├── LICENSE ├── README.md ├── configs ├── __init__.py ├── data_configs.py ├── paths_config.py └── transforms_config.py ├── criteria ├── __init__.py ├── id_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── moco_loss.py └── ms_ssim.py ├── datasets ├── __init__.py ├── dataset_fetcher.py ├── gt_res_dataset.py ├── images_dataset.py ├── inference_dataset.py └── latents_images_dataset.py ├── docs ├── blunt_interfacegan.jpg ├── cars_ganspace.jpg ├── dicaprio_styleclip.jpg ├── domain_adaptation.jpg └── teaser.jpg ├── editing ├── __init__.py ├── cars_editor.py ├── face_editor.py ├── ganspace_directions │ └── cars_pca.pt ├── inference_cars_editing.py ├── inference_face_editing.py ├── interfacegan_directions │ ├── age.pt │ ├── pose.pt │ └── smile.pt └── styleclip │ ├── __init__.py │ ├── edit.py │ ├── global_direction.py │ ├── global_directions │ ├── ffhq │ │ ├── S_mean_std │ │ └── fs3.npy │ └── templates.txt │ ├── model.py │ └── stylespace_utils.py ├── environment └── hyperstyle_env.yaml ├── licenses ├── LICENSE_encoder4editing ├── LICENSE_insightface ├── LICENSE_lpips ├── LICENSE_pixel2style2pixel ├── LICENSE_ranger ├── LICENSE_restyle └── LICENSE_stylegan2 ├── models ├── __init__.py ├── encoders │ ├── __init__.py │ ├── e4e.py │ ├── helpers.py │ ├── model_irse.py │ ├── psp.py │ ├── restyle_e4e_encoders.py │ └── w_encoder.py ├── hypernetworks │ ├── __init__.py │ ├── hypernetwork.py │ ├── refinement_blocks.py │ └── shared_weights_hypernet.py ├── hyperstyle.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── notebooks ├── __init__.py ├── animations_playground.ipynb ├── domain_adaptation_playground.ipynb ├── images │ ├── afhq_wild_image.jpg │ ├── animations │ │ ├── affleck.jpg │ │ ├── bezos.jpg │ │ ├── blunt.jpg │ │ ├── damon.jpg │ │ ├── dicaprio.jpg │ │ ├── downey.jpg │ │ ├── driver.jpg │ │ ├── jackson.jpg │ │ ├── johansson.jpg │ │ ├── kunis.jpg │ │ ├── pitt.jpg │ │ ├── robbie.jpg │ │ ├── stone.jpg │ │ ├── watson.jpg │ │ └── zuckerberg.jpg │ ├── car_image.jpg │ ├── domain_adaptation.jpg │ └── face_image.jpg ├── inference_playground.ipynb └── notebook_utils.py ├── options ├── __init__.py ├── test_options.py └── train_options.py ├── scripts ├── align_faces_parallel.py ├── calc_id_loss_parallel.py ├── calc_losses_on_images.py ├── inference.py ├── run_domain_adaptation.py └── train.py ├── training ├── __init__.py ├── coach_hyperstyle.py └── ranger.py └── utils ├── __init__.py ├── common.py ├── data_utils.py ├── domain_adaptation_utils.py ├── inference_utils.py ├── model_utils.py ├── resnet_mapping.py ├── restyle_inference_utils.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuval Alaluf, Omer Tov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/configs/__init__.py -------------------------------------------------------------------------------- /configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from configs import transforms_config 2 | from configs.paths_config import dataset_paths 3 | 4 | 5 | DATASETS = { 6 | 'ffhq_hypernet': { 7 | 'transforms': transforms_config.EncodeTransforms, 8 | 'train_source_root': dataset_paths['ffhq'], 9 | 'train_target_root': dataset_paths['ffhq'], 10 | 'test_source_root': dataset_paths['celeba_test'], 11 | 'test_target_root': dataset_paths['celeba_test'], 12 | }, 13 | 'ffhq_hypernet_pre_extract': { 14 | 'transforms': transforms_config.NoFlipTransforms, 15 | 'train_source_root': dataset_paths['ffhq_w_inv'], 16 | 'train_target_root': dataset_paths['ffhq'], 17 | 'train_latents_path': dataset_paths['ffhq_w_latents'], 18 | 'test_source_root': dataset_paths['celeba_test_w_inv'], 19 | 'test_target_root': dataset_paths['celeba_test'], 20 | 'test_latents_path': dataset_paths['celeba_test_w_latents'] 21 | }, 22 | "cars_hypernet": { 23 | 'transforms': transforms_config.CarsEncodeTransforms, 24 | 'train_source_root': dataset_paths['cars_train'], 25 | 'train_target_root': dataset_paths['cars_train'], 26 | 'test_source_root': dataset_paths['cars_test'], 27 | 'test_target_root': dataset_paths['cars_test'] 28 | }, 29 | "afhq_wild_hypernet": { 30 | 'transforms': transforms_config.EncodeTransforms, 31 | 'train_source_root': dataset_paths['afhq_wild_train'], 32 | 'train_target_root': dataset_paths['afhq_wild_train'], 33 | 'test_source_root': dataset_paths['afhq_wild_test'], 34 | 'test_target_root': dataset_paths['afhq_wild_test'] 35 | } 36 | } -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | dataset_paths = { 2 | 'cars_train': '', 3 | 'cars_test': '', 4 | 5 | 'celeba_train': '', 6 | 'celeba_test': '', 7 | 'celeba_test_w_inv': '', 8 | 'celeba_test_w_latents': '', 9 | 10 | 'ffhq': '', 11 | 'ffhq_w_inv': '', 12 | 'ffhq_w_latents': '', 13 | 14 | 'afhq_wild_train': '', 15 | 'afhq_wild_test': '', 16 | 17 | } 18 | 19 | model_paths = { 20 | # models for backbones and losses 21 | 'ir_se50': 'pretrained_models/model_ir_se50.pth', 22 | 'resnet34': 'pretrained_models/resnet34-333f7ec4.pth', 23 | 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pt', 24 | # stylegan2 generators 25 | 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', 26 | 'stylegan_cars': 'pretrained_models/stylegan2-car-config-f.pt', 27 | 'stylegan_ada_wild': 'pretrained_models/afhqwild.pt', 28 | # model for face alignment 29 | 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat', 30 | # models for ID similarity computation 31 | 'curricular_face': 'pretrained_models/CurricularFace_Backbone.pth', 32 | 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy', 33 | 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy', 34 | 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy', 35 | # WEncoders for training on various domains 36 | 'faces_w_encoder': 'pretrained_models/faces_w_encoder.pt', 37 | 'cars_w_encoder': 'pretrained_models/cars_w_encoder.pt', 38 | 'afhq_wild_w_encoder': 'pretrained_models/afhq_wild_w_encoder.pt', 39 | # models for domain adaptation 40 | 'restyle_e4e_ffhq': 'pretrained_models/restyle_e4e_ffhq_encode.pt', 41 | 'stylegan_pixar': 'pretrained_models/pixar.pt', 42 | 'stylegan_toonify': 'pretrained_models/ffhq_cartoon_blended.pt', 43 | 'stylegan_sketch': 'pretrained_models/sketch.pt', 44 | 'stylegan_disney': 'pretrained_models/disney_princess.pt' 45 | } 46 | 47 | edit_paths = { 48 | 'age': 'editing/interfacegan_directions/age.pt', 49 | 'smile': 'editing/interfacegan_directions/smile.pt', 50 | 'pose': 'editing/interfacegan_directions/pose.pt', 51 | 'cars': 'editing/ganspace_directions/cars_pca.pt', 52 | 'styleclip': { 53 | 'delta_i_c': 'editing/styleclip/global_directions/ffhq/fs3.npy', 54 | 's_statistics': 'editing/styleclip/global_directions/ffhq/S_mean_std', 55 | 'templates': 'editing/styleclip/global_directions/templates.txt' 56 | } 57 | } -------------------------------------------------------------------------------- /configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | 4 | 5 | class TransformsConfig(object): 6 | 7 | def __init__(self, opts): 8 | self.opts = opts 9 | 10 | @abstractmethod 11 | def get_transforms(self): 12 | pass 13 | 14 | 15 | class EncodeTransforms(TransformsConfig): 16 | 17 | def __init__(self, opts): 18 | super(EncodeTransforms, self).__init__(opts) 19 | 20 | def get_transforms(self): 21 | transforms_dict = { 22 | 'transform_gt_train': transforms.Compose([ 23 | transforms.Resize((256, 256)), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 27 | 'transform_source': None, 28 | 'transform_test': transforms.Compose([ 29 | transforms.Resize((256, 256)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 32 | 'transform_inference': transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 36 | } 37 | return transforms_dict 38 | 39 | 40 | class NoFlipTransforms(TransformsConfig): 41 | 42 | def __init__(self, opts): 43 | super(NoFlipTransforms, self).__init__(opts) 44 | 45 | def get_transforms(self): 46 | transforms_dict = { 47 | 'transform_gt_train': transforms.Compose([ 48 | transforms.Resize((256, 256)), 49 | transforms.ToTensor(), 50 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 51 | 'transform_source': transforms.Compose([ 52 | transforms.Resize((256, 256)), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 55 | 'transform_test': transforms.Compose([ 56 | transforms.Resize((256, 256)), 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 59 | 'transform_inference': transforms.Compose([ 60 | transforms.Resize((256, 256)), 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 63 | } 64 | return transforms_dict 65 | 66 | 67 | class CarsEncodeTransforms(TransformsConfig): 68 | 69 | def __init__(self, opts): 70 | super(CarsEncodeTransforms, self).__init__(opts) 71 | 72 | def get_transforms(self): 73 | transforms_dict = { 74 | 'transform_gt_train': transforms.Compose([ 75 | transforms.Resize((192, 256)), 76 | transforms.RandomHorizontalFlip(0.5), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 79 | 'transform_source': None, 80 | 'transform_test': transforms.Compose([ 81 | transforms.Resize((192, 256)), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 84 | 'transform_inference': transforms.Compose([ 85 | transforms.Resize((192, 256)), 86 | transforms.ToTensor(), 87 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 88 | } 89 | return transforms_dict 90 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import model_paths 4 | from models.encoders.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self, opts): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | self.opts = opts 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def forward(self, y_hat, y, x): 24 | n_samples = x.shape[0] 25 | x_feats = self.extract_feats(x) 26 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 27 | y_hat_feats = self.extract_feats(y_hat) 28 | y_feats = y_feats.detach() 29 | loss = 0 30 | sim_improvement = 0 31 | id_logs = [] 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | diff_input = y_hat_feats[i].dot(x_feats[i]) 36 | diff_views = y_feats[i].dot(x_feats[i]) 37 | id_logs.append({'diff_target': float(diff_target), 38 | 'diff_input': float(diff_input), 39 | 'diff_views': float(diff_views)}) 40 | loss += 1 - diff_target 41 | id_diff = float(diff_target) - float(diff_views) 42 | sim_improvement += id_diff 43 | count += 1 44 | 45 | return loss / count, sim_improvement / count, id_logs 46 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | 9 | def __init__(self): 10 | super(MocoLoss, self).__init__() 11 | print("Loading MOCO model from path: {}".format(model_paths["moco"])) 12 | self.model = self.__load_model() 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs 70 | -------------------------------------------------------------------------------- /criteria/ms_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | """ 7 | Taken from https://github.com/jorge-pessoa/pytorch-msssim 8 | """ 9 | 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 13 | return gauss/gauss.sum() 14 | 15 | 16 | def create_window(window_size, channel=1): 17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 18 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 19 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 20 | return window 21 | 22 | 23 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 24 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 25 | if val_range is None: 26 | if torch.max(img1) > 128: 27 | max_val = 255 28 | else: 29 | max_val = 1 30 | 31 | if torch.min(img1) < -0.5: 32 | min_val = -1 33 | else: 34 | min_val = 0 35 | L = max_val - min_val 36 | else: 37 | L = val_range 38 | 39 | padd = 0 40 | (_, channel, height, width) = img1.size() 41 | if window is None: 42 | real_size = min(window_size, height, width) 43 | window = create_window(real_size, channel=channel).to(img1.device) 44 | 45 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 46 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 47 | 48 | mu1_sq = mu1.pow(2) 49 | mu2_sq = mu2.pow(2) 50 | mu1_mu2 = mu1 * mu2 51 | 52 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 53 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 54 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 55 | 56 | C1 = (0.01 * L) ** 2 57 | C2 = (0.03 * L) ** 2 58 | 59 | v1 = 2.0 * sigma12 + C2 60 | v2 = sigma1_sq + sigma2_sq + C2 61 | cs = v1 / v2 # contrast sensitivity 62 | 63 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 64 | 65 | if size_average: 66 | cs = cs.mean() 67 | ret = ssim_map.mean() 68 | else: 69 | cs = cs.mean(1).mean(1).mean(1) 70 | ret = ssim_map.mean(1).mean(1).mean(1) 71 | 72 | if full: 73 | return ret, cs 74 | return ret 75 | 76 | 77 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): 78 | device = img1.device 79 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 80 | levels = weights.size()[0] 81 | ssims = [] 82 | mcs = [] 83 | for _ in range(levels): 84 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 85 | 86 | # Relu normalize (not compliant with original definition) 87 | if normalize == "relu": 88 | ssims.append(torch.relu(sim)) 89 | mcs.append(torch.relu(cs)) 90 | else: 91 | ssims.append(sim) 92 | mcs.append(cs) 93 | 94 | img1 = F.avg_pool2d(img1, (2, 2)) 95 | img2 = F.avg_pool2d(img2, (2, 2)) 96 | 97 | ssims = torch.stack(ssims) 98 | mcs = torch.stack(mcs) 99 | 100 | # Simple normalize (not compliant with original definition) 101 | if normalize == "simple" or normalize == True: 102 | ssims = (ssims + 1) / 2 103 | mcs = (mcs + 1) / 2 104 | 105 | pow1 = mcs ** weights 106 | pow2 = ssims ** weights 107 | 108 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 109 | output = torch.prod(pow1[:-1]) * pow2[-1] 110 | return output 111 | 112 | 113 | # Classes to re-use window 114 | class SSIM(torch.nn.Module): 115 | def __init__(self, window_size=11, size_average=True, val_range=None): 116 | super(SSIM, self).__init__() 117 | self.window_size = window_size 118 | self.size_average = size_average 119 | self.val_range = val_range 120 | 121 | # Assume 1 channel for SSIM 122 | self.channel = 1 123 | self.window = create_window(window_size) 124 | 125 | def forward(self, img1, img2): 126 | (_, channel, _, _) = img1.size() 127 | 128 | if channel == self.channel and self.window.dtype == img1.dtype: 129 | window = self.window 130 | else: 131 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 132 | self.window = window 133 | self.channel = channel 134 | 135 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 136 | 137 | class MSSSIM(torch.nn.Module): 138 | def __init__(self, window_size=11, size_average=True, channel=3): 139 | super(MSSSIM, self).__init__() 140 | self.window_size = window_size 141 | self.size_average = size_average 142 | self.channel = channel 143 | 144 | def forward(self, img1, img2): 145 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 146 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataset_fetcher.py: -------------------------------------------------------------------------------- 1 | from datasets.images_dataset import ImagesDataset 2 | from datasets.latents_images_dataset import LatentsImagesDataset 3 | 4 | 5 | class DatasetFetcher: 6 | 7 | def get_dataset(self, opts, dataset_args, transforms_dict): 8 | if opts.dataset_type in ['ffhq_hypernet_pre_extract']: 9 | return self.__get_latents_dataset(opts, dataset_args, transforms_dict) 10 | else: 11 | return self.__get_images_dataset(opts, dataset_args, transforms_dict) 12 | 13 | @staticmethod 14 | def __get_latents_dataset(opts, dataset_args, transforms_dict): 15 | train_dataset = LatentsImagesDataset(source_root=dataset_args['train_source_root'], 16 | target_root=dataset_args['train_target_root'], 17 | latents_path=dataset_args['train_latents_path'], 18 | source_transform=transforms_dict['transform_source'], 19 | target_transform=transforms_dict['transform_gt_train'], 20 | opts=opts) 21 | test_dataset = LatentsImagesDataset(source_root=dataset_args['test_source_root'], 22 | target_root=dataset_args['test_target_root'], 23 | latents_path=dataset_args['test_latents_path'], 24 | source_transform=transforms_dict['transform_source'], 25 | target_transform=transforms_dict['transform_test'], 26 | opts=opts) 27 | return train_dataset, test_dataset 28 | 29 | @staticmethod 30 | def __get_images_dataset(opts, dataset_args, transforms_dict): 31 | train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], 32 | target_root=dataset_args['train_target_root'], 33 | source_transform=transforms_dict['transform_source'], 34 | target_transform=transforms_dict['transform_gt_train'], 35 | opts=opts) 36 | test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], 37 | target_root=dataset_args['test_target_root'], 38 | source_transform=transforms_dict['transform_source'], 39 | target_transform=transforms_dict['transform_test'], 40 | opts=opts) 41 | return train_dataset, test_dataset 42 | -------------------------------------------------------------------------------- /datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | 5 | 6 | class GTResDataset(Dataset): 7 | 8 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 9 | self.pairs = [] 10 | for f in os.listdir(root_path): 11 | image_path = os.path.join(root_path, f) 12 | gt_path = os.path.join(gt_dir, f) 13 | if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg"): 14 | self.pairs.append([image_path, gt_path, None]) 15 | self.transform = transform 16 | self.transform_train = transform_train 17 | 18 | def __len__(self): 19 | return len(self.pairs) 20 | 21 | def __getitem__(self, index): 22 | from_path, to_path, _ = self.pairs[index] 23 | from_im = Image.open(from_path).convert('RGB') 24 | to_im = Image.open(to_path).convert('RGB') 25 | if self.transform: 26 | to_im = self.transform(to_im) 27 | from_im = self.transform(from_im) 28 | return from_im, to_im 29 | -------------------------------------------------------------------------------- /datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class ImagesDataset(Dataset): 7 | 8 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 9 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 10 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 11 | self.source_transform = source_transform 12 | self.target_transform = target_transform 13 | self.opts = opts 14 | 15 | def __len__(self): 16 | return len(self.source_paths) 17 | 18 | def __getitem__(self, index): 19 | from_path = self.source_paths[index] 20 | to_path = self.target_paths[index] 21 | 22 | from_im = Image.open(from_path).convert('RGB') 23 | to_im = Image.open(to_path).convert('RGB') 24 | 25 | if self.target_transform: 26 | to_im = self.target_transform(to_im) 27 | 28 | if self.source_transform: 29 | from_im = self.source_transform(from_im) 30 | else: 31 | from_im = to_im 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class InferenceDataset(Dataset): 7 | 8 | def __init__(self, root, opts, transform=None): 9 | self.paths = sorted(data_utils.make_dataset(root)) 10 | self.transform = transform 11 | self.opts = opts 12 | 13 | def __len__(self): 14 | return len(self.paths) 15 | 16 | def __getitem__(self, index): 17 | from_path = self.paths[index] 18 | from_im = Image.open(from_path).convert('RGB') 19 | if self.transform: 20 | from_im = self.transform(from_im) 21 | return from_im 22 | -------------------------------------------------------------------------------- /datasets/latents_images_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from utils import data_utils 6 | 7 | 8 | class LatentsImagesDataset(Dataset): 9 | 10 | def __init__(self, source_root, target_root, latents_path, opts, target_transform=None, source_transform=None): 11 | # path to inversions directory 12 | self.source_root = source_root 13 | # path to original dataset 14 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 15 | # path to latents corresponding to inversions 16 | # this should be a dictionary mapping image name to the image's latent code 17 | self.latents = torch.load(latents_path, map_location='cpu') 18 | self.latents.requires_grad = False 19 | self.source_transform = source_transform 20 | self.target_transform = target_transform 21 | self.opts = opts 22 | 23 | def __len__(self): 24 | return len(self.target_paths) 25 | 26 | def __getitem__(self, index): 27 | from_path = os.path.join(self.source_root, f'{index+1:05d}.png') 28 | to_path = self.target_paths[index] 29 | 30 | from_im = Image.open(from_path).convert('RGB') 31 | to_im = Image.open(to_path).convert('RGB') 32 | 33 | if self.target_transform: 34 | to_im = self.target_transform(to_im) 35 | 36 | if self.source_transform: 37 | from_im = self.source_transform(from_im) 38 | else: 39 | from_im = to_im 40 | 41 | latent = self.latents[os.path.basename(from_path)] 42 | if latent.ndim == 1: 43 | latent = latent.repeat(18, 1) 44 | 45 | return from_im, to_im, latent 46 | -------------------------------------------------------------------------------- /docs/blunt_interfacegan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/docs/blunt_interfacegan.jpg -------------------------------------------------------------------------------- /docs/cars_ganspace.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/docs/cars_ganspace.jpg -------------------------------------------------------------------------------- /docs/dicaprio_styleclip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/docs/dicaprio_styleclip.jpg -------------------------------------------------------------------------------- /docs/domain_adaptation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/docs/domain_adaptation.jpg -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/docs/teaser.jpg -------------------------------------------------------------------------------- /editing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/__init__.py -------------------------------------------------------------------------------- /editing/cars_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | 6 | from configs.paths_config import edit_paths 7 | from utils.common import tensor2im 8 | 9 | 10 | class CarsEditor: 11 | 12 | def __init__(self, stylegan_generator): 13 | self.generator = stylegan_generator 14 | self.gs = torch.load(edit_paths['cars']) 15 | # Directions: Pose I, Pose II, Cube, Color, Grass 16 | self.directions = [(0, 0, 5, 2), (0, 0, 5, -2), (16, 3, 6, 25), (22, 9, 11, -8), (41, 9, 11, -18)] 17 | 18 | def apply_ganspace(self, latents, weights_deltas, input_im, save_dir, noise=None): 19 | for image_id in latents.keys(): 20 | latent = latents[image_id].to('cuda') 21 | inputs = [latent] 22 | for i, (pca_idx, start, end, strength) in enumerate(self.directions): 23 | delta = self._get_delta(self.gs, latent, pca_idx, strength) 24 | delta_padded = torch.zeros(latent.shape).to('cuda') 25 | delta_padded[start:end] += delta.repeat(end - start, 1) 26 | inputs.append(latent + delta_padded) 27 | inputs = torch.stack(inputs) 28 | edited_images = self._latents_to_image(inputs, weights_deltas, noise=noise) 29 | self._save_coupled_image(input_im, image_id, edited_images, save_dir) 30 | 31 | def _latents_to_image(self, inputs, weights_deltas, noise=None): 32 | with torch.no_grad(): 33 | images, _ = self.generator([inputs], input_is_latent=True, noise=noise, randomize_noise=False, 34 | weights_deltas=weights_deltas, return_latents=True) 35 | images = images[:, :, 64:448, :] 36 | return images 37 | 38 | @staticmethod 39 | def _get_delta(gs, latent, idx=16, strength=25): 40 | # gs: ganspace checkpoint loaded, latent: (16, 512) w+ 41 | w_centered = latent - gs['mean'].to('cuda') 42 | lat_comp = gs['comp'].to('cuda') 43 | lat_std = gs['std'].to('cuda') 44 | w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] 45 | delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] 46 | return delta 47 | 48 | @staticmethod 49 | def _save_coupled_image(input_im, image_id, edited_images, save_dir): 50 | res = np.array(input_im) 51 | for img in edited_images: 52 | res = np.concatenate([res, tensor2im(img)], axis=1) 53 | res_im = Image.fromarray(res) 54 | im_save_path = os.path.join(save_dir, f"{image_id.split('.')[0]}.jpg") 55 | res_im.save(im_save_path) 56 | -------------------------------------------------------------------------------- /editing/face_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from configs.paths_config import edit_paths 4 | from utils.common import tensor2im 5 | 6 | 7 | class FaceEditor: 8 | 9 | def __init__(self, stylegan_generator): 10 | self.generator = stylegan_generator 11 | self.interfacegan_directions = { 12 | 'age': torch.load(edit_paths['age']).cuda(), 13 | 'smile': torch.load(edit_paths['smile']).cuda(), 14 | 'pose': torch.load(edit_paths['pose']).cuda() 15 | } 16 | 17 | def apply_interfacegan(self, latents, weights_deltas, direction, factor=1, factor_range=None): 18 | edit_latents = [] 19 | direction = self.interfacegan_directions[direction] 20 | if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) 21 | for f in range(*factor_range): 22 | edit_latent = latents + f * direction 23 | edit_latents.append(edit_latent) 24 | edit_latents = torch.stack(edit_latents).transpose(0,1) 25 | else: 26 | edit_latents = latents + factor * direction 27 | return self._latents_to_image(edit_latents, weights_deltas) 28 | 29 | def _latents_to_image(self, all_latents, weights_deltas): 30 | sample_results = {} 31 | with torch.no_grad(): 32 | for idx, sample_latents in enumerate(all_latents): 33 | sample_deltas = [d[idx] if d is not None else None for d in weights_deltas] 34 | images, _ = self.generator([sample_latents], 35 | weights_deltas=sample_deltas, 36 | randomize_noise=False, 37 | input_is_latent=True) 38 | sample_results[idx] = [tensor2im(image) for image in images] 39 | return sample_results 40 | -------------------------------------------------------------------------------- /editing/ganspace_directions/cars_pca.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/ganspace_directions/cars_pca.pt -------------------------------------------------------------------------------- /editing/inference_cars_editing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import time 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | import sys 8 | sys.path.append(".") 9 | sys.path.append("..") 10 | 11 | from configs import data_configs 12 | from datasets.inference_dataset import InferenceDataset 13 | from editing.cars_editor import CarsEditor 14 | from options.test_options import TestOptions 15 | from utils.inference_utils import run_inversion 16 | from utils.model_utils import load_model 17 | from utils.common import tensor2im 18 | 19 | 20 | def run(): 21 | test_opts = TestOptions().parse() 22 | 23 | out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled') 24 | os.makedirs(out_path_coupled, exist_ok=True) 25 | 26 | # update test options with options used during training 27 | net, opts = load_model(test_opts.checkpoint_path, update_opts=test_opts) 28 | 29 | print('Loading dataset for {}'.format(opts.dataset_type)) 30 | dataset_args = data_configs.DATASETS[opts.dataset_type] 31 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 32 | dataset = InferenceDataset(root=opts.data_path, 33 | transform=transforms_dict['transform_inference'], 34 | opts=opts) 35 | dataloader = DataLoader(dataset, 36 | batch_size=opts.test_batch_size, 37 | shuffle=False, 38 | num_workers=int(opts.test_workers), 39 | drop_last=False) 40 | 41 | if opts.n_images is None: 42 | opts.n_images = len(dataset) 43 | 44 | latent_editor = CarsEditor(net.decoder) 45 | 46 | global_i = 0 47 | global_time = [] 48 | for input_batch in tqdm(dataloader): 49 | if global_i >= opts.n_images: 50 | break 51 | with torch.no_grad(): 52 | input_cuda = input_batch.cuda().float() 53 | tic = time.time() 54 | y_hat, batch_latents, weights_deltas, codes = run_inversion(input_cuda, net, opts) 55 | toc = time.time() 56 | global_time.append(toc - tic) 57 | 58 | for i in range(input_batch.shape[0]): 59 | 60 | im_path = dataset.paths[global_i] 61 | input_im = tensor2im(input_batch[i]).resize((512, 384)) 62 | 63 | sample_deltas = [d[i] if d is not None else None for d in weights_deltas] 64 | latents = {os.path.basename(im_path): batch_latents[i]} 65 | latent_editor.apply_ganspace(latents=latents, 66 | weights_deltas=sample_deltas, 67 | input_im=input_im, 68 | save_dir=out_path_coupled) 69 | global_i += 1 70 | 71 | 72 | if __name__ == '__main__': 73 | run() 74 | -------------------------------------------------------------------------------- /editing/inference_face_editing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import DataLoader 7 | 8 | import sys 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from configs import data_configs 13 | from datasets.inference_dataset import InferenceDataset 14 | from editing.face_editor import FaceEditor 15 | from options.test_options import TestOptions 16 | from utils.common import tensor2im 17 | from utils.inference_utils import run_inversion 18 | from utils.model_utils import load_model 19 | 20 | 21 | def run(): 22 | test_opts = TestOptions().parse() 23 | 24 | out_path_results = os.path.join(test_opts.exp_dir, 'editing_results') 25 | out_path_coupled = os.path.join(test_opts.exp_dir, 'editing_coupled') 26 | 27 | os.makedirs(out_path_results, exist_ok=True) 28 | os.makedirs(out_path_coupled, exist_ok=True) 29 | 30 | # update test options with options used during training 31 | net, opts = load_model(test_opts.checkpoint_path, update_opts=test_opts) 32 | 33 | print('Loading dataset for {}'.format(opts.dataset_type)) 34 | dataset_args = data_configs.DATASETS[opts.dataset_type] 35 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 36 | dataset = InferenceDataset(root=opts.data_path, 37 | transform=transforms_dict['transform_inference'], 38 | opts=opts) 39 | dataloader = DataLoader(dataset, 40 | batch_size=opts.test_batch_size, 41 | shuffle=False, 42 | num_workers=int(opts.test_workers), 43 | drop_last=False) 44 | 45 | if opts.n_images is None: 46 | opts.n_images = len(dataset) 47 | 48 | latent_editor = FaceEditor(net.decoder) 49 | 50 | global_i = 0 51 | for input_batch in tqdm(dataloader): 52 | 53 | if global_i >= opts.n_images: 54 | break 55 | 56 | with torch.no_grad(): 57 | input_cuda = input_batch.cuda().float() 58 | result_batch = run_on_batch(input_cuda, net, latent_editor, opts) 59 | 60 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 61 | for i in range(input_batch.shape[0]): 62 | 63 | im_path = dataset.paths[global_i] 64 | results = result_batch[i] 65 | 66 | inversion = results.pop('inversion') 67 | input_im = tensor2im(input_batch[i]) 68 | 69 | all_edit_results = [] 70 | for edit_name, edit_res in results.items(): 71 | # set the input image 72 | res = np.array(input_im.resize(resize_amount)) 73 | # set the inversion 74 | res = np.concatenate([res, np.array(inversion.resize(resize_amount))], axis=1) 75 | # add editing results side-by-side 76 | for result in edit_res: 77 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 78 | res_im = Image.fromarray(res) 79 | all_edit_results.append(res_im) 80 | 81 | edit_save_dir = os.path.join(out_path_results, edit_name) 82 | os.makedirs(edit_save_dir, exist_ok=True) 83 | res_im.save(os.path.join(edit_save_dir, os.path.basename(im_path))) 84 | 85 | # save final concatenated result of all edits 86 | coupled_res = np.concatenate(all_edit_results, axis=0) 87 | im_save_path = os.path.join(out_path_coupled, os.path.basename(im_path)) 88 | Image.fromarray(coupled_res).save(im_save_path) 89 | global_i += 1 90 | 91 | 92 | def run_on_batch(inputs, net, latent_editor, opts): 93 | y_hat, _, weights_deltas, codes = run_inversion(inputs, net, opts) 94 | edit_directions = opts.edit_directions.split(',') 95 | # store all results for each sample, split by the edit direction 96 | results = {idx: {'inversion': tensor2im(y_hat[idx])} for idx in range(len(inputs))} 97 | for edit_direction in edit_directions: 98 | edit_res = latent_editor.apply_interfacegan(latents=codes, 99 | weights_deltas=weights_deltas, 100 | direction=edit_direction, 101 | factor_range=(-1 * opts.factor_range, opts.factor_range)) 102 | # store the results for each sample 103 | for idx, sample_res in edit_res.items(): 104 | results[idx][edit_direction] = sample_res 105 | return results 106 | 107 | 108 | if __name__ == '__main__': 109 | run() 110 | -------------------------------------------------------------------------------- /editing/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /editing/interfacegan_directions/pose.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/interfacegan_directions/pose.pt -------------------------------------------------------------------------------- /editing/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /editing/styleclip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/styleclip/__init__.py -------------------------------------------------------------------------------- /editing/styleclip/edit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import torch 5 | import numpy as np 6 | import torchvision 7 | 8 | import sys 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from configs.paths_config import edit_paths, model_paths 13 | from editing.styleclip.global_direction import StyleCLIPGlobalDirection 14 | from editing.styleclip.model import Generator 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--exp_dir", type=str, default="./experiment", 20 | help="Path to inference results with `latents.npy` saved here (obtained with inference.py).") 21 | parser.add_argument("--weight_deltas_path", type=str, default="./weight_deltas", 22 | help="Root path holding all weight deltas (obtained by running inference.py).") 23 | parser.add_argument('--n_images', type=int, default=None, 24 | help="Maximum number of images to edit. If None, edit all images.") 25 | parser.add_argument("--neutral_text", type=str, default="face with hair") 26 | parser.add_argument("--target_text", type=str, default="face with long hair") 27 | parser.add_argument("--stylegan_weights", type=str, default=model_paths["stylegan_ffhq"]) 28 | parser.add_argument("--stylegan_size", type=int, default=1024) 29 | parser.add_argument("--stylegan_truncation", type=int, default=1.) 30 | parser.add_argument("--stylegan_truncation_mean", type=int, default=4096) 31 | parser.add_argument("--beta_min", type=float, default=0.11) 32 | parser.add_argument("--beta_max", type=float, default=0.16) 33 | parser.add_argument("--num_betas", type=int, default=5) 34 | parser.add_argument("--alpha_min", type=float, default=-5) 35 | parser.add_argument("--alpha_max", type=float, default=5) 36 | parser.add_argument("--num_alphas", type=int, default=11) 37 | parser.add_argument("--delta_i_c", type=str, default=edit_paths["styleclip"]["delta_i_c"], 38 | help="path to file containing delta_i_c") 39 | parser.add_argument("--s_statistics", type=str, default=edit_paths["styleclip"]["s_statistics"], 40 | help="path to file containing s statistics") 41 | parser.add_argument("--text_prompt_templates", default=edit_paths["styleclip"]["templates"]) 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def load_direction_calculator(args): 47 | delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().cuda() 48 | with open(args.s_statistics, "rb") as channels_statistics: 49 | _, s_std = pickle.load(channels_statistics) 50 | s_std = [torch.from_numpy(s_i).float().cuda() for s_i in s_std] 51 | with open(args.text_prompt_templates, "r") as templates: 52 | text_prompt_templates = templates.readlines() 53 | global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates) 54 | return global_direction_calculator 55 | 56 | 57 | def load_stylegan_generator(args): 58 | stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).cuda() 59 | checkpoint = torch.load(args.stylegan_weights) 60 | stylegan_model.load_state_dict(checkpoint['g_ema']) 61 | return stylegan_model 62 | 63 | 64 | def run(): 65 | args = parse_args() 66 | stylegan_model = load_stylegan_generator(args) 67 | global_direction_calculator = load_direction_calculator(args) 68 | # load latents obtained via inference 69 | latents = np.load(os.path.join(args.exp_dir, 'latents.npy'), allow_pickle=True).item() 70 | # prepare output directory 71 | args.output_path = os.path.join(args.exp_dir, "styleclip_edits", f"{args.neutral_text}_to_{args.target_text}") 72 | os.makedirs(args.output_path, exist_ok=True) 73 | # edit all images 74 | for idx, (image_name, latent) in enumerate(latents.items()): 75 | if args.n_images is not None and idx >= args.n_images: 76 | break 77 | edit_image(image_name, latent, stylegan_model, global_direction_calculator, args) 78 | 79 | 80 | def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args): 81 | print(f'Editing {image_name}') 82 | 83 | latent_code = torch.from_numpy(latent).cuda() 84 | truncation = 1 85 | mean_latent = None 86 | input_is_latent = True 87 | latent_code_i = latent_code.unsqueeze(0) 88 | 89 | weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True) 90 | weight_deltas = [torch.from_numpy(w).cuda() if w is not None else None for w in weight_deltas] 91 | 92 | with torch.no_grad(): 93 | 94 | source_im, _, latent_code_s = stylegan_model([latent_code_i], 95 | input_is_latent=input_is_latent, 96 | randomize_noise=False, 97 | return_latents=True, 98 | truncation=truncation, 99 | truncation_latent=mean_latent, 100 | weights_deltas=weight_deltas) 101 | 102 | alphas = np.linspace(args.alpha_min, args.alpha_max, args.num_alphas) 103 | betas = np.linspace(args.beta_min, args.beta_max, args.num_betas) 104 | results = [] 105 | for beta in betas: 106 | direction = global_direction_calculator.get_delta_s(args.neutral_text, args.target_text, beta) 107 | edited_latent_code_s = [[s_i + alpha * b_i for s_i, b_i in zip(latent_code_s, direction)] for alpha in alphas] 108 | edited_latent_code_s = [torch.cat([edited_latent_code_s[i][j] for i in range(args.num_alphas)]) 109 | for j in range(len(edited_latent_code_s[0]))] 110 | for b in range(0, edited_latent_code_s[0].shape[0]): 111 | edited_latent_code_s_batch = [s_i[b:b + 1] for s_i in edited_latent_code_s] 112 | with torch.no_grad(): 113 | edited_image, _, _ = stylegan_model([edited_latent_code_s_batch], 114 | input_is_stylespace=True, 115 | randomize_noise=False, 116 | return_latents=True, 117 | weights_deltas=weight_deltas) 118 | results.append(edited_image) 119 | 120 | results = torch.cat(results) 121 | torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg", 122 | normalize=True, range=(-1, 1), padding=0, nrow=args.num_alphas) 123 | 124 | 125 | if __name__ == "__main__": 126 | run() 127 | -------------------------------------------------------------------------------- /editing/styleclip/global_direction.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import clip 3 | import torch 4 | 5 | from editing.styleclip.stylespace_utils import features_channels_to_s 6 | 7 | 8 | class StyleCLIPGlobalDirection: 9 | 10 | def __init__(self, delta_i_c, s_std, text_prompts_templates): 11 | super(StyleCLIPGlobalDirection, self).__init__() 12 | self.delta_i_c = delta_i_c 13 | self.s_std = s_std 14 | self.text_prompts_templates = text_prompts_templates 15 | self.clip_model, _ = clip.load("ViT-B/32", device="cuda") 16 | 17 | def get_delta_s(self, neutral_text, target_text, beta): 18 | delta_i = self.get_delta_i([target_text, neutral_text]).float() 19 | r_c = torch.matmul(self.delta_i_c, delta_i) 20 | delta_s = copy.copy(r_c) 21 | channels_to_zero = torch.abs(r_c) < beta 22 | delta_s[channels_to_zero] = 0 23 | max_channel_value = torch.abs(delta_s).max() 24 | if max_channel_value > 0: 25 | delta_s /= max_channel_value 26 | direction = features_channels_to_s(delta_s, self.s_std) 27 | return direction 28 | 29 | def get_delta_i(self, text_prompts): 30 | text_features = self._get_averaged_text_features(text_prompts) 31 | delta_t = text_features[0] - text_features[1] 32 | delta_i = delta_t / torch.norm(delta_t) 33 | return delta_i 34 | 35 | def _get_averaged_text_features(self, text_prompts): 36 | with torch.no_grad(): 37 | text_features_list = [] 38 | for text_prompt in text_prompts: 39 | formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class 40 | formatted_text_prompts = clip.tokenize(formatted_text_prompts).cuda() # tokenize 41 | text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder 42 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 43 | text_embedding = text_embeddings.mean(dim=0) 44 | text_embedding /= text_embedding.norm() 45 | text_features_list.append(text_embedding) 46 | text_features = torch.stack(text_features_list, dim=1).cuda() 47 | return text_features.t() 48 | -------------------------------------------------------------------------------- /editing/styleclip/global_directions/ffhq/S_mean_std: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/styleclip/global_directions/ffhq/S_mean_std -------------------------------------------------------------------------------- /editing/styleclip/global_directions/ffhq/fs3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/editing/styleclip/global_directions/ffhq/fs3.npy -------------------------------------------------------------------------------- /editing/styleclip/global_directions/templates.txt: -------------------------------------------------------------------------------- 1 | a bad photo of a {}. 2 | a sculpture of a {}. 3 | a photo of the hard to see {}. 4 | a low resolution photo of the {}. 5 | a rendering of a {}. 6 | graffiti of a {}. 7 | a bad photo of the {}. 8 | a cropped photo of the {}. 9 | a tattoo of a {}. 10 | the embroidered {}. 11 | a photo of a hard to see {}. 12 | a bright photo of a {}. 13 | a photo of a clean {}. 14 | a photo of a dirty {}. 15 | a dark photo of the {}. 16 | a drawing of a {}. 17 | a photo of my {}. 18 | the plastic {}. 19 | a photo of the cool {}. 20 | a close-up photo of a {}. 21 | a black and white photo of the {}. 22 | a painting of the {}. 23 | a painting of a {}. 24 | a pixelated photo of the {}. 25 | a sculpture of the {}. 26 | a bright photo of the {}. 27 | a cropped photo of a {}. 28 | a plastic {}. 29 | a photo of the dirty {}. 30 | a jpeg corrupted photo of a {}. 31 | a blurry photo of the {}. 32 | a photo of the {}. 33 | a good photo of the {}. 34 | a rendering of the {}. 35 | a {} in a video game. 36 | a photo of one {}. 37 | a doodle of a {}. 38 | a close-up photo of the {}. 39 | a photo of a {}. 40 | the origami {}. 41 | the {} in a video game. 42 | a sketch of a {}. 43 | a doodle of the {}. 44 | a origami {}. 45 | a low resolution photo of a {}. 46 | the toy {}. 47 | a rendition of the {}. 48 | a photo of the clean {}. 49 | a photo of a large {}. 50 | a rendition of a {}. 51 | a photo of a nice {}. 52 | a photo of a weird {}. 53 | a blurry photo of a {}. 54 | a cartoon {}. 55 | art of a {}. 56 | a sketch of the {}. 57 | a embroidered {}. 58 | a pixelated photo of a {}. 59 | itap of the {}. 60 | a jpeg corrupted photo of the {}. 61 | a good photo of a {}. 62 | a plushie {}. 63 | a photo of the nice {}. 64 | a photo of the small {}. 65 | a photo of the weird {}. 66 | the cartoon {}. 67 | art of the {}. 68 | a drawing of the {}. 69 | a photo of the large {}. 70 | a black and white photo of a {}. 71 | the plushie {}. 72 | a dark photo of a {}. 73 | itap of a {}. 74 | graffiti of the {}. 75 | a toy {}. 76 | itap of my {}. 77 | a photo of a cool {}. 78 | a photo of a small {}. 79 | a tattoo of the {}. -------------------------------------------------------------------------------- /editing/styleclip/stylespace_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32] 4 | 5 | TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3)) 6 | STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11] 7 | 8 | def features_channels_to_s(s_without_torgb, s_std): 9 | s = [] 10 | start_index_features = 0 11 | for c in range(len(STYLESPACE_DIMENSIONS)): 12 | if c in STYLESPACE_INDICES_WITHOUT_TORGB: 13 | end_index_features = start_index_features + STYLESPACE_DIMENSIONS[c] 14 | s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c] 15 | start_index_features = end_index_features 16 | else: 17 | s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).cuda() 18 | s_i = s_i.view(1, 1, -1, 1, 1) 19 | s.append(s_i) 20 | return s -------------------------------------------------------------------------------- /environment/hyperstyle_env.yaml: -------------------------------------------------------------------------------- 1 | name: hyperstyle_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2020.4.5.1=hecc5488_0 8 | - certifi=2020.4.5.1=py36h9f0ad1d_0 9 | - libedit=3.1.20181209=hc058e9b_0 10 | - libffi=3.2.1=hd88cf55_4 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - ninja=1.10.0=hc9558a2_0 15 | - openssl=1.1.1g=h516909a_0 16 | - pip=20.0.2=py36_3 17 | - python=3.6.7=h0371630_0 18 | - python_abi=3.6=1_cp36m 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=46.4.0=py36_0 21 | - sqlite=3.31.1=h62c20be_1 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.5=h7b6447c_0 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - scipy==1.4.1 28 | - matplotlib==3.2.1 29 | - tqdm==4.46.0 30 | - numpy==1.18.4 31 | - opencv-python==4.2.0.34 32 | - pillow==7.1.2 33 | - tensorboard==2.2.1 34 | - torch==1.10.0 35 | - torchvision==0.11.1 36 | prefix: ~/anaconda3/envs/hyperstyle_env -------------------------------------------------------------------------------- /licenses/LICENSE_encoder4editing: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 omertov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_insightface: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_lpips: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /licenses/LICENSE_pixel2style2pixel: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_restyle: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_stylegan2: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/__init__.py -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/encoders/__init__.py -------------------------------------------------------------------------------- /models/encoders/e4e.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | 8 | from models.stylegan2.model import Generator 9 | from configs.paths_config import model_paths 10 | from models.encoders import restyle_e4e_encoders 11 | from utils.resnet_mapping import RESNET_MAPPING 12 | 13 | 14 | class e4e(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(e4e, self).__init__() 18 | self.set_opts(opts) 19 | self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 20 | # Define architecture 21 | self.encoder = self.set_encoder() 22 | self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | 27 | def set_encoder(self): 28 | if self.opts.encoder_type == 'ProgressiveBackboneEncoder': 29 | encoder = restyle_e4e_encoders.ProgressiveBackboneEncoder(50, 'ir_se', self.n_styles, self.opts) 30 | elif self.opts.encoder_type == 'ResNetProgressiveBackboneEncoder': 31 | encoder = restyle_e4e_encoders.ResNetProgressiveBackboneEncoder(self.n_styles, self.opts) 32 | else: 33 | raise Exception(f'{self.opts.encoder_type} is not a valid encoders') 34 | return encoder 35 | 36 | def load_weights(self): 37 | if self.opts.checkpoint_path is not None: 38 | print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}') 39 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 40 | self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=True) 41 | self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) 42 | self.__load_latent_avg(ckpt) 43 | else: 44 | encoder_ckpt = self.__get_encoder_checkpoint() 45 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 46 | print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') 47 | ckpt = torch.load(self.opts.stylegan_weights) 48 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 49 | self.__load_latent_avg(ckpt, repeat=self.n_styles) 50 | 51 | def forward(self, x, latent=None, resize=True, input_code=False, randomize_noise=True, 52 | return_latents=False, average_code=False, input_is_full=False): 53 | 54 | if input_code: 55 | codes = x 56 | else: 57 | codes = self.encoder(x) 58 | # residual step 59 | if x.shape[1] == 6 and latent is not None: 60 | # learn error with respect to previous iteration 61 | codes = codes + latent 62 | else: 63 | # first iteration is with respect to the avg latent code 64 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 65 | 66 | if average_code: 67 | input_is_latent = True 68 | else: 69 | input_is_latent = (not input_code) or (input_is_full) 70 | 71 | images, result_latent = self.decoder([codes], 72 | input_is_latent=input_is_latent, 73 | randomize_noise=randomize_noise, 74 | return_latents=return_latents) 75 | 76 | if resize: 77 | images = self.face_pool(images) 78 | 79 | if return_latents: 80 | return images, result_latent 81 | else: 82 | return images 83 | 84 | def set_opts(self, opts): 85 | self.opts = opts 86 | 87 | def __load_latent_avg(self, ckpt, repeat=None): 88 | if 'latent_avg' in ckpt: 89 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 90 | if repeat is not None: 91 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 92 | else: 93 | self.latent_avg = None 94 | 95 | def __get_encoder_checkpoint(self): 96 | if "ffhq" in self.opts.dataset_type: 97 | print('Loading encoders weights from irse50!') 98 | encoder_ckpt = torch.load(model_paths['ir_se50']) 99 | # Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder 100 | if self.opts.input_nc != 3: 101 | shape = encoder_ckpt['input_layer.0.weight'].shape 102 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 103 | altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] 104 | encoder_ckpt['input_layer.0.weight'] = altered_input_layer 105 | return encoder_ckpt 106 | else: 107 | print('Loading encoders weights from resnet34!') 108 | encoder_ckpt = torch.load(model_paths['resnet34']) 109 | # Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder 110 | if self.opts.input_nc != 3: 111 | shape = encoder_ckpt['conv1.weight'].shape 112 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 113 | altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight'] 114 | encoder_ckpt['conv1.weight'] = altered_input_layer 115 | mapped_encoder_ckpt = dict(encoder_ckpt) 116 | for p, v in encoder_ckpt.items(): 117 | for original_name, psp_name in RESNET_MAPPING.items(): 118 | if original_name in p: 119 | mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v 120 | mapped_encoder_ckpt.pop(p) 121 | return encoder_ckpt 122 | 123 | @staticmethod 124 | def __get_keys(d, name): 125 | if 'state_dict' in d: 126 | d = d['state_dict'] 127 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 128 | return d_filt 129 | -------------------------------------------------------------------------------- /models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Linear 4 | import torch.nn.functional as F 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | SEModule(depth, 16) 115 | ) 116 | 117 | def forward(self, x): 118 | shortcut = self.shortcut_layer(x) 119 | res = self.res_layer(x) 120 | return res + shortcut 121 | 122 | 123 | class SeparableConv2d(torch.nn.Module): 124 | 125 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 126 | super(SeparableConv2d, self).__init__() 127 | self.depthwise = Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, bias=bias, padding=1) 128 | self.pointwise = Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) 129 | 130 | def forward(self, x): 131 | out = self.depthwise(x) 132 | out = self.pointwise(out) 133 | return out 134 | 135 | 136 | def _upsample_add(x, y): 137 | """Upsample and add two feature maps. 138 | Args: 139 | x: (Variable) top feature map to be upsampled. 140 | y: (Variable) lateral feature map. 141 | Returns: 142 | (Variable) added feature map. 143 | Note in PyTorch, when input size is odd, the upsampled feature map 144 | with `F.upsample(..., scale_factor=2, mode='nearest')` 145 | maybe not equal to the lateral feature map size. 146 | e.g. 147 | original input size: [N,_,15,15] -> 148 | conv2d feature map size: [N,_,8,8] -> 149 | upsampled feature map size: [N,_,16,16] 150 | So we choose bilinear upsample which supports arbitrary output sizes. 151 | """ 152 | _, _, H, W = y.size() 153 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 154 | 155 | 156 | class SeparableBlock(Module): 157 | 158 | def __init__(self, input_size, kernel_channels_in, kernel_channels_out, kernel_size): 159 | super(SeparableBlock, self).__init__() 160 | 161 | self.input_size = input_size 162 | self.kernel_size = kernel_size 163 | self.kernel_channels_in = kernel_channels_in 164 | self.kernel_channels_out = kernel_channels_out 165 | 166 | self.make_kernel_in = Linear(input_size, kernel_size * kernel_size * kernel_channels_in) 167 | self.make_kernel_out = Linear(input_size, kernel_size * kernel_size * kernel_channels_out) 168 | 169 | self.kernel_linear_in = Linear(kernel_channels_in, kernel_channels_in) 170 | self.kernel_linear_out = Linear(kernel_channels_out, kernel_channels_out) 171 | 172 | def forward(self, features): 173 | 174 | features = features.view(-1, self.input_size) 175 | 176 | kernel_in = self.make_kernel_in(features).view(-1, self.kernel_size, self.kernel_size, 1, self.kernel_channels_in) 177 | kernel_out = self.make_kernel_out(features).view(-1, self.kernel_size, self.kernel_size, self.kernel_channels_out, 1) 178 | 179 | kernel = torch.matmul(kernel_out, kernel_in) 180 | 181 | kernel = self.kernel_linear_in(kernel).permute(0, 1, 2, 4, 3) 182 | kernel = self.kernel_linear_out(kernel) 183 | kernel = kernel.permute(0, 4, 3, 1, 2) 184 | 185 | return kernel 186 | -------------------------------------------------------------------------------- /models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /models/encoders/psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import math 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import torch 8 | from torch import nn 9 | from models.encoders import w_encoder 10 | from models.stylegan2.model import Generator 11 | from configs.paths_config import model_paths 12 | 13 | 14 | def get_keys(d, name): 15 | if 'state_dict' in d: 16 | d = d['state_dict'] 17 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 18 | return d_filt 19 | 20 | 21 | class pSp(nn.Module): 22 | 23 | def __init__(self, opts): 24 | super(pSp, self).__init__() 25 | self.set_opts(opts) 26 | # compute number of style inputs based on the output resolution 27 | self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 28 | # Define architecture 29 | self.encoder = self.set_encoder() 30 | self.decoder = Generator(self.opts.output_size, 512, 8) 31 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 32 | # Load weights if needed 33 | self.load_weights() 34 | 35 | def set_encoder(self): 36 | if self.opts.encoder_type == 'WEncoder': 37 | encoder = w_encoder.WEncoder(50, 'ir_se', self.opts) 38 | else: 39 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 40 | return encoder 41 | 42 | def load_weights(self): 43 | if self.opts.checkpoint_path is not None: 44 | print('Loading WEncoder from checkpoint: {}'.format(self.opts.checkpoint_path)) 45 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 46 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 47 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 48 | self.__load_latent_avg(ckpt) 49 | else: 50 | print('Loading encoders weights from irse50!') 51 | encoder_ckpt = torch.load(model_paths['ir_se50']) 52 | # if input to encoder is not an RGB image, do not load the input layer weights 53 | if self.opts.label_nc != 0: 54 | encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} 55 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 56 | print('Loading decoder weights from pretrained!') 57 | ckpt = torch.load(self.opts.stylegan_weights) 58 | self.decoder.load_state_dict(ckpt['g_ema'], strict=False) 59 | if self.opts.learn_in_w: 60 | self.__load_latent_avg(ckpt, repeat=1) 61 | else: 62 | self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) 63 | 64 | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, inject_latent=None, 65 | return_latents=False, alpha=None): 66 | 67 | if input_code: 68 | codes = x 69 | else: 70 | codes = self.encoder(x) 71 | if codes.ndim == 2: 72 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] 73 | else: 74 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 75 | 76 | if latent_mask is not None: 77 | for i in latent_mask: 78 | if inject_latent is not None: 79 | if alpha is not None: 80 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 81 | else: 82 | codes[:, i] = inject_latent[:, i] 83 | else: 84 | codes[:, i] = 0 85 | 86 | input_is_latent = not input_code 87 | images, result_latent = self.decoder([codes], 88 | input_is_latent=input_is_latent, 89 | randomize_noise=randomize_noise, 90 | return_latents=return_latents) 91 | 92 | if resize: 93 | images = self.face_pool(images) 94 | 95 | if return_latents: 96 | return images, result_latent 97 | else: 98 | return images 99 | 100 | def set_opts(self, opts): 101 | self.opts = opts 102 | 103 | def __load_latent_avg(self, ckpt, repeat=None): 104 | if 'latent_avg' in ckpt: 105 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 106 | if repeat is not None: 107 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 108 | else: 109 | self.latent_avg = None -------------------------------------------------------------------------------- /models/encoders/restyle_e4e_encoders.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import numpy as np 3 | from torch import nn 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 5 | from torchvision.models import resnet34 6 | 7 | from models.stylegan2.model import EqualLinear 8 | from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 9 | 10 | 11 | class ProgressiveStage(Enum): 12 | WTraining = 0 13 | Delta1Training = 1 14 | Delta2Training = 2 15 | Delta3Training = 3 16 | Delta4Training = 4 17 | Delta5Training = 5 18 | Delta6Training = 6 19 | Delta7Training = 7 20 | Delta8Training = 8 21 | Delta9Training = 9 22 | Delta10Training = 10 23 | Delta11Training = 11 24 | Delta12Training = 12 25 | Delta13Training = 13 26 | Delta14Training = 14 27 | Delta15Training = 15 28 | Delta16Training = 16 29 | Delta17Training = 17 30 | Inference = 18 31 | 32 | 33 | class GradualStyleBlock(Module): 34 | def __init__(self, in_c, out_c, spatial): 35 | super(GradualStyleBlock, self).__init__() 36 | self.out_c = out_c 37 | self.spatial = spatial 38 | num_pools = int(np.log2(spatial)) 39 | modules = [] 40 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 41 | nn.LeakyReLU()] 42 | for i in range(num_pools - 1): 43 | modules += [ 44 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 45 | nn.LeakyReLU() 46 | ] 47 | self.convs = nn.Sequential(*modules) 48 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 49 | 50 | def forward(self, x): 51 | x = self.convs(x) 52 | x = x.view(-1, self.out_c) 53 | x = self.linear(x) 54 | return x 55 | 56 | 57 | class ProgressiveBackboneEncoder(Module): 58 | """ 59 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 60 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE50 backbone with the 61 | progressive training scheme from e4e_modules. 62 | Note this class is designed to be used for the human facial domain. 63 | """ 64 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 65 | super(ProgressiveBackboneEncoder, self).__init__() 66 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 67 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 68 | blocks = get_blocks(num_layers) 69 | if mode == 'ir': 70 | unit_module = bottleneck_IR 71 | elif mode == 'ir_se': 72 | unit_module = bottleneck_IR_SE 73 | 74 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 75 | BatchNorm2d(64), 76 | PReLU(64)) 77 | modules = [] 78 | for block in blocks: 79 | for bottleneck in block: 80 | modules.append(unit_module(bottleneck.in_channel, 81 | bottleneck.depth, 82 | bottleneck.stride)) 83 | self.body = Sequential(*modules) 84 | 85 | self.styles = nn.ModuleList() 86 | self.style_count = n_styles 87 | for i in range(self.style_count): 88 | style = GradualStyleBlock(512, 512, 16) 89 | self.styles.append(style) 90 | self.progressive_stage = ProgressiveStage.Inference 91 | 92 | def get_deltas_starting_dimensions(self): 93 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 94 | return list(range(self.style_count)) # Each dimension has a delta applied to 95 | 96 | def set_progressive_stage(self, new_stage: ProgressiveStage): 97 | # In this encoder we train all the pyramid (At least as a first stage experiment 98 | self.progressive_stage = new_stage 99 | print('Changed progressive stage to: ', new_stage) 100 | 101 | def forward(self, x): 102 | x = self.input_layer(x) 103 | x = self.body(x) 104 | 105 | # get initial w0 from first map2style layer 106 | w0 = self.styles[0](x) 107 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 108 | 109 | # learn the deltas up to the current stage 110 | stage = self.progressive_stage.value 111 | for i in range(1, min(stage + 1, self.style_count)): 112 | delta_i = self.styles[i](x) 113 | w[:, i] += delta_i 114 | return w 115 | 116 | 117 | class ResNetProgressiveBackboneEncoder(Module): 118 | """ 119 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 120 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone with the 121 | progressive training scheme from e4e_modules. 122 | """ 123 | def __init__(self, n_styles=18, opts=None): 124 | super(ResNetProgressiveBackboneEncoder, self).__init__() 125 | 126 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 127 | self.bn1 = BatchNorm2d(64) 128 | self.relu = PReLU(64) 129 | 130 | resnet_basenet = resnet34(pretrained=True) 131 | blocks = [ 132 | resnet_basenet.layer1, 133 | resnet_basenet.layer2, 134 | resnet_basenet.layer3, 135 | resnet_basenet.layer4 136 | ] 137 | modules = [] 138 | for block in blocks: 139 | for bottleneck in block: 140 | modules.append(bottleneck) 141 | self.body = Sequential(*modules) 142 | 143 | self.styles = nn.ModuleList() 144 | self.style_count = n_styles 145 | for i in range(self.style_count): 146 | style = GradualStyleBlock(512, 512, 16) 147 | self.styles.append(style) 148 | self.progressive_stage = ProgressiveStage.Inference 149 | 150 | def get_deltas_starting_dimensions(self): 151 | ''' Get a list of the initial dimension of every delta from which it is applied ''' 152 | return list(range(self.style_count)) # Each dimension has a delta applied to 153 | 154 | def set_progressive_stage(self, new_stage: ProgressiveStage): 155 | # In this encoder we train all the pyramid (At least as a first stage experiment 156 | self.progressive_stage = new_stage 157 | print('Changed progressive stage to: ', new_stage) 158 | 159 | def forward(self, x): 160 | x = self.conv1(x) 161 | x = self.bn1(x) 162 | x = self.relu(x) 163 | x = self.body(x) 164 | 165 | # get initial w0 from first map2style layer 166 | w0 = self.styles[0](x) 167 | w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) 168 | 169 | # learn the deltas up to the current stage 170 | stage = self.progressive_stage.value 171 | for i in range(1, min(stage + 1, self.style_count)): 172 | delta_i = self.styles[i](x) 173 | w[:, i] += delta_i 174 | return w 175 | -------------------------------------------------------------------------------- /models/encoders/w_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | 5 | from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 6 | from models.stylegan2.model import EqualLinear 7 | 8 | 9 | class WEncoder(Module): 10 | def __init__(self, num_layers, mode='ir', opts=None): 11 | super(WEncoder, self).__init__() 12 | print('Using WEncoder') 13 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 14 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 24 | self.linear = EqualLinear(512, 512, lr_mul=1) 25 | modules = [] 26 | for block in blocks: 27 | for bottleneck in block: 28 | modules.append(unit_module(bottleneck.in_channel, 29 | bottleneck.depth, 30 | bottleneck.stride)) 31 | self.body = Sequential(*modules) 32 | log_size = int(math.log(opts.output_size, 2)) 33 | self.style_count = 2 * log_size - 2 34 | 35 | def forward(self, x): 36 | x = self.input_layer(x) 37 | x = self.body(x) 38 | x = self.output_pool(x) 39 | x = x.view(-1, 512) 40 | x = self.linear(x) 41 | return x.repeat(self.style_count, 1, 1).permute(1, 0, 2) 42 | 43 | -------------------------------------------------------------------------------- /models/hypernetworks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/hypernetworks/__init__.py -------------------------------------------------------------------------------- /models/hypernetworks/hypernetwork.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import BatchNorm2d, PReLU, Sequential, Module 3 | from torchvision.models import resnet34 4 | 5 | from models.hypernetworks.refinement_blocks import HyperRefinementBlock, RefinementBlock, RefinementBlockSeparable 6 | from models.hypernetworks.shared_weights_hypernet import SharedWeightsHypernet 7 | 8 | 9 | class SharedWeightsHyperNetResNet(Module): 10 | 11 | def __init__(self, opts): 12 | super(SharedWeightsHyperNetResNet, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 15 | self.bn1 = BatchNorm2d(64) 16 | self.relu = PReLU(64) 17 | 18 | resnet_basenet = resnet34(pretrained=True) 19 | blocks = [ 20 | resnet_basenet.layer1, 21 | resnet_basenet.layer2, 22 | resnet_basenet.layer3, 23 | resnet_basenet.layer4 24 | ] 25 | modules = [] 26 | for block in blocks: 27 | for bottleneck in block: 28 | modules.append(bottleneck) 29 | self.body = Sequential(*modules) 30 | 31 | if len(opts.layers_to_tune) == 0: 32 | self.layers_to_tune = list(range(opts.n_hypernet_outputs)) 33 | else: 34 | self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] 35 | 36 | self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] 37 | self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None) 38 | 39 | self.refinement_blocks = nn.ModuleList() 40 | self.n_outputs = opts.n_hypernet_outputs 41 | for layer_idx in range(self.n_outputs): 42 | if layer_idx in self.layers_to_tune: 43 | if layer_idx in self.shared_layers: 44 | refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) 45 | else: 46 | refinement_block = RefinementBlock(layer_idx, opts, n_channels=512, inner_c=256) 47 | else: 48 | refinement_block = None 49 | self.refinement_blocks.append(refinement_block) 50 | 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = self.bn1(x) 54 | x = self.relu(x) 55 | x = self.body(x) 56 | weight_deltas = [] 57 | for j in range(self.n_outputs): 58 | if self.refinement_blocks[j] is not None: 59 | delta = self.refinement_blocks[j](x) 60 | else: 61 | delta = None 62 | weight_deltas.append(delta) 63 | return weight_deltas 64 | 65 | 66 | class SharedWeightsHyperNetResNetSeparable(Module): 67 | 68 | def __init__(self, opts): 69 | super(SharedWeightsHyperNetResNetSeparable, self).__init__() 70 | 71 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 72 | self.bn1 = BatchNorm2d(64) 73 | self.relu = PReLU(64) 74 | 75 | resnet_basenet = resnet34(pretrained=True) 76 | blocks = [ 77 | resnet_basenet.layer1, 78 | resnet_basenet.layer2, 79 | resnet_basenet.layer3, 80 | resnet_basenet.layer4 81 | ] 82 | modules = [] 83 | for block in blocks: 84 | for bottleneck in block: 85 | modules.append(bottleneck) 86 | self.body = Sequential(*modules) 87 | 88 | if len(opts.layers_to_tune) == 0: 89 | self.layers_to_tune = list(range(opts.n_hypernet_outputs)) 90 | else: 91 | self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] 92 | 93 | self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] 94 | self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None) 95 | 96 | self.refinement_blocks = nn.ModuleList() 97 | self.n_outputs = opts.n_hypernet_outputs 98 | for layer_idx in range(self.n_outputs): 99 | if layer_idx in self.layers_to_tune: 100 | if layer_idx in self.shared_layers: 101 | refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) 102 | else: 103 | refinement_block = RefinementBlockSeparable(layer_idx, opts, n_channels=512, inner_c=256) 104 | else: 105 | refinement_block = None 106 | self.refinement_blocks.append(refinement_block) 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.body(x) 113 | weight_deltas = [] 114 | for j in range(self.n_outputs): 115 | if self.refinement_blocks[j] is not None: 116 | delta = self.refinement_blocks[j](x) 117 | else: 118 | delta = None 119 | weight_deltas.append(delta) 120 | return weight_deltas 121 | -------------------------------------------------------------------------------- /models/hypernetworks/refinement_blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Sequential, Module 4 | 5 | from models.encoders.helpers import SeparableBlock 6 | from models.stylegan2.model import EqualLinear 7 | 8 | 9 | # layer_idx: [kernel_size, in_channels, out_channels] 10 | PARAMETERS = { 11 | 0: [3, 512, 512], 12 | 1: [1, 512, 3], 13 | 2: [3, 512, 512], 14 | 3: [3, 512, 512], 15 | 4: [1, 512, 3], 16 | 5: [3, 512, 512], 17 | 6: [3, 512, 512], 18 | 7: [1, 512, 3], 19 | 8: [3, 512, 512], 20 | 9: [3, 512, 512], 21 | 10: [1, 512, 3], 22 | 11: [3, 512, 512], 23 | 12: [3, 512, 512], 24 | 13: [1, 512, 3], 25 | 14: [3, 512, 256], 26 | 15: [3, 256, 256], 27 | 16: [1, 256, 3], 28 | 17: [3, 256, 128], 29 | 18: [3, 128, 128], 30 | 19: [1, 128, 3], 31 | 20: [3, 128, 64], 32 | 21: [3, 64, 64], 33 | 22: [1, 64, 3], 34 | 23: [3, 64, 32], 35 | 24: [3, 32, 32], 36 | 25: [1, 32, 3] 37 | } 38 | TO_RGB_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25] 39 | 40 | 41 | class RefinementBlock(Module): 42 | 43 | def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): 44 | super(RefinementBlock, self).__init__() 45 | self.layer_idx = layer_idx 46 | self.opts = opts 47 | self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] 48 | self.spatial = spatial 49 | self.n_channels = n_channels 50 | self.inner_c = inner_c 51 | self.out_c = 512 52 | num_pools = int(np.log2(self.spatial)) - 1 53 | if self.kernel_size == 3: 54 | num_pools = num_pools - 1 55 | self.modules = [] 56 | self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 57 | for i in range(num_pools - 1): 58 | self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 59 | self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 60 | self.convs = nn.Sequential(*self.modules) 61 | 62 | if layer_idx in TO_RGB_LAYERS: 63 | self.output = Sequential( 64 | Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, padding=0)) 65 | else: 66 | self.output = Sequential(nn.AdaptiveAvgPool2d((1, 1)), 67 | Conv2d(self.out_c, self.in_channels * self.out_channels, kernel_size=1, stride=1, 68 | padding=0)) 69 | 70 | def forward(self, x): 71 | x = self.convs(x) 72 | x = self.output(x) 73 | if self.layer_idx in TO_RGB_LAYERS: 74 | x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) 75 | else: 76 | x = x.view(-1, self.out_channels, self.in_channels) 77 | x = x.unsqueeze(3).repeat(1, 1, 1, self.kernel_size).unsqueeze(4).repeat(1, 1, 1, 1, self.kernel_size) 78 | return x 79 | 80 | 81 | class HyperRefinementBlock(Module): 82 | def __init__(self, hypernet, n_channels=512, inner_c=128, spatial=16): 83 | super(HyperRefinementBlock, self).__init__() 84 | self.n_channels = n_channels 85 | self.inner_c = inner_c 86 | self.out_c = 512 87 | num_pools = int(np.log2(spatial)) 88 | modules = [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()] 89 | for i in range(num_pools - 1): 90 | modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 91 | modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 92 | self.convs = nn.Sequential(*modules) 93 | self.linear = EqualLinear(self.out_c, self.out_c, lr_mul=1) 94 | self.hypernet = hypernet 95 | 96 | def forward(self, features): 97 | code = self.convs(features) 98 | code = code.view(-1, self.out_c) 99 | code = self.linear(code) 100 | weight_delta = self.hypernet(code) 101 | return weight_delta 102 | 103 | 104 | class RefinementBlockSeparable(Module): 105 | 106 | def __init__(self, layer_idx, opts, n_channels=512, inner_c=256, spatial=16): 107 | super(RefinementBlockSeparable, self).__init__() 108 | self.layer_idx = layer_idx 109 | self.kernel_size, self.in_channels, self.out_channels = PARAMETERS[self.layer_idx] 110 | self.spatial = spatial 111 | self.n_channels = n_channels 112 | self.inner_c = inner_c 113 | self.out_c = 512 114 | num_pools = int(np.log2(self.spatial)) - 1 115 | self.modules = [] 116 | self.modules += [Conv2d(self.n_channels, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 117 | for i in range(num_pools - 1): 118 | self.modules += [Conv2d(self.inner_c, self.inner_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 119 | self.modules += [Conv2d(self.inner_c, self.out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 120 | self.convs = nn.Sequential(*self.modules) 121 | 122 | self.opts = opts 123 | if self.layer_idx in TO_RGB_LAYERS: 124 | self.output = Sequential(Conv2d(self.out_c, self.in_channels * self.out_channels, 125 | kernel_size=1, stride=1, padding=0)) 126 | else: 127 | self.output = Sequential(SeparableBlock(input_size=self.out_c, 128 | kernel_channels_in=self.in_channels, 129 | kernel_channels_out=self.out_channels, 130 | kernel_size=self.kernel_size)) 131 | 132 | def forward(self, x): 133 | x = self.convs(x) 134 | x = self.output(x) 135 | if self.layer_idx in TO_RGB_LAYERS: 136 | x = x.view(-1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size) 137 | return x -------------------------------------------------------------------------------- /models/hypernetworks/shared_weights_hypernet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class SharedWeightsHypernet(nn.Module): 7 | 8 | def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None): 9 | super(SharedWeightsHypernet, self).__init__() 10 | self.mode = mode 11 | self.z_dim = z_dim 12 | self.f_size = f_size 13 | if self.mode == 'delta_per_channel': 14 | self.f_size = 1 15 | self.out_size = out_size 16 | self.in_size = in_size 17 | 18 | self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).cuda() / 40, 2)) 19 | self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).cuda() / 40, 2)) 20 | 21 | self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).cuda() / 40, 2)) 22 | self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).cuda() / 40, 2)) 23 | 24 | def forward(self, z): 25 | batch_size = z.shape[0] 26 | h_in = torch.matmul(z, self.w2) + self.b2 27 | h_in = h_in.view(batch_size, self.in_size, self.z_dim) 28 | 29 | h_final = torch.matmul(h_in, self.w1) + self.b1 30 | kernel = h_final.view(batch_size, self.out_size, self.in_size, self.f_size, self.f_size) 31 | if self.mode == 'delta_per_channel': # repeat per channel values to the 3x3 conv kernels 32 | kernel = kernel.repeat(1, 1, 1, 3, 3) 33 | return kernel 34 | -------------------------------------------------------------------------------- /models/hyperstyle.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import copy 5 | from argparse import Namespace 6 | 7 | from models.encoders.psp import pSp 8 | from models.stylegan2.model import Generator 9 | from configs.paths_config import model_paths 10 | from models.hypernetworks.hypernetwork import SharedWeightsHyperNetResNet, SharedWeightsHyperNetResNetSeparable 11 | from utils.resnet_mapping import RESNET_MAPPING 12 | 13 | 14 | class HyperStyle(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(HyperStyle, self).__init__() 18 | self.set_opts(opts) 19 | self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 20 | # Define architecture 21 | self.hypernet = self.set_hypernet() 22 | self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2) 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | if self.opts.load_w_encoder: 27 | self.w_encoder.eval() 28 | 29 | def set_hypernet(self): 30 | if self.opts.output_size == 1024: 31 | self.opts.n_hypernet_outputs = 26 32 | elif self.opts.output_size == 512: 33 | self.opts.n_hypernet_outputs = 23 34 | elif self.opts.output_size == 256: 35 | self.opts.n_hypernet_outputs = 20 36 | else: 37 | raise ValueError(f"Invalid Output Size! Support sizes: [1024, 512, 256]!") 38 | networks = { 39 | "SharedWeightsHyperNetResNet": SharedWeightsHyperNetResNet(opts=self.opts), 40 | "SharedWeightsHyperNetResNetSeparable": SharedWeightsHyperNetResNetSeparable(opts=self.opts), 41 | } 42 | return networks[self.opts.encoder_type] 43 | 44 | def load_weights(self): 45 | if self.opts.checkpoint_path is not None: 46 | print(f'Loading HyperStyle from checkpoint: {self.opts.checkpoint_path}') 47 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 48 | self.hypernet.load_state_dict(self.__get_keys(ckpt, 'hypernet'), strict=True) 49 | self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True) 50 | self.__load_latent_avg(ckpt) 51 | if self.opts.load_w_encoder: 52 | self.w_encoder = self.__get_pretrained_w_encoder() 53 | else: 54 | hypernet_ckpt = self.__get_hypernet_checkpoint() 55 | self.hypernet.load_state_dict(hypernet_ckpt, strict=False) 56 | print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}') 57 | ckpt = torch.load(self.opts.stylegan_weights) 58 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 59 | self.__load_latent_avg(ckpt, repeat=self.n_styles) 60 | if self.opts.load_w_encoder: 61 | self.w_encoder = self.__get_pretrained_w_encoder() 62 | 63 | def forward(self, x, resize=True, input_code=False, randomize_noise=True, return_latents=False, 64 | return_weight_deltas_and_codes=False, weights_deltas=None, y_hat=None, codes=None): 65 | 66 | if input_code: 67 | codes = x 68 | else: 69 | if y_hat is None: 70 | assert self.opts.load_w_encoder, "Cannot infer latent code when e4e isn't loaded." 71 | y_hat, codes = self.__get_initial_inversion(x, resize=True) 72 | 73 | # concatenate original input with w-reconstruction or current reconstruction 74 | x_input = torch.cat([x, y_hat], dim=1) 75 | 76 | # pass through hypernet to get per-layer deltas 77 | hypernet_outputs = self.hypernet(x_input) 78 | if weights_deltas is None: 79 | weights_deltas = hypernet_outputs 80 | else: 81 | weights_deltas = [weights_deltas[i] + hypernet_outputs[i] if weights_deltas[i] is not None else None 82 | for i in range(len(hypernet_outputs))] 83 | 84 | input_is_latent = (not input_code) 85 | images, result_latent = self.decoder([codes], 86 | weights_deltas=weights_deltas, 87 | input_is_latent=input_is_latent, 88 | randomize_noise=randomize_noise, 89 | return_latents=return_latents) 90 | 91 | if resize: 92 | images = self.face_pool(images) 93 | 94 | if return_latents and return_weight_deltas_and_codes: 95 | return images, result_latent, weights_deltas, codes, y_hat 96 | elif return_latents: 97 | return images, result_latent 98 | elif return_weight_deltas_and_codes: 99 | return images, weights_deltas, codes 100 | else: 101 | return images 102 | 103 | def set_opts(self, opts): 104 | self.opts = opts 105 | 106 | def __load_latent_avg(self, ckpt, repeat=None): 107 | if 'latent_avg' in ckpt: 108 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 109 | if repeat is not None: 110 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 111 | else: 112 | self.latent_avg = None 113 | 114 | def __get_hypernet_checkpoint(self): 115 | print('Loading hypernet weights from resnet34!') 116 | hypernet_ckpt = torch.load(model_paths['resnet34']) 117 | # Transfer the RGB input of the resnet34 network to the first 3 input channels of hypernet 118 | if self.opts.input_nc != 3: 119 | shape = hypernet_ckpt['conv1.weight'].shape 120 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 121 | altered_input_layer[:, :3, :, :] = hypernet_ckpt['conv1.weight'] 122 | hypernet_ckpt['conv1.weight'] = altered_input_layer 123 | mapped_hypernet_ckpt = dict(hypernet_ckpt) 124 | for p, v in hypernet_ckpt.items(): 125 | for original_name, net_name in RESNET_MAPPING.items(): 126 | if original_name in p: 127 | mapped_hypernet_ckpt[p.replace(original_name, net_name)] = v 128 | mapped_hypernet_ckpt.pop(p) 129 | return hypernet_ckpt 130 | 131 | @staticmethod 132 | def __get_keys(d, name): 133 | if 'state_dict' in d: 134 | d = d['state_dict'] 135 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 136 | return d_filt 137 | 138 | def __get_pretrained_w_encoder(self): 139 | print("Loading pretrained W encoder...") 140 | opts_w_encoder = vars(copy.deepcopy(self.opts)) 141 | opts_w_encoder['checkpoint_path'] = self.opts.w_encoder_checkpoint_path 142 | opts_w_encoder['encoder_type'] = self.opts.w_encoder_type 143 | opts_w_encoder['input_nc'] = 3 144 | opts_w_encoder = Namespace(**opts_w_encoder) 145 | w_net = pSp(opts_w_encoder) 146 | w_net = w_net.encoder 147 | w_net.eval() 148 | w_net.cuda() 149 | return w_net 150 | 151 | def __get_initial_inversion(self, x, resize=True): 152 | # get initial inversion and reconstruction of batch 153 | with torch.no_grad(): 154 | return self.__get_w_inversion(x, resize) 155 | 156 | def __get_w_inversion(self, x, resize=True): 157 | if self.w_encoder.training: 158 | self.w_encoder.eval() 159 | codes = self.w_encoder.forward(x) 160 | if codes.ndim == 2: 161 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] 162 | else: 163 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 164 | y_hat, _ = self.decoder([codes], 165 | weights_deltas=None, 166 | input_is_latent=True, 167 | randomize_noise=False, 168 | return_latents=False) 169 | if resize: 170 | y_hat = self.face_pool(y_hat) 171 | if "cars" in self.opts.dataset_type: 172 | y_hat = y_hat[:, :, 32:224, :] 173 | return y_hat, codes 174 | -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 7 | from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 8 | 9 | device = 'cuda:0' 10 | 11 | 12 | class MTCNN(): 13 | def __init__(self): 14 | print(device) 15 | self.pnet = PNet().to(device) 16 | self.rnet = RNet().to(device) 17 | self.onet = ONet().to(device) 18 | self.pnet.eval() 19 | self.rnet.eval() 20 | self.onet.eval() 21 | self.refrence = get_reference_facial_points(default_square=True) 22 | 23 | def align(self, img): 24 | _, landmarks = self.detect_faces(img) 25 | if len(landmarks) == 0: 26 | return None, None 27 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 28 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 29 | return Image.fromarray(warped_face), tfm 30 | 31 | def align_multi(self, img, limit=None, min_face_size=30.0): 32 | boxes, landmarks = self.detect_faces(img, min_face_size) 33 | if limit: 34 | boxes = boxes[:limit] 35 | landmarks = landmarks[:limit] 36 | faces = [] 37 | tfms = [] 38 | for landmark in landmarks: 39 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 40 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 41 | faces.append(Image.fromarray(warped_face)) 42 | tfms.append(tfm) 43 | return boxes, faces, tfms 44 | 45 | def detect_faces(self, image, min_face_size=20.0, 46 | thresholds=[0.15, 0.25, 0.35], 47 | nms_thresholds=[0.7, 0.7, 0.7]): 48 | """ 49 | Arguments: 50 | image: an instance of PIL.Image. 51 | min_face_size: a float number. 52 | thresholds: a list of length 3. 53 | nms_thresholds: a list of length 3. 54 | 55 | Returns: 56 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 57 | bounding boxes and facial landmarks. 58 | """ 59 | 60 | # BUILD AN IMAGE PYRAMID 61 | width, height = image.size 62 | min_length = min(height, width) 63 | 64 | min_detection_size = 12 65 | factor = 0.707 # sqrt(0.5) 66 | 67 | # scales for scaling the image 68 | scales = [] 69 | 70 | # scales the image so that 71 | # minimum size that we can detect equals to 72 | # minimum face size that we want to detect 73 | m = min_detection_size / min_face_size 74 | min_length *= m 75 | 76 | factor_count = 0 77 | while min_length > min_detection_size: 78 | scales.append(m * factor ** factor_count) 79 | min_length *= factor 80 | factor_count += 1 81 | 82 | # STAGE 1 83 | 84 | # it will be returned 85 | bounding_boxes = [] 86 | 87 | with torch.no_grad(): 88 | # run P-Net on different scales 89 | for s in scales: 90 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 91 | bounding_boxes.append(boxes) 92 | 93 | # collect boxes (and offsets, and scores) from different scales 94 | bounding_boxes = [i for i in bounding_boxes if i is not None] 95 | bounding_boxes = np.vstack(bounding_boxes) 96 | 97 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 98 | bounding_boxes = bounding_boxes[keep] 99 | 100 | # use offsets predicted by pnet to transform bounding boxes 101 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 102 | # shape [n_boxes, 5] 103 | 104 | bounding_boxes = convert_to_square(bounding_boxes) 105 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 106 | 107 | # STAGE 2 108 | 109 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 110 | img_boxes = torch.FloatTensor(img_boxes).to(device) 111 | 112 | output = self.rnet(img_boxes) 113 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 114 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 115 | 116 | keep = np.where(probs[:, 1] > thresholds[1])[0] 117 | bounding_boxes = bounding_boxes[keep] 118 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 119 | offsets = offsets[keep] 120 | 121 | keep = nms(bounding_boxes, nms_thresholds[1]) 122 | bounding_boxes = bounding_boxes[keep] 123 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 124 | bounding_boxes = convert_to_square(bounding_boxes) 125 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 126 | 127 | # STAGE 3 128 | 129 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 130 | if len(img_boxes) == 0: 131 | return [], [] 132 | img_boxes = torch.FloatTensor(img_boxes).to(device) 133 | output = self.onet(img_boxes) 134 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 135 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 136 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 137 | 138 | keep = np.where(probs[:, 1] > thresholds[2])[0] 139 | bounding_boxes = bounding_boxes[keep] 140 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 141 | offsets = offsets[keep] 142 | landmarks = landmarks[keep] 143 | 144 | # compute landmark points 145 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 146 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 147 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 148 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 149 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 150 | 151 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 152 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 153 | bounding_boxes = bounding_boxes[keep] 154 | landmarks = landmarks[keep] 155 | 156 | return bounding_boxes, landmarks 157 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def nms(boxes, overlap_threshold=0.5, mode='union'): 6 | """Non-maximum suppression. 7 | 8 | Arguments: 9 | boxes: a float numpy array of shape [n, 5], 10 | where each row is (xmin, ymin, xmax, ymax, score). 11 | overlap_threshold: a float number. 12 | mode: 'union' or 'min'. 13 | 14 | Returns: 15 | list with indices of the selected boxes 16 | """ 17 | 18 | # if there are no boxes, return the empty list 19 | if len(boxes) == 0: 20 | return [] 21 | 22 | # list of picked indices 23 | pick = [] 24 | 25 | # grab the coordinates of the bounding boxes 26 | x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] 27 | 28 | area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) 29 | ids = np.argsort(score) # in increasing order 30 | 31 | while len(ids) > 0: 32 | 33 | # grab index of the largest value 34 | last = len(ids) - 1 35 | i = ids[last] 36 | pick.append(i) 37 | 38 | # compute intersections 39 | # of the box with the largest score 40 | # with the rest of boxes 41 | 42 | # left top corner of intersection boxes 43 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 44 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 45 | 46 | # right bottom corner of intersection boxes 47 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 48 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 49 | 50 | # width and height of intersection boxes 51 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 52 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 53 | 54 | # intersections' areas 55 | inter = w * h 56 | if mode == 'min': 57 | overlap = inter / np.minimum(area[i], area[ids[:last]]) 58 | elif mode == 'union': 59 | # intersection over union (IoU) 60 | overlap = inter / (area[i] + area[ids[:last]] - inter) 61 | 62 | # delete all boxes where overlap is too big 63 | ids = np.delete( 64 | ids, 65 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 66 | ) 67 | 68 | return pick 69 | 70 | 71 | def convert_to_square(bboxes): 72 | """Convert bounding boxes to a square form. 73 | 74 | Arguments: 75 | bboxes: a float numpy array of shape [n, 5]. 76 | 77 | Returns: 78 | a float numpy array of shape [n, 5], 79 | squared bounding boxes. 80 | """ 81 | 82 | square_bboxes = np.zeros_like(bboxes) 83 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 84 | h = y2 - y1 + 1.0 85 | w = x2 - x1 + 1.0 86 | max_side = np.maximum(h, w) 87 | square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 88 | square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 89 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 90 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 91 | return square_bboxes 92 | 93 | 94 | def calibrate_box(bboxes, offsets): 95 | """Transform bounding boxes to be more like true bounding boxes. 96 | 'offsets' is one of the outputs of the nets. 97 | 98 | Arguments: 99 | bboxes: a float numpy array of shape [n, 5]. 100 | offsets: a float numpy array of shape [n, 4]. 101 | 102 | Returns: 103 | a float numpy array of shape [n, 5]. 104 | """ 105 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 106 | w = x2 - x1 + 1.0 107 | h = y2 - y1 + 1.0 108 | w = np.expand_dims(w, 1) 109 | h = np.expand_dims(h, 1) 110 | 111 | # this is what happening here: 112 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 113 | # x1_true = x1 + tx1*w 114 | # y1_true = y1 + ty1*h 115 | # x2_true = x2 + tx2*w 116 | # y2_true = y2 + ty2*h 117 | # below is just more compact form of this 118 | 119 | # are offsets always such that 120 | # x1 < x2 and y1 < y2 ? 121 | 122 | translation = np.hstack([w, h, w, h]) * offsets 123 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 124 | return bboxes 125 | 126 | 127 | def get_image_boxes(bounding_boxes, img, size=24): 128 | """Cut out boxes from the image. 129 | 130 | Arguments: 131 | bounding_boxes: a float numpy array of shape [n, 5]. 132 | img: an instance of PIL.Image. 133 | size: an integer, size of cutouts. 134 | 135 | Returns: 136 | a float numpy array of shape [n, 3, size, size]. 137 | """ 138 | 139 | num_boxes = len(bounding_boxes) 140 | width, height = img.size 141 | 142 | [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) 143 | img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') 144 | 145 | for i in range(num_boxes): 146 | img_box = np.zeros((h[i], w[i], 3), 'uint8') 147 | 148 | img_array = np.asarray(img, 'uint8') 149 | img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \ 150 | img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] 151 | 152 | # resize 153 | img_box = Image.fromarray(img_box) 154 | img_box = img_box.resize((size, size), Image.BILINEAR) 155 | img_box = np.asarray(img_box, 'float32') 156 | 157 | img_boxes[i, :, :, :] = _preprocess(img_box) 158 | 159 | return img_boxes 160 | 161 | 162 | def correct_bboxes(bboxes, width, height): 163 | """Crop boxes that are too big and get coordinates 164 | with respect to cutouts. 165 | 166 | Arguments: 167 | bboxes: a float numpy array of shape [n, 5], 168 | where each row is (xmin, ymin, xmax, ymax, score). 169 | width: a float number. 170 | height: a float number. 171 | 172 | Returns: 173 | dy, dx, edy, edx: a int numpy arrays of shape [n], 174 | coordinates of the boxes with respect to the cutouts. 175 | y, x, ey, ex: a int numpy arrays of shape [n], 176 | corrected ymin, xmin, ymax, xmax. 177 | h, w: a int numpy arrays of shape [n], 178 | just heights and widths of boxes. 179 | 180 | in the following order: 181 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 182 | """ 183 | 184 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 185 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 186 | num_boxes = bboxes.shape[0] 187 | 188 | # 'e' stands for end 189 | # (x, y) -> (ex, ey) 190 | x, y, ex, ey = x1, y1, x2, y2 191 | 192 | # we need to cut out a box from the image. 193 | # (x, y, ex, ey) are corrected coordinates of the box 194 | # in the image. 195 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 196 | # from the image. 197 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 198 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 199 | 200 | # if box's bottom right corner is too far right 201 | ind = np.where(ex > width - 1.0)[0] 202 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 203 | ex[ind] = width - 1.0 204 | 205 | # if box's bottom right corner is too low 206 | ind = np.where(ey > height - 1.0)[0] 207 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 208 | ey[ind] = height - 1.0 209 | 210 | # if box's top left corner is too far left 211 | ind = np.where(x < 0.0)[0] 212 | dx[ind] = 0.0 - x[ind] 213 | x[ind] = 0.0 214 | 215 | # if box's top left corner is too high 216 | ind = np.where(y < 0.0)[0] 217 | dy[ind] = 0.0 - y[ind] 218 | y[ind] = 0.0 219 | 220 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 221 | return_list = [i.astype('int32') for i in return_list] 222 | 223 | return return_list 224 | 225 | 226 | def _preprocess(img): 227 | """Preprocessing step before feeding the network. 228 | 229 | Arguments: 230 | img: a float numpy array of shape [h, w, c]. 231 | 232 | Returns: 233 | a float numpy array of shape [1, c, h, w]. 234 | """ 235 | img = img.transpose((2, 0, 1)) 236 | img = np.expand_dims(img, 0) 237 | img = (img - 127.5) * 0.0078125 238 | return img 239 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = 'cuda:0' 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, 'float32') 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack([ 93 | np.round((stride * inds[1] + 1.0) / scale), 94 | np.round((stride * inds[0] + 1.0) / scale), 95 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 96 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 97 | score, offsets 98 | ]) 99 | # why one is added? 100 | 101 | return bounding_boxes.T 102 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/__init__.py -------------------------------------------------------------------------------- /notebooks/images/afhq_wild_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/afhq_wild_image.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/affleck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/affleck.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/bezos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/bezos.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/blunt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/blunt.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/damon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/damon.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/dicaprio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/dicaprio.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/downey.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/downey.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/driver.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/driver.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/jackson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/jackson.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/johansson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/johansson.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/kunis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/kunis.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/pitt.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/robbie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/robbie.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/stone.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/stone.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/watson.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/watson.jpg -------------------------------------------------------------------------------- /notebooks/images/animations/zuckerberg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/animations/zuckerberg.jpg -------------------------------------------------------------------------------- /notebooks/images/car_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/car_image.jpg -------------------------------------------------------------------------------- /notebooks/images/domain_adaptation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/domain_adaptation.jpg -------------------------------------------------------------------------------- /notebooks/images/face_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/notebooks/images/face_image.jpg -------------------------------------------------------------------------------- /notebooks/notebook_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydrive.auth import GoogleAuth 3 | from pydrive.drive import GoogleDrive 4 | from google.colab import auth 5 | from oauth2client.client import GoogleCredentials 6 | 7 | 8 | HYPERSTYLE_PATHS = { 9 | "faces": {"id": "1C3dEIIH1y8w1-zQMCyx7rDF0ndswSXh4", "name": "hyperstyle_ffhq.pt"}, 10 | "cars": {"id": "1WZ7iNv5ENmxXFn6dzPeue1jQGNp6Nr9d", "name": "hyperstyle_cars.pt"}, 11 | "afhq_wild": {"id": "1OMAKYRp3T6wzGr0s3887rQK-5XHlJ2gp", "name": "hyperstyle_afhq_wild.pt"} 12 | } 13 | W_ENCODERS_PATHS = { 14 | "faces": {"id": "1M-hsL3W_cJKs77xM1mwq2e9-J0_m7rHP", "name": "faces_w_encoder.pt"}, 15 | "cars": {"id": "1GZke8pfXMSZM9mfT-AbP1Csyddf5fas7", "name": "cars_w_encoder.pt"}, 16 | "afhq_wild": {"id": "1MhEHGgkTpnTanIwuHYv46i6MJeet2Nlr", "name": "afhq_wild_w_encoder.pt"} 17 | } 18 | FINETUNED_MODELS = { 19 | "toonify": {'id': '1r3XVCt_WYUKFZFxhNH-xO2dTtF6B5szu', 'name': 'toonify.pt'}, 20 | "pixar": {'id': '1trPW-To9L63x5gaXrbAIPkOU0q9f_h05', 'name': 'pixar.pt'}, 21 | "sketch": {'id': '1aHhzmxT7eD90txAN93zCl8o9CUVbMFnD', 'name': 'sketch.pt'}, 22 | "disney_princess": {'id': '1rXHZu4Vd0l_KCiCxGbwL9Xtka7n3S2NB', 'name': 'disney_princess.pt'} 23 | } 24 | RESTYLE_E4E_MODELS = {'id': '1e2oXVeBPXMQoUoC_4TNwAWpOPpSEhE_e', 'name': 'restlye_e4e.pt'} 25 | 26 | 27 | class Downloader: 28 | 29 | def __init__(self, code_dir, use_pydrive): 30 | self.use_pydrive = use_pydrive 31 | current_directory = os.getcwd() 32 | self.save_dir = os.path.join(os.path.dirname(current_directory), code_dir, "pretrained_models") 33 | os.makedirs(self.save_dir, exist_ok=True) 34 | if self.use_pydrive: 35 | self.authenticate() 36 | 37 | def authenticate(self): 38 | auth.authenticate_user() 39 | gauth = GoogleAuth() 40 | gauth.credentials = GoogleCredentials.get_application_default() 41 | self.drive = GoogleDrive(gauth) 42 | 43 | def download_file(self, file_id, file_name): 44 | file_dst = f'{self.save_dir}/{file_name}' 45 | if os.path.exists(file_dst): 46 | print(f'{file_name} already exists!') 47 | return 48 | if self.use_pydrive: 49 | downloaded = self.drive.CreateFile({'id':file_id}) 50 | downloaded.FetchMetadata(fetch_all=True) 51 | downloaded.GetContentFile(file_dst) 52 | else: 53 | os.system(f"gdown --id {file_id} -O {file_dst}") 54 | 55 | 56 | def run_alignment(image_path): 57 | import dlib 58 | from scripts.align_faces_parallel import align_face 59 | if not os.path.exists("shape_predictor_68_face_landmarks.dat"): 60 | print('Downloading files for aligning face image...') 61 | os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2') 62 | os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2') 63 | print('Done.') 64 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 65 | aligned_image = align_face(filepath=image_path, predictor=predictor) 66 | print(f"Finished running alignment on image: {image_path}") 67 | return aligned_image 68 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/options/__init__.py -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from configs.paths_config import model_paths 4 | 5 | 6 | class TestOptions: 7 | 8 | def __init__(self): 9 | self.parser = ArgumentParser() 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # arguments for inference script 14 | self.parser.add_argument('--exp_dir', type=str, 15 | help='Path to experiment output directory') 16 | self.parser.add_argument('--checkpoint_path', default=None, type=str, 17 | help='Path to HyperStyle model checkpoint') 18 | self.parser.add_argument('--data_path', type=str, default='gt_images', 19 | help='Path to directory of images to evaluate') 20 | self.parser.add_argument('--resize_outputs', action='store_true', 21 | help='Whether to resize outputs to 256x256 or keep at original output resolution') 22 | self.parser.add_argument('--test_batch_size', default=2, type=int, 23 | help='Batch size for testing and inference') 24 | self.parser.add_argument('--test_workers', default=2, type=int, 25 | help='Number of test/inference dataloader workers') 26 | self.parser.add_argument('--n_images', type=int, default=None, 27 | help='Number of images to output. If None, run on all data') 28 | self.parser.add_argument('--save_weight_deltas', action='store_true', 29 | help='Whether to save the weight deltas of each image. Note: file weighs about 200MB.') 30 | 31 | # arguments for iterative inference 32 | self.parser.add_argument('--n_iters_per_batch', default=5, type=int, 33 | help='Number of forward passes per batch during training.') 34 | 35 | # arguments for loading pre-trained encoder 36 | self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.') 37 | self.parser.add_argument('--w_encoder_checkpoint_path', default=model_paths["faces_w_encoder"], type=str, 38 | help='Path to pre-trained W-encoder.') 39 | self.parser.add_argument('--w_encoder_type', default='WEncoder', 40 | help='Encoder type for the encoder used to get the initial inversion') 41 | 42 | # arguments for editing scripts 43 | self.parser.add_argument('--edit_directions', default='age,smile,pose', help='which edit directions top perform') 44 | self.parser.add_argument('--factor_range', type=int, default=5, help='max range for interfacegan edits.') 45 | 46 | # arguments for domain adaptation 47 | self.parser.add_argument('--restyle_checkpoint_path', default=model_paths["restyle_e4e_ffhq"], type=str, 48 | help='ReStyle e4e checkpoint path used for domain adaptation') 49 | self.parser.add_argument('--restyle_n_iterations', default=2, type=int, 50 | help='Number of forward passes per batch for ReStyle-e4e inference.') 51 | self.parser.add_argument('--finetuned_generator_checkpoint_path', type=str, default=model_paths["stylegan_pixar"], 52 | help='Path to fine-tuned generator checkpoint used for domain adaptation.') 53 | 54 | def parse(self): 55 | opts = self.parser.parse_args() 56 | return opts 57 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from configs.paths_config import model_paths 4 | 5 | 6 | class TrainOptions: 7 | 8 | def __init__(self): 9 | self.parser = ArgumentParser() 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # general setup 14 | self.parser.add_argument('--exp_dir', type=str, 15 | help='Path to experiment output directory') 16 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, 17 | help='Type of dataset/experiment to run') 18 | self.parser.add_argument('--encoder_type', default='HyperNet', type=str, 19 | help='Which encoder to use') 20 | self.parser.add_argument('--input_nc', default=6, type=int, 21 | help='Number of input image channels to the HyperStyle network. Should be set to 6.') 22 | self.parser.add_argument('--output_size', default=1024, type=int, 23 | help='Output size of generator') 24 | 25 | # batch size and dataloader works 26 | self.parser.add_argument('--batch_size', default=4, type=int, 27 | help='Batch size for training') 28 | self.parser.add_argument('--test_batch_size', default=2, type=int, 29 | help='Batch size for testing and inference') 30 | self.parser.add_argument('--workers', default=4, type=int, 31 | help='Number of train dataloader workers') 32 | self.parser.add_argument('--test_workers', default=2, type=int, 33 | help='Number of test/inference dataloader workers') 34 | 35 | # optimizers 36 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, 37 | help='Optimizer learning rate') 38 | self.parser.add_argument('--optim_name', default='ranger', type=str, 39 | help='Which optimizer to use') 40 | self.parser.add_argument('--train_decoder', default=False, type=bool, 41 | help='Whether to train the decoder model') 42 | 43 | # loss lambdas 44 | self.parser.add_argument('--lpips_lambda', default=0, type=float, 45 | help='LPIPS loss multiplier factor') 46 | self.parser.add_argument('--id_lambda', default=0, type=float, 47 | help='ID loss multiplier factor') 48 | self.parser.add_argument('--l2_lambda', default=0, type=float, 49 | help='L2 loss multiplier factor') 50 | self.parser.add_argument('--moco_lambda', default=0, type=float, 51 | help='Moco feature loss multiplier factor') 52 | 53 | # weights and checkpoint paths 54 | self.parser.add_argument('--stylegan_weights', default=model_paths["stylegan_ffhq"], type=str, 55 | help='Path to StyleGAN model weights') 56 | self.parser.add_argument('--checkpoint_path', default=None, type=str, 57 | help='Path to HyperStyle model checkpoint') 58 | 59 | # intervals for logging, validation, and saving 60 | self.parser.add_argument('--max_steps', default=500000, type=int, 61 | help='Maximum number of training steps') 62 | self.parser.add_argument('--max_val_batches', type=int, default=None, 63 | help='Number of batches to run validation on. If None, run on all batches.') 64 | self.parser.add_argument('--image_interval', default=100, type=int, 65 | help='Interval for logging train images during training') 66 | self.parser.add_argument('--board_interval', default=50, type=int, 67 | help='Interval for logging metrics to tensorboard') 68 | self.parser.add_argument('--val_interval', default=1000, type=int, 69 | help='Validation interval') 70 | self.parser.add_argument('--save_interval', default=None, type=int, 71 | help='Model checkpoint interval') 72 | 73 | # arguments for iterative encoding 74 | self.parser.add_argument('--n_iters_per_batch', default=1, type=int, 75 | help='Number of forward passes per batch during training') 76 | 77 | # hypernet parameters 78 | self.parser.add_argument('--load_w_encoder', action='store_true', help='Whether to load the w e4e encoder.') 79 | self.parser.add_argument('--w_encoder_checkpoint_path', default=model_paths["e4e_w_encoder"], type=str, 80 | help='Path to pre-trained W-encoder.') 81 | self.parser.add_argument('--w_encoder_type', default='WEncoder', 82 | help='Encoder type for the encoder used to get the initial inversion') 83 | self.parser.add_argument('--layers_to_tune', default='0,2,3,5,6,8,9,11,12,14,15,17,18,20,21,23,24', type=str, 84 | help='comma-separated list of which layers of the StyleGAN generator to tune') 85 | 86 | def parse(self): 87 | opts = self.parser.parse_args() 88 | return opts 89 | -------------------------------------------------------------------------------- /scripts/align_faces_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) 3 | author: lzhbrian (https://lzhbrian.me) 4 | date: 2020.1.5 5 | note: code is heavily borrowed from 6 | https://github.com/NVlabs/ffhq-dataset 7 | http://dlib.net/face_landmark_detection.py.html 8 | 9 | requirements: 10 | apt install cmake 11 | conda install Pillow numpy scipy 12 | pip install dlib 13 | # download face landmark model from: 14 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 15 | """ 16 | from argparse import ArgumentParser 17 | import time 18 | import numpy as np 19 | import PIL 20 | import PIL.Image 21 | import os 22 | import scipy 23 | import scipy.ndimage 24 | import dlib 25 | import multiprocessing as mp 26 | import math 27 | import sys 28 | 29 | sys.path.append(".") 30 | sys.path.append("..") 31 | 32 | from configs.paths_config import model_paths 33 | SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"] 34 | 35 | 36 | def get_landmark(filepath, predictor): 37 | """get landmark with dlib 38 | :return: np.array shape=(68, 2) 39 | """ 40 | detector = dlib.get_frontal_face_detector() 41 | 42 | img = dlib.load_rgb_image(filepath) 43 | dets = detector(img, 1) 44 | 45 | for k, d in enumerate(dets): 46 | shape = predictor(img, d) 47 | 48 | t = list(shape.parts()) 49 | a = [] 50 | for tt in t: 51 | a.append([tt.x, tt.y]) 52 | lm = np.array(a) 53 | return lm 54 | 55 | 56 | def align_face(filepath, predictor): 57 | """ 58 | :param filepath: str 59 | :return: PIL Image 60 | """ 61 | 62 | lm = get_landmark(filepath, predictor) 63 | 64 | lm_chin = lm[0: 17] # left-right 65 | lm_eyebrow_left = lm[17: 22] # left-right 66 | lm_eyebrow_right = lm[22: 27] # left-right 67 | lm_nose = lm[27: 31] # top-down 68 | lm_nostrils = lm[31: 36] # top-down 69 | lm_eye_left = lm[36: 42] # left-clockwise 70 | lm_eye_right = lm[42: 48] # left-clockwise 71 | lm_mouth_outer = lm[48: 60] # left-clockwise 72 | lm_mouth_inner = lm[60: 68] # left-clockwise 73 | 74 | # Calculate auxiliary vectors. 75 | eye_left = np.mean(lm_eye_left, axis=0) 76 | eye_right = np.mean(lm_eye_right, axis=0) 77 | eye_avg = (eye_left + eye_right) * 0.5 78 | eye_to_eye = eye_right - eye_left 79 | mouth_left = lm_mouth_outer[0] 80 | mouth_right = lm_mouth_outer[6] 81 | mouth_avg = (mouth_left + mouth_right) * 0.5 82 | eye_to_mouth = mouth_avg - eye_avg 83 | 84 | # Choose oriented crop rectangle. 85 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 86 | x /= np.hypot(*x) 87 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 88 | y = np.flipud(x) * [-1, 1] 89 | c = eye_avg + eye_to_mouth * 0.1 90 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 91 | qsize = np.hypot(*x) * 2 92 | 93 | # read image 94 | img = PIL.Image.open(filepath) 95 | 96 | output_size = 1024 97 | transform_size = 1024 98 | enable_padding = True 99 | 100 | # Shrink. 101 | shrink = int(np.floor(qsize / output_size * 0.5)) 102 | if shrink > 1: 103 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 104 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 105 | quad /= shrink 106 | qsize /= shrink 107 | 108 | # Crop. 109 | border = max(int(np.rint(qsize * 0.1)), 3) 110 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 111 | int(np.ceil(max(quad[:, 1])))) 112 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 113 | min(crop[3] + border, img.size[1])) 114 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 115 | img = img.crop(crop) 116 | quad -= crop[0:2] 117 | 118 | # Pad. 119 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 120 | int(np.ceil(max(quad[:, 1])))) 121 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 122 | max(pad[3] - img.size[1] + border, 0)) 123 | if enable_padding and max(pad) > border - 4: 124 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 125 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 126 | h, w, _ = img.shape 127 | y, x, _ = np.ogrid[:h, :w, :1] 128 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 129 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 130 | blur = qsize * 0.02 131 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 132 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 133 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 134 | quad += pad[:2] 135 | 136 | # Transform. 137 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 138 | if output_size < transform_size: 139 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 140 | 141 | # Save aligned image. 142 | return img 143 | 144 | 145 | def chunks(lst, n): 146 | """Yield successive n-sized chunks from lst.""" 147 | for i in range(0, len(lst), n): 148 | yield lst[i:i + n] 149 | 150 | 151 | def extract_on_paths(file_paths): 152 | predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) 153 | pid = mp.current_process().name 154 | print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths))) 155 | tot_count = len(file_paths) 156 | count = 0 157 | for file_path, res_path in file_paths: 158 | count += 1 159 | if count % 100 == 0: 160 | print('{} done with {}/{}'.format(pid, count, tot_count)) 161 | try: 162 | res = align_face(file_path, predictor) 163 | res = res.convert('RGB') 164 | os.makedirs(os.path.dirname(res_path), exist_ok=True) 165 | res.save(res_path) 166 | except Exception as e: 167 | print(f"Failed on image: {file_path} - {e}") 168 | continue 169 | print('\tDone!') 170 | 171 | 172 | def parse_args(): 173 | parser = ArgumentParser(add_help=False) 174 | parser.add_argument('--num_threads', type=int, default=1) 175 | parser.add_argument('--root_path', type=str, default='') 176 | args = parser.parse_args() 177 | return args 178 | 179 | 180 | def run(args): 181 | root_path = args.root_path 182 | out_crops_path = os.path.join(root_path, 'aligned') 183 | if not os.path.exists(out_crops_path): 184 | os.makedirs(out_crops_path, exist_ok=True) 185 | 186 | file_paths = [] 187 | for file in os.listdir(root_path): 188 | file_path = os.path.join(root_path, file) 189 | fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path)) 190 | res_path = '{}.jpg'.format(os.path.splitext(fname)[0]) 191 | if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path): 192 | print(res_path) 193 | continue 194 | file_paths.append((file_path, res_path)) 195 | 196 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 197 | print(len(file_chunks)) 198 | pool = mp.Pool(args.num_threads) 199 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 200 | tic = time.time() 201 | pool.map(extract_on_paths, file_chunks) 202 | toc = time.time() 203 | print('Mischief managed in {}s'.format(toc - tic)) 204 | 205 | 206 | if __name__ == '__main__': 207 | args = parse_args() 208 | run(args) 209 | -------------------------------------------------------------------------------- /scripts/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | import sys 7 | from PIL import Image 8 | import multiprocessing as mp 9 | import math 10 | import torch 11 | import torchvision.transforms as trans 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from models.mtcnn.mtcnn import MTCNN 17 | from models.encoders.model_irse import IR_101 18 | from configs.paths_config import model_paths 19 | CIRCULAR_FACE_PATH = model_paths['curricular_face'] 20 | 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | 28 | def extract_on_paths(file_paths): 29 | facenet = IR_101(input_size=112) 30 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 31 | facenet.cuda() 32 | facenet.eval() 33 | mtcnn = MTCNN() 34 | id_transform = trans.Compose([ 35 | trans.ToTensor(), 36 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 37 | ]) 38 | 39 | pid = mp.current_process().name 40 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 41 | tot_count = len(file_paths) 42 | count = 0 43 | 44 | scores_dict = {} 45 | for res_path, gt_path in file_paths: 46 | count += 1 47 | if count % 100 == 0: 48 | print('{} done with {}/{}'.format(pid, count, tot_count)) 49 | if True: 50 | input_im = Image.open(res_path) 51 | input_im, _ = mtcnn.align(input_im) 52 | if input_im is None: 53 | print('{} skipping {}'.format(pid, res_path)) 54 | continue 55 | 56 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 57 | 58 | result_im = Image.open(gt_path) 59 | result_im, _ = mtcnn.align(result_im) 60 | if result_im is None: 61 | print('{} skipping {}'.format(pid, gt_path)) 62 | continue 63 | 64 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 65 | score = float(input_id.dot(result_id)) 66 | scores_dict[os.path.basename(gt_path)] = score 67 | 68 | return scores_dict 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser(add_help=False) 73 | parser.add_argument('--num_threads', type=int, default=4) 74 | parser.add_argument('--output_path', type=str, default='inference_results', help='path to inference outputs') 75 | parser.add_argument('--gt_path', type=str, default='gt_images', help='path to gt images') 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | def run(args): 81 | for step in sorted(os.listdir(args.output_path)): 82 | if not step.isdigit(): 83 | continue 84 | step_outputs_path = os.path.join(args.output_path, step) 85 | if os.path.isdir(step_outputs_path): 86 | print('#' * 80) 87 | print(f'Running on step: {step}') 88 | print('#' * 80) 89 | run_on_step_output(step=step, args=args) 90 | 91 | 92 | def run_on_step_output(step, args): 93 | file_paths = [] 94 | step_outputs_path = os.path.join(args.output_path, step) 95 | for f in os.listdir(step_outputs_path): 96 | image_path = os.path.join(step_outputs_path, f) 97 | gt_path = os.path.join(args.gt_path, f) 98 | if not os.path.exists(gt_path): 99 | continue 100 | if f.endswith(".jpg") or f.endswith('.png') or f.endswith('.jpeg'): 101 | file_paths.append([image_path, gt_path]) 102 | 103 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 104 | pool = mp.Pool(args.num_threads) 105 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 106 | 107 | tic = time.time() 108 | results = pool.map(extract_on_paths, file_chunks) 109 | scores_dict = {} 110 | for d in results: 111 | scores_dict.update(d) 112 | 113 | all_scores = list(scores_dict.values()) 114 | mean = np.mean(all_scores) 115 | std = np.std(all_scores) 116 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 117 | print(result_str) 118 | 119 | out_path = os.path.join(os.path.dirname(args.output_path), 'inference_metrics') 120 | if not os.path.exists(out_path): 121 | os.makedirs(out_path) 122 | 123 | with open(os.path.join(out_path, f'stat_id.txt'), 'w') as f: 124 | f.write(result_str) 125 | with open(os.path.join(out_path, f'scores_id.json'), 'w') as f: 126 | json.dump(scores_dict, f) 127 | 128 | toc = time.time() 129 | print('Mischief managed in {}s'.format(toc - tic)) 130 | 131 | 132 | if __name__ == '__main__': 133 | args = parse_args() 134 | run(args) 135 | -------------------------------------------------------------------------------- /scripts/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import json 4 | import sys 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from criteria.lpips.lpips import LPIPS 15 | from criteria.ms_ssim import MSSSIM 16 | from datasets.gt_res_dataset import GTResDataset 17 | 18 | 19 | def parse_args(): 20 | parser = ArgumentParser(add_help=False) 21 | parser.add_argument('--metrics', type=str, default='lpips,l2,msssim') 22 | parser.add_argument('--output_path', type=str, default='results') 23 | parser.add_argument('--gt_path', type=str, default='gt_images') 24 | parser.add_argument('--workers', type=int, default=4) 25 | parser.add_argument('--batch_size', type=int, default=4) 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def run(args): 31 | for metric in args.metrics.split(','): 32 | args.metric = metric 33 | for step in sorted(os.listdir(args.output_path)): 34 | if not step.isdigit(): 35 | continue 36 | step_outputs_path = os.path.join(args.output_path, step) 37 | if os.path.isdir(step_outputs_path): 38 | print('#' * 80) 39 | print(f'Computing {args.metric} on step: {step}') 40 | print('#' * 80) 41 | run_on_step_output(step=step, args=args) 42 | 43 | 44 | def run_on_step_output(step, args): 45 | 46 | transform = transforms.Compose([transforms.Resize((256, 256)), 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 49 | 50 | step_outputs_path = os.path.join(args.output_path, step) 51 | 52 | print('Loading dataset') 53 | dataset = GTResDataset(root_path=step_outputs_path, 54 | gt_dir=args.gt_path, 55 | transform=transform) 56 | 57 | dataloader = DataLoader(dataset, 58 | batch_size=args.batch_size, 59 | shuffle=False, 60 | num_workers=int(args.workers), 61 | drop_last=True) 62 | 63 | if args.metric == 'lpips': 64 | loss_func = LPIPS(net_type='alex') 65 | elif args.metric == 'l2': 66 | loss_func = torch.nn.MSELoss() 67 | elif args.metric == 'msssim': 68 | loss_func = MSSSIM() 69 | else: 70 | raise Exception(f'Not a valid metric: {args.metric}!') 71 | 72 | loss_func.cuda() 73 | 74 | global_i = 0 75 | scores_dict = {} 76 | all_scores = [] 77 | for result_batch, gt_batch in tqdm(dataloader): 78 | for i in range(args.batch_size): 79 | loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda())) 80 | all_scores.append(loss) 81 | im_path = dataset.pairs[global_i][0] 82 | scores_dict[os.path.basename(im_path)] = loss 83 | global_i += 1 84 | 85 | all_scores = list(scores_dict.values()) 86 | all_scores = [s for s in all_scores if not np.isnan(s)] 87 | mean = np.mean(all_scores) 88 | std = np.std(all_scores) 89 | result_str = 'Average loss is {:.5f}+-{:.5f}'.format(mean, std) 90 | print('Finished with ', step_outputs_path) 91 | print(result_str) 92 | 93 | out_path = os.path.join(os.path.dirname(args.output_path), 'inference_metrics') 94 | if not os.path.exists(out_path): 95 | os.makedirs(out_path) 96 | 97 | with open(os.path.join(out_path, f'stat_{args.metric}.txt'), 'w') as f: 98 | f.write(result_str) 99 | with open(os.path.join(out_path, f'scores_{args.metric}.json'), 'w') as f: 100 | json.dump(scores_dict, f) 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | run(args) 106 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | 10 | import sys 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs import data_configs 15 | from datasets.inference_dataset import InferenceDataset 16 | from utils.common import tensor2im 17 | from utils.inference_utils import run_inversion 18 | from utils.model_utils import load_model 19 | from options.test_options import TestOptions 20 | 21 | 22 | def run(): 23 | test_opts = TestOptions().parse() 24 | 25 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 26 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') 27 | 28 | os.makedirs(out_path_results, exist_ok=True) 29 | os.makedirs(out_path_coupled, exist_ok=True) 30 | 31 | # update test options with options used during training 32 | net, opts = load_model(test_opts.checkpoint_path, update_opts=test_opts) 33 | 34 | print('Loading dataset for {}'.format(opts.dataset_type)) 35 | dataset_args = data_configs.DATASETS[opts.dataset_type] 36 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 37 | dataset = InferenceDataset(root=opts.data_path, 38 | transform=transforms_dict['transform_inference'], 39 | opts=opts) 40 | dataloader = DataLoader(dataset, 41 | batch_size=opts.test_batch_size, 42 | shuffle=False, 43 | num_workers=int(opts.test_workers), 44 | drop_last=False) 45 | 46 | if opts.n_images is None: 47 | opts.n_images = len(dataset) 48 | 49 | if "cars" in opts.dataset_type: 50 | resize_amount = (256, 192) if opts.resize_outputs else (512, 384) 51 | else: 52 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 53 | 54 | global_i = 0 55 | global_time = [] 56 | all_latents = {} 57 | for input_batch in tqdm(dataloader): 58 | 59 | if global_i >= opts.n_images: 60 | break 61 | 62 | with torch.no_grad(): 63 | input_cuda = input_batch.cuda().float() 64 | tic = time.time() 65 | result_batch, result_latents, result_deltas = run_inversion(input_cuda, net, opts, 66 | return_intermediate_results=True) 67 | toc = time.time() 68 | global_time.append(toc - tic) 69 | 70 | for i in range(input_batch.shape[0]): 71 | results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)] 72 | im_path = dataset.paths[global_i] 73 | 74 | input_im = tensor2im(input_batch[i]) 75 | res = np.array(input_im.resize(resize_amount)) 76 | for idx, result in enumerate(results): 77 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 78 | # save individual outputs 79 | save_dir = os.path.join(out_path_results, str(idx)) 80 | os.makedirs(save_dir, exist_ok=True) 81 | result.resize(resize_amount).save(os.path.join(save_dir, os.path.basename(im_path))) 82 | 83 | # save coupled image with side-by-side results 84 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 85 | 86 | all_latents[os.path.basename(im_path)] = result_latents[i][0] 87 | 88 | if opts.save_weight_deltas: 89 | weight_deltas_dir = os.path.join(test_opts.exp_dir, "weight_deltas") 90 | os.makedirs(weight_deltas_dir, exist_ok=True) 91 | np.save(os.path.join(weight_deltas_dir, os.path.basename(im_path).split('.')[0] + ".npy"), 92 | result_deltas[i][-1]) 93 | 94 | global_i += 1 95 | 96 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 97 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 98 | print(result_str) 99 | 100 | with open(stats_path, 'w') as f: 101 | f.write(result_str) 102 | 103 | # save all latents as npy file 104 | np.save(os.path.join(test_opts.exp_dir, 'latents.npy'), all_latents) 105 | 106 | 107 | if __name__ == '__main__': 108 | run() 109 | -------------------------------------------------------------------------------- /scripts/run_domain_adaptation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import DataLoader 7 | 8 | import sys 9 | sys.path.extend([".", ".."]) 10 | 11 | from configs import data_configs 12 | from datasets.inference_dataset import InferenceDataset 13 | from options.test_options import TestOptions 14 | from utils.common import tensor2im 15 | from utils.domain_adaptation_utils import run_domain_adaptation 16 | from utils.model_utils import load_model, load_generator 17 | 18 | 19 | def run(): 20 | test_opts = TestOptions().parse() 21 | 22 | out_path_results = os.path.join(test_opts.exp_dir, 'domain_adaptation_results') 23 | out_path_coupled = os.path.join(test_opts.exp_dir, 'domain_adaptation_coupled') 24 | 25 | os.makedirs(out_path_results, exist_ok=True) 26 | os.makedirs(out_path_coupled, exist_ok=True) 27 | 28 | # update test options with options used during training 29 | net, opts = load_model(test_opts.checkpoint_path, update_opts=test_opts) 30 | 31 | restyle_e4e, restyle_opts = load_model(test_opts.restyle_checkpoint_path, 32 | update_opts={"resize_outputs": test_opts.resize_outputs, 33 | "n_iters_per_batch": test_opts.restyle_n_iterations}, 34 | is_restyle_encoder=True) 35 | finetuned_generator = load_generator(test_opts.finetuned_generator_checkpoint_path) 36 | 37 | print('Loading dataset for {}'.format(opts.dataset_type)) 38 | dataset_args = data_configs.DATASETS[opts.dataset_type] 39 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 40 | dataset = InferenceDataset(root=opts.data_path, 41 | transform=transforms_dict['transform_inference'], 42 | opts=opts) 43 | dataloader = DataLoader(dataset, 44 | batch_size=opts.test_batch_size, 45 | shuffle=False, 46 | num_workers=int(opts.test_workers), 47 | drop_last=False) 48 | 49 | if opts.n_images is None: 50 | opts.n_images = len(dataset) 51 | 52 | global_i = 0 53 | for input_batch in tqdm(dataloader): 54 | 55 | if global_i >= opts.n_images: 56 | break 57 | 58 | with torch.no_grad(): 59 | input_cuda = input_batch.cuda().float() 60 | result_batch, _ = run_domain_adaptation(input_cuda, net, opts, finetuned_generator, 61 | restyle_e4e, restyle_opts) 62 | 63 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 64 | for i in range(input_batch.shape[0]): 65 | 66 | im_path = dataset.paths[global_i] 67 | 68 | curr_result = tensor2im(result_batch[i]) 69 | input_im = tensor2im(input_batch[i]) 70 | 71 | res_save_path = os.path.join(out_path_results, os.path.basename(im_path)) 72 | curr_result.resize(resize_amount).save(res_save_path) 73 | 74 | coupled_save_path = os.path.join(out_path_coupled, os.path.basename(im_path)) 75 | res = np.concatenate([np.array(input_im.resize(resize_amount)), np.array(curr_result.resize(resize_amount))], 76 | axis=1) 77 | Image.fromarray(res).save(coupled_save_path) 78 | global_i += 1 79 | 80 | 81 | if __name__ == '__main__': 82 | run() 83 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import json 5 | import os 6 | import sys 7 | import pprint 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from options.train_options import TrainOptions 13 | from training.coach_hyperstyle import Coach 14 | 15 | 16 | def main(): 17 | opts = TrainOptions().parse() 18 | create_initial_experiment_dir(opts) 19 | coach = Coach(opts) 20 | coach.train() 21 | 22 | 23 | def create_initial_experiment_dir(opts): 24 | os.makedirs(opts.exp_dir, exist_ok=True) 25 | opts_dict = vars(opts) 26 | pprint.pprint(opts_dict) 27 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 28 | json.dump(opts_dict, f, indent=4, sort_keys=True) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/training/__init__.py -------------------------------------------------------------------------------- /training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 152 | else: 153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuval-alaluf/hyperstyle/a723c7310312ca5b7f859340027f78e18965acbc/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def tensor2im(var): 6 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 7 | var = ((var + 1) / 2) 8 | var[var < 0] = 0 9 | var[var > 1] = 1 10 | var = var * 255 11 | return Image.fromarray(var.astype('uint8')) 12 | 13 | 14 | def vis_faces(log_hooks): 15 | display_count = len(log_hooks) 16 | n_outputs = len(log_hooks[0]['output_face']) if type(log_hooks[0]['output_face']) == list else 1 17 | fig = plt.figure(figsize=(8 + (n_outputs * 2), 3 * display_count)) 18 | gs = fig.add_gridspec(display_count, (3 + n_outputs)) 19 | for i in range(display_count): 20 | hooks_dict = log_hooks[i] 21 | fig.add_subplot(gs[i, 0]) 22 | vis_faces_iterative(hooks_dict, fig, gs, i) 23 | plt.tight_layout() 24 | return fig 25 | 26 | 27 | def vis_faces_with_id(hooks_dict, fig, gs, i): 28 | plt.imshow(hooks_dict['input_face']) 29 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 30 | fig.add_subplot(gs[i, 1]) 31 | plt.imshow(hooks_dict['target_face']) 32 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 33 | float(hooks_dict['diff_target']))) 34 | fig.add_subplot(gs[i, 2]) 35 | plt.imshow(hooks_dict['output_face']) 36 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 37 | 38 | 39 | def vis_faces_iterative(hooks_dict, fig, gs, i): 40 | plt.imshow(hooks_dict['input_face']) 41 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 42 | fig.add_subplot(gs[i, 1]) 43 | plt.imshow(hooks_dict['w_inversion']) 44 | plt.title('W-Inversion\n') 45 | fig.add_subplot(gs[i, 2]) 46 | plt.imshow(hooks_dict['target_face']) 47 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), float(hooks_dict['diff_target']))) 48 | for idx, output_idx in enumerate(range(len(hooks_dict['output_face']) - 1, -1, -1)): 49 | output_image, similarity = hooks_dict['output_face'][output_idx] 50 | fig.add_subplot(gs[i, 3 + idx]) 51 | plt.imshow(output_image) 52 | plt.title('Output {}\n Target Sim={:.2f}'.format(output_idx, float(similarity))) 53 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | for fname in os.listdir(dir): 21 | if is_image_file(fname): 22 | path = os.path.join(dir, fname) 23 | images.append(path) 24 | return images 25 | -------------------------------------------------------------------------------- /utils/domain_adaptation_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.extend(['.', '..']) 3 | 4 | from utils.inference_utils import run_inversion 5 | from utils import restyle_inference_utils 6 | 7 | 8 | def run_domain_adaptation(inputs, net, opts, fine_tuned_generator, restyle_e4e, restyle_opts): 9 | """ Combine restyle e4e's latent code with HyperStyle's predicted weight offsets. """ 10 | y_hat, latents = restyle_inference_utils.run_on_batch(inputs, restyle_e4e, restyle_opts) 11 | y_hat2, _, weights_deltas, _ = run_inversion(inputs, net, opts) 12 | weights_deltas = filter_non_ffhq_layers_in_toonify_model(weights_deltas) 13 | return fine_tuned_generator([latents], 14 | input_is_latent=True, 15 | randomize_noise=True, 16 | return_latents=True, 17 | weights_deltas=weights_deltas) 18 | 19 | 20 | def filter_non_ffhq_layers_in_toonify_model(weights_deltas): 21 | toonify_ffhq_layer_idx = [14, 15, 17, 18, 20, 21, 23, 24] # convs 8-15 according to model_utils.py 22 | for i in range(len(weights_deltas)): 23 | if weights_deltas[i] is not None and i not in toonify_ffhq_layer_idx: 24 | weights_deltas[i] = None 25 | return weights_deltas 26 | 27 | -------------------------------------------------------------------------------- /utils/inference_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def run_inversion(inputs, net, opts, return_intermediate_results=False): 5 | y_hat, latent, weights_deltas, codes = None, None, None, None 6 | 7 | if return_intermediate_results: 8 | results_batch = {idx: [] for idx in range(inputs.shape[0])} 9 | results_latent = {idx: [] for idx in range(inputs.shape[0])} 10 | results_deltas = {idx: [] for idx in range(inputs.shape[0])} 11 | else: 12 | results_batch, results_latent, results_deltas = None, None, None 13 | 14 | for iter in range(opts.n_iters_per_batch): 15 | y_hat, latent, weights_deltas, codes, _ = net.forward(inputs, 16 | y_hat=y_hat, 17 | codes=codes, 18 | weights_deltas=weights_deltas, 19 | return_latents=True, 20 | resize=opts.resize_outputs, 21 | randomize_noise=False, 22 | return_weight_deltas_and_codes=True) 23 | 24 | if "cars" in opts.dataset_type: 25 | if opts.resize_outputs: 26 | y_hat = y_hat[:, :, 32:224, :] 27 | else: 28 | y_hat = y_hat[:, :, 64:448, :] 29 | 30 | if return_intermediate_results: 31 | store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas) 32 | 33 | # resize input to 256 before feeding into next iteration 34 | if "cars" in opts.dataset_type: 35 | y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) 36 | else: 37 | y_hat = net.face_pool(y_hat) 38 | 39 | if return_intermediate_results: 40 | return results_batch, results_latent, results_deltas 41 | return y_hat, latent, weights_deltas, codes 42 | 43 | 44 | def store_intermediate_results(results_batch, results_latent, results_deltas, y_hat, latent, weights_deltas): 45 | for idx in range(y_hat.shape[0]): 46 | results_batch[idx].append(y_hat[idx]) 47 | results_latent[idx].append(latent[idx].cpu().numpy()) 48 | results_deltas[idx].append([w[idx].cpu().numpy() if w is not None else None for w in weights_deltas]) 49 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import Namespace 3 | 4 | import sys 5 | sys.path.extend(['.', '..']) 6 | 7 | from models.stylegan2.model import Generator 8 | from models.hyperstyle import HyperStyle 9 | from models.encoders.e4e import e4e 10 | 11 | 12 | def load_model(checkpoint_path, device='cuda', update_opts=None, is_restyle_encoder=False): 13 | ckpt = torch.load(checkpoint_path, map_location='cpu') 14 | opts = ckpt['opts'] 15 | 16 | opts['checkpoint_path'] = checkpoint_path 17 | opts['load_w_encoder'] = True 18 | 19 | if update_opts is not None: 20 | if type(update_opts) == dict: 21 | opts.update(update_opts) 22 | else: 23 | opts.update(vars(update_opts)) 24 | 25 | opts = Namespace(**opts) 26 | 27 | if is_restyle_encoder: 28 | net = e4e(opts) 29 | else: 30 | net = HyperStyle(opts) 31 | 32 | net.eval() 33 | net.to(device) 34 | return net, opts 35 | 36 | 37 | def load_generator(checkpoint_path, device='cuda'): 38 | print(f"Loading generator from checkpoint: {checkpoint_path}") 39 | generator = Generator(1024, 512, 8, channel_multiplier=2) 40 | ckpt = torch.load(checkpoint_path, map_location='cpu') 41 | generator.load_state_dict(ckpt['g_ema']) 42 | generator.eval() 43 | generator.to(device) 44 | return generator 45 | -------------------------------------------------------------------------------- /utils/resnet_mapping.py: -------------------------------------------------------------------------------- 1 | RESNET_MAPPING = { 2 | 'layer1.0': 'body.0', 3 | 'layer1.1': 'body.1', 4 | 'layer1.2': 'body.2', 5 | 'layer2.0': 'body.3', 6 | 'layer2.1': 'body.4', 7 | 'layer2.2': 'body.5', 8 | 'layer2.3': 'body.6', 9 | 'layer3.0': 'body.7', 10 | 'layer3.1': 'body.8', 11 | 'layer3.2': 'body.9', 12 | 'layer3.3': 'body.10', 13 | 'layer3.4': 'body.11', 14 | 'layer3.5': 'body.12', 15 | 'layer4.0': 'body.13', 16 | 'layer4.1': 'body.14', 17 | 'layer4.2': 'body.15', 18 | } -------------------------------------------------------------------------------- /utils/restyle_inference_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_average_image(net, opts): 5 | avg_image = net(net.latent_avg.unsqueeze(0), 6 | input_code=True, 7 | randomize_noise=False, 8 | return_latents=False, 9 | average_code=True)[0] 10 | avg_image = avg_image.to('cuda').float().detach() 11 | if "cars" in opts.dataset_type: 12 | avg_image = avg_image[:, 32:224, :] 13 | return avg_image 14 | 15 | 16 | def run_on_batch(inputs, net, opts): 17 | avg_image = get_average_image(net, opts) 18 | y_hat, latent = None, None 19 | for iter in range(opts.n_iters_per_batch): 20 | if iter == 0: 21 | avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 22 | x_input = torch.cat([inputs, avg_image_for_batch], dim=1) 23 | else: 24 | x_input = torch.cat([inputs, y_hat], dim=1) 25 | 26 | y_hat, latent = net.forward(x_input, 27 | latent=latent, 28 | randomize_noise=False, 29 | return_latents=True, 30 | resize=opts.resize_outputs) 31 | 32 | if "cars" in opts.dataset_type: 33 | if opts.resize_outputs: 34 | y_hat = y_hat[:, :, 32:224, :] 35 | else: 36 | y_hat = y_hat[:, :, 64:448, :] 37 | 38 | # resize input to 256 before feeding into next iteration 39 | if "cars" in opts.dataset_type: 40 | y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat) 41 | else: 42 | y_hat = net.face_pool(y_hat) 43 | 44 | return y_hat, latent 45 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | def aggregate_loss_dict(agg_loss_dict): 2 | mean_vals = {} 3 | for output in agg_loss_dict: 4 | for key in output: 5 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 6 | for key in mean_vals: 7 | if len(mean_vals[key]) > 0: 8 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 9 | else: 10 | print(f'{key} has no value') 11 | mean_vals[key] = 0 12 | return mean_vals 13 | --------------------------------------------------------------------------------