├── LICENSE ├── README.md ├── cog.yaml ├── configs ├── __init__.py ├── data_configs.py ├── paths_config.py └── transforms_config.py ├── criteria ├── __init__.py ├── id_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── moco_loss.py └── w_norm.py ├── datasets ├── __init__.py ├── augmentations.py ├── gt_res_dataset.py ├── images_dataset.py └── inference_dataset.py ├── docs ├── encoding_inputs.jpg ├── encoding_outputs.jpg ├── frontalization_inputs.jpg ├── frontalization_outputs.jpg ├── seg2image.png ├── sketch2image.png ├── super_res_32.jpg ├── super_res_style_mixing.jpg ├── teaser.png ├── toonify_input.jpg └── toonify_output.jpg ├── download-weights.sh ├── environment └── psp_env.yaml ├── licenses ├── LICENSE_HuangYG123 ├── LICENSE_S-aiueo32 ├── LICENSE_TreB1eN ├── LICENSE_lessw2020 └── LICENSE_rosinality ├── models ├── __init__.py ├── encoders │ ├── __init__.py │ ├── helpers.py │ ├── model_irse.py │ └── psp_encoders.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy ├── psp.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── notebooks ├── images │ ├── input_img.jpg │ ├── input_mask.png │ └── input_sketch.jpg └── inference_playground.ipynb ├── options ├── __init__.py ├── test_options.py └── train_options.py ├── predict.py ├── scripts ├── align_all_parallel.py ├── calc_id_loss_parallel.py ├── calc_losses_on_images.py ├── generate_sketch_data.py ├── inference.py ├── style_mixing.py └── train.py ├── training ├── __init__.py ├── coach.py └── ranger.py └── utils ├── __init__.py ├── common.py ├── data_utils.py ├── train_utils.py └── wandb_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | - "ninja-build" 8 | python_packages: 9 | - "cmake==3.21.2" 10 | - "torch==1.8.0" 11 | - "torchvision==0.9.0" 12 | - "numpy==1.21.1" 13 | - "ipython==7.21.0" 14 | - "tensorboard==2.6.0" 15 | - "tqdm==4.43.0" 16 | - "torch-optimizer==0.1.0" 17 | - "opencv-python==4.5.3.56" 18 | - "Pillow==8.3.2" 19 | - "matplotlib==3.2.1" 20 | - "scipy==1.7.1" 21 | run: 22 | - pip install dlib 23 | 24 | predict: "predict.py:Predictor" 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/configs/__init__.py -------------------------------------------------------------------------------- /configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from configs import transforms_config 2 | from configs.paths_config import dataset_paths 3 | 4 | 5 | DATASETS = { 6 | 'ffhq_encode': { 7 | 'transforms': transforms_config.EncodeTransforms, 8 | 'train_source_root': dataset_paths['ffhq'], 9 | 'train_target_root': dataset_paths['ffhq'], 10 | 'test_source_root': dataset_paths['celeba_test'], 11 | 'test_target_root': dataset_paths['celeba_test'], 12 | }, 13 | 'ffhq_frontalize': { 14 | 'transforms': transforms_config.FrontalizationTransforms, 15 | 'train_source_root': dataset_paths['ffhq'], 16 | 'train_target_root': dataset_paths['ffhq'], 17 | 'test_source_root': dataset_paths['celeba_test'], 18 | 'test_target_root': dataset_paths['celeba_test'], 19 | }, 20 | 'celebs_sketch_to_face': { 21 | 'transforms': transforms_config.SketchToImageTransforms, 22 | 'train_source_root': dataset_paths['celeba_train_sketch'], 23 | 'train_target_root': dataset_paths['celeba_train'], 24 | 'test_source_root': dataset_paths['celeba_test_sketch'], 25 | 'test_target_root': dataset_paths['celeba_test'], 26 | }, 27 | 'celebs_seg_to_face': { 28 | 'transforms': transforms_config.SegToImageTransforms, 29 | 'train_source_root': dataset_paths['celeba_train_segmentation'], 30 | 'train_target_root': dataset_paths['celeba_train'], 31 | 'test_source_root': dataset_paths['celeba_test_segmentation'], 32 | 'test_target_root': dataset_paths['celeba_test'], 33 | }, 34 | 'celebs_super_resolution': { 35 | 'transforms': transforms_config.SuperResTransforms, 36 | 'train_source_root': dataset_paths['celeba_train'], 37 | 'train_target_root': dataset_paths['celeba_train'], 38 | 'test_source_root': dataset_paths['celeba_test'], 39 | 'test_target_root': dataset_paths['celeba_test'], 40 | }, 41 | } 42 | -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | dataset_paths = { 2 | 'celeba_train': '', 3 | 'celeba_test': '', 4 | 'celeba_train_sketch': '', 5 | 'celeba_test_sketch': '', 6 | 'celeba_train_segmentation': '', 7 | 'celeba_test_segmentation': '', 8 | 'ffhq': '', 9 | } 10 | 11 | model_paths = { 12 | 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', 13 | 'ir_se50': 'pretrained_models/model_ir_se50.pth', 14 | 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth', 15 | 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy', 16 | 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy', 17 | 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy', 18 | 'shape_predictor': 'shape_predictor_68_face_landmarks.dat', 19 | 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar' 20 | } 21 | -------------------------------------------------------------------------------- /configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | from datasets import augmentations 4 | 5 | 6 | class TransformsConfig(object): 7 | 8 | def __init__(self, opts): 9 | self.opts = opts 10 | 11 | @abstractmethod 12 | def get_transforms(self): 13 | pass 14 | 15 | 16 | class EncodeTransforms(TransformsConfig): 17 | 18 | def __init__(self, opts): 19 | super(EncodeTransforms, self).__init__(opts) 20 | 21 | def get_transforms(self): 22 | transforms_dict = { 23 | 'transform_gt_train': transforms.Compose([ 24 | transforms.Resize((256, 256)), 25 | transforms.RandomHorizontalFlip(0.5), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 28 | 'transform_source': None, 29 | 'transform_test': transforms.Compose([ 30 | transforms.Resize((256, 256)), 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 33 | 'transform_inference': transforms.Compose([ 34 | transforms.Resize((256, 256)), 35 | transforms.ToTensor(), 36 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 37 | } 38 | return transforms_dict 39 | 40 | 41 | class FrontalizationTransforms(TransformsConfig): 42 | 43 | def __init__(self, opts): 44 | super(FrontalizationTransforms, self).__init__(opts) 45 | 46 | def get_transforms(self): 47 | transforms_dict = { 48 | 'transform_gt_train': transforms.Compose([ 49 | transforms.Resize((256, 256)), 50 | transforms.RandomHorizontalFlip(0.5), 51 | transforms.ToTensor(), 52 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 53 | 'transform_source': transforms.Compose([ 54 | transforms.Resize((256, 256)), 55 | transforms.RandomHorizontalFlip(0.5), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 58 | 'transform_test': transforms.Compose([ 59 | transforms.Resize((256, 256)), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 62 | 'transform_inference': transforms.Compose([ 63 | transforms.Resize((256, 256)), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 66 | } 67 | return transforms_dict 68 | 69 | 70 | class SketchToImageTransforms(TransformsConfig): 71 | 72 | def __init__(self, opts): 73 | super(SketchToImageTransforms, self).__init__(opts) 74 | 75 | def get_transforms(self): 76 | transforms_dict = { 77 | 'transform_gt_train': transforms.Compose([ 78 | transforms.Resize((256, 256)), 79 | transforms.ToTensor(), 80 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 81 | 'transform_source': transforms.Compose([ 82 | transforms.Resize((256, 256)), 83 | transforms.ToTensor()]), 84 | 'transform_test': transforms.Compose([ 85 | transforms.Resize((256, 256)), 86 | transforms.ToTensor(), 87 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 88 | 'transform_inference': transforms.Compose([ 89 | transforms.Resize((256, 256)), 90 | transforms.ToTensor()]), 91 | } 92 | return transforms_dict 93 | 94 | 95 | class SegToImageTransforms(TransformsConfig): 96 | 97 | def __init__(self, opts): 98 | super(SegToImageTransforms, self).__init__(opts) 99 | 100 | def get_transforms(self): 101 | transforms_dict = { 102 | 'transform_gt_train': transforms.Compose([ 103 | transforms.Resize((256, 256)), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 106 | 'transform_source': transforms.Compose([ 107 | transforms.Resize((256, 256)), 108 | augmentations.ToOneHot(self.opts.label_nc), 109 | transforms.ToTensor()]), 110 | 'transform_test': transforms.Compose([ 111 | transforms.Resize((256, 256)), 112 | transforms.ToTensor(), 113 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 114 | 'transform_inference': transforms.Compose([ 115 | transforms.Resize((256, 256)), 116 | augmentations.ToOneHot(self.opts.label_nc), 117 | transforms.ToTensor()]) 118 | } 119 | return transforms_dict 120 | 121 | 122 | class SuperResTransforms(TransformsConfig): 123 | 124 | def __init__(self, opts): 125 | super(SuperResTransforms, self).__init__(opts) 126 | 127 | def get_transforms(self): 128 | if self.opts.resize_factors is None: 129 | self.opts.resize_factors = '1,2,4,8,16,32' 130 | factors = [int(f) for f in self.opts.resize_factors.split(",")] 131 | print("Performing down-sampling with factors: {}".format(factors)) 132 | transforms_dict = { 133 | 'transform_gt_train': transforms.Compose([ 134 | transforms.Resize((256, 256)), 135 | transforms.ToTensor(), 136 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 137 | 'transform_source': transforms.Compose([ 138 | transforms.Resize((256, 256)), 139 | augmentations.BilinearResize(factors=factors), 140 | transforms.Resize((256, 256)), 141 | transforms.ToTensor(), 142 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 143 | 'transform_test': transforms.Compose([ 144 | transforms.Resize((256, 256)), 145 | transforms.ToTensor(), 146 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 147 | 'transform_inference': transforms.Compose([ 148 | transforms.Resize((256, 256)), 149 | augmentations.BilinearResize(factors=factors), 150 | transforms.Resize((256, 256)), 151 | transforms.ToTensor(), 152 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 153 | } 154 | return transforms_dict 155 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import model_paths 4 | from models.encoders.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def extract_feats(self, x): 17 | x = x[:, :, 35:223, 32:220] # Crop interesting region 18 | x = self.face_pool(x) 19 | x_feats = self.facenet(x) 20 | return x_feats 21 | 22 | def forward(self, y_hat, y, x): 23 | n_samples = x.shape[0] 24 | x_feats = self.extract_feats(x) 25 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 26 | y_hat_feats = self.extract_feats(y_hat) 27 | y_feats = y_feats.detach() 28 | loss = 0 29 | sim_improvement = 0 30 | id_logs = [] 31 | count = 0 32 | for i in range(n_samples): 33 | diff_target = y_hat_feats[i].dot(y_feats[i]) 34 | diff_input = y_hat_feats[i].dot(x_feats[i]) 35 | diff_views = y_feats[i].dot(x_feats[i]) 36 | id_logs.append({'diff_target': float(diff_target), 37 | 'diff_input': float(diff_input), 38 | 'diff_views': float(diff_views)}) 39 | loss += 1 - diff_target 40 | id_diff = float(diff_target) - float(diff_views) 41 | sim_improvement += id_diff 42 | count += 1 43 | 44 | return loss / count, sim_improvement / count, id_logs 45 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | 9 | def __init__(self): 10 | super(MocoLoss, self).__init__() 11 | print("Loading MOCO model from path: {}".format(model_paths["moco"])) 12 | self.model = self.__load_model() 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs 70 | -------------------------------------------------------------------------------- /criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | 7 | 8 | class ToOneHot(object): 9 | """ Convert the input PIL image to a one-hot torch tensor """ 10 | def __init__(self, n_classes=None): 11 | self.n_classes = n_classes 12 | 13 | def onehot_initialization(self, a): 14 | if self.n_classes is None: 15 | self.n_classes = len(np.unique(a)) 16 | out = np.zeros(a.shape + (self.n_classes, ), dtype=int) 17 | out[self.__all_idx(a, axis=2)] = 1 18 | return out 19 | 20 | def __all_idx(self, idx, axis): 21 | grid = np.ogrid[tuple(map(slice, idx.shape))] 22 | grid.insert(axis, idx) 23 | return tuple(grid) 24 | 25 | def __call__(self, img): 26 | img = np.array(img) 27 | one_hot = self.onehot_initialization(img) 28 | return one_hot 29 | 30 | 31 | class BilinearResize(object): 32 | def __init__(self, factors=[1, 2, 4, 8, 16, 32]): 33 | self.factors = factors 34 | 35 | def __call__(self, image): 36 | factor = np.random.choice(self.factors, size=1)[0] 37 | D = BicubicDownSample(factor=factor, cuda=False) 38 | img_tensor = transforms.ToTensor()(image).unsqueeze(0) 39 | img_tensor_lr = D(img_tensor)[0].clamp(0, 1) 40 | img_low_res = transforms.ToPILImage()(img_tensor_lr) 41 | return img_low_res 42 | 43 | 44 | class BicubicDownSample(nn.Module): 45 | def bicubic_kernel(self, x, a=-0.50): 46 | """ 47 | This equation is exactly copied from the website below: 48 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 49 | """ 50 | abs_x = torch.abs(x) 51 | if abs_x <= 1.: 52 | return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 53 | elif 1. < abs_x < 2.: 54 | return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a 55 | else: 56 | return 0.0 57 | 58 | def __init__(self, factor=4, cuda=True, padding='reflect'): 59 | super().__init__() 60 | self.factor = factor 61 | size = factor * 4 62 | k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) 63 | for i in range(size)], dtype=torch.float32) 64 | k = k / torch.sum(k) 65 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 66 | self.k1 = torch.cat([k1, k1, k1], dim=0) 67 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 68 | self.k2 = torch.cat([k2, k2, k2], dim=0) 69 | self.cuda = '.cuda' if cuda else '' 70 | self.padding = padding 71 | for param in self.parameters(): 72 | param.requires_grad = False 73 | 74 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 75 | filter_height = self.factor * 4 76 | filter_width = self.factor * 4 77 | stride = self.factor 78 | 79 | pad_along_height = max(filter_height - stride, 0) 80 | pad_along_width = max(filter_width - stride, 0) 81 | filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) 82 | filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) 83 | 84 | # compute actual padding values for each side 85 | pad_top = pad_along_height // 2 86 | pad_bottom = pad_along_height - pad_top 87 | pad_left = pad_along_width // 2 88 | pad_right = pad_along_width - pad_left 89 | 90 | # apply mirror padding 91 | if nhwc: 92 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW 93 | 94 | # downscaling performed by 1-d convolution 95 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 96 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 97 | if clip_round: 98 | x = torch.clamp(torch.round(x), 0.0, 255.) 99 | 100 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 101 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 102 | if clip_round: 103 | x = torch.clamp(torch.round(x), 0.0, 255.) 104 | 105 | if nhwc: 106 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 107 | if byte_output: 108 | return x.type('torch.ByteTensor'.format(self.cuda)) 109 | else: 110 | return x 111 | -------------------------------------------------------------------------------- /datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | import os 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | 7 | 8 | class GTResDataset(Dataset): 9 | 10 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 11 | self.pairs = [] 12 | for f in os.listdir(root_path): 13 | image_path = os.path.join(root_path, f) 14 | gt_path = os.path.join(gt_dir, f) 15 | if f.endswith(".jpg") or f.endswith(".png"): 16 | self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) 17 | self.transform = transform 18 | self.transform_train = transform_train 19 | 20 | def __len__(self): 21 | return len(self.pairs) 22 | 23 | def __getitem__(self, index): 24 | from_path, to_path, _ = self.pairs[index] 25 | from_im = Image.open(from_path).convert('RGB') 26 | to_im = Image.open(to_path).convert('RGB') 27 | 28 | if self.transform: 29 | to_im = self.transform(to_im) 30 | from_im = self.transform(from_im) 31 | 32 | return from_im, to_im 33 | -------------------------------------------------------------------------------- /datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class ImagesDataset(Dataset): 7 | 8 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 9 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 10 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 11 | self.source_transform = source_transform 12 | self.target_transform = target_transform 13 | self.opts = opts 14 | 15 | def __len__(self): 16 | return len(self.source_paths) 17 | 18 | def __getitem__(self, index): 19 | from_path = self.source_paths[index] 20 | from_im = Image.open(from_path) 21 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 22 | 23 | to_path = self.target_paths[index] 24 | to_im = Image.open(to_path).convert('RGB') 25 | if self.target_transform: 26 | to_im = self.target_transform(to_im) 27 | 28 | if self.source_transform: 29 | from_im = self.source_transform(from_im) 30 | else: 31 | from_im = to_im 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class InferenceDataset(Dataset): 7 | 8 | def __init__(self, root, opts, transform=None): 9 | self.paths = sorted(data_utils.make_dataset(root)) 10 | self.transform = transform 11 | self.opts = opts 12 | 13 | def __len__(self): 14 | return len(self.paths) 15 | 16 | def __getitem__(self, index): 17 | from_path = self.paths[index] 18 | from_im = Image.open(from_path) 19 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 20 | if self.transform: 21 | from_im = self.transform(from_im) 22 | return from_im 23 | -------------------------------------------------------------------------------- /docs/encoding_inputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/encoding_inputs.jpg -------------------------------------------------------------------------------- /docs/encoding_outputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/encoding_outputs.jpg -------------------------------------------------------------------------------- /docs/frontalization_inputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/frontalization_inputs.jpg -------------------------------------------------------------------------------- /docs/frontalization_outputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/frontalization_outputs.jpg -------------------------------------------------------------------------------- /docs/seg2image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/seg2image.png -------------------------------------------------------------------------------- /docs/sketch2image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/sketch2image.png -------------------------------------------------------------------------------- /docs/super_res_32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/super_res_32.jpg -------------------------------------------------------------------------------- /docs/super_res_style_mixing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/super_res_style_mixing.jpg -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/teaser.png -------------------------------------------------------------------------------- /docs/toonify_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/toonify_input.jpg -------------------------------------------------------------------------------- /docs/toonify_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/docs/toonify_output.jpg -------------------------------------------------------------------------------- /download-weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | mkdir pretrained_models 3 | cd pretrained_models 4 | 5 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1lB7wk7MwtdxL-LL4Z_T76DuCfk00aSXA' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1lB7wk7MwtdxL-LL4Z_T76DuCfk00aSXA" -O psp_celebs_sketch_to_face.pt && rm -rf /tmp/cookies.txt 6 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1_S4THAzXb-97DbpXmanjHtXRyKxqjARv' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1_S4THAzXb-97DbpXmanjHtXRyKxqjARv" -O psp_ffhq_frontalization.pt && rm -rf /tmp/cookies.txt 7 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ZpmSXBpJ9pFEov6-jjQstAlfYbkebECu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ZpmSXBpJ9pFEov6-jjQstAlfYbkebECu" -O psp_celebs_super_resolution.pt && rm -rf /tmp/cookies.txt 8 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1YKoiVuFaqdvzDP5CZaqa3k5phL-VDmyz' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1YKoiVuFaqdvzDP5CZaqa3k5phL-VDmyz" -O psp_ffhq_toonify.pt && rm -rf /tmp/cookies.txt 9 | 10 | cd .. 11 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 12 | bunzip2 shape_predictor_68_face_landmarks.dat.bz2 13 | -------------------------------------------------------------------------------- /environment/psp_env.yaml: -------------------------------------------------------------------------------- 1 | name: psp_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2020.4.5.1=hecc5488_0 8 | - certifi=2020.4.5.1=py36h9f0ad1d_0 9 | - libedit=3.1.20181209=hc058e9b_0 10 | - libffi=3.2.1=hd88cf55_4 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - ninja=1.10.0=hc9558a2_0 15 | - openssl=1.1.1g=h516909a_0 16 | - pip=20.0.2=py36_3 17 | - python=3.6.7=h0371630_0 18 | - python_abi=3.6=1_cp36m 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=46.4.0=py36_0 21 | - sqlite=3.31.1=h62c20be_1 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.5=h7b6447c_0 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - scipy==1.4.1 28 | - matplotlib==3.2.1 29 | - tqdm==4.46.0 30 | - numpy==1.18.4 31 | - opencv-python==4.2.0.34 32 | - pillow==7.1.2 33 | - tensorboard==2.2.1 34 | - torch==1.6.0 35 | - torchvision==0.4.2 36 | prefix: ~/anaconda3/envs/psp_env 37 | 38 | -------------------------------------------------------------------------------- /licenses/LICENSE_HuangYG123: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HuangYG123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_S-aiueo32: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /licenses/LICENSE_TreB1eN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_lessw2020: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /licenses/LICENSE_rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/__init__.py -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/encoders/__init__.py -------------------------------------------------------------------------------- /models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /models/encoders/psp_encoders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module 6 | 7 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE 8 | from models.stylegan2.model import EqualLinear 9 | 10 | 11 | class GradualStyleBlock(Module): 12 | def __init__(self, in_c, out_c, spatial): 13 | super(GradualStyleBlock, self).__init__() 14 | self.out_c = out_c 15 | self.spatial = spatial 16 | num_pools = int(np.log2(spatial)) 17 | modules = [] 18 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 19 | nn.LeakyReLU()] 20 | for i in range(num_pools - 1): 21 | modules += [ 22 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 23 | nn.LeakyReLU() 24 | ] 25 | self.convs = nn.Sequential(*modules) 26 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 27 | 28 | def forward(self, x): 29 | x = self.convs(x) 30 | x = x.view(-1, self.out_c) 31 | x = self.linear(x) 32 | return x 33 | 34 | 35 | class GradualStyleEncoder(Module): 36 | def __init__(self, num_layers, mode='ir', opts=None): 37 | super(GradualStyleEncoder, self).__init__() 38 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 39 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 40 | blocks = get_blocks(num_layers) 41 | if mode == 'ir': 42 | unit_module = bottleneck_IR 43 | elif mode == 'ir_se': 44 | unit_module = bottleneck_IR_SE 45 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 46 | BatchNorm2d(64), 47 | PReLU(64)) 48 | modules = [] 49 | for block in blocks: 50 | for bottleneck in block: 51 | modules.append(unit_module(bottleneck.in_channel, 52 | bottleneck.depth, 53 | bottleneck.stride)) 54 | self.body = Sequential(*modules) 55 | 56 | self.styles = nn.ModuleList() 57 | self.style_count = opts.n_styles 58 | self.coarse_ind = 3 59 | self.middle_ind = 7 60 | for i in range(self.style_count): 61 | if i < self.coarse_ind: 62 | style = GradualStyleBlock(512, 512, 16) 63 | elif i < self.middle_ind: 64 | style = GradualStyleBlock(512, 512, 32) 65 | else: 66 | style = GradualStyleBlock(512, 512, 64) 67 | self.styles.append(style) 68 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 69 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 70 | 71 | def _upsample_add(self, x, y): 72 | '''Upsample and add two feature maps. 73 | Args: 74 | x: (Variable) top feature map to be upsampled. 75 | y: (Variable) lateral feature map. 76 | Returns: 77 | (Variable) added feature map. 78 | Note in PyTorch, when input size is odd, the upsampled feature map 79 | with `F.upsample(..., scale_factor=2, mode='nearest')` 80 | maybe not equal to the lateral feature map size. 81 | e.g. 82 | original input size: [N,_,15,15] -> 83 | conv2d feature map size: [N,_,8,8] -> 84 | upsampled feature map size: [N,_,16,16] 85 | So we choose bilinear upsample which supports arbitrary output sizes. 86 | ''' 87 | _, _, H, W = y.size() 88 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 89 | 90 | def forward(self, x): 91 | x = self.input_layer(x) 92 | 93 | latents = [] 94 | modulelist = list(self.body._modules.values()) 95 | for i, l in enumerate(modulelist): 96 | x = l(x) 97 | if i == 6: 98 | c1 = x 99 | elif i == 20: 100 | c2 = x 101 | elif i == 23: 102 | c3 = x 103 | 104 | for j in range(self.coarse_ind): 105 | latents.append(self.styles[j](c3)) 106 | 107 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 108 | for j in range(self.coarse_ind, self.middle_ind): 109 | latents.append(self.styles[j](p2)) 110 | 111 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 112 | for j in range(self.middle_ind, self.style_count): 113 | latents.append(self.styles[j](p1)) 114 | 115 | out = torch.stack(latents, dim=1) 116 | return out 117 | 118 | 119 | class BackboneEncoderUsingLastLayerIntoW(Module): 120 | def __init__(self, num_layers, mode='ir', opts=None): 121 | super(BackboneEncoderUsingLastLayerIntoW, self).__init__() 122 | print('Using BackboneEncoderUsingLastLayerIntoW') 123 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 124 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 125 | blocks = get_blocks(num_layers) 126 | if mode == 'ir': 127 | unit_module = bottleneck_IR 128 | elif mode == 'ir_se': 129 | unit_module = bottleneck_IR_SE 130 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 131 | BatchNorm2d(64), 132 | PReLU(64)) 133 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 134 | self.linear = EqualLinear(512, 512, lr_mul=1) 135 | modules = [] 136 | for block in blocks: 137 | for bottleneck in block: 138 | modules.append(unit_module(bottleneck.in_channel, 139 | bottleneck.depth, 140 | bottleneck.stride)) 141 | self.body = Sequential(*modules) 142 | 143 | def forward(self, x): 144 | x = self.input_layer(x) 145 | x = self.body(x) 146 | x = self.output_pool(x) 147 | x = x.view(-1, 512) 148 | x = self.linear(x) 149 | return x 150 | 151 | 152 | class BackboneEncoderUsingLastLayerIntoWPlus(Module): 153 | def __init__(self, num_layers, mode='ir', opts=None): 154 | super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__() 155 | print('Using BackboneEncoderUsingLastLayerIntoWPlus') 156 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 157 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 158 | blocks = get_blocks(num_layers) 159 | if mode == 'ir': 160 | unit_module = bottleneck_IR 161 | elif mode == 'ir_se': 162 | unit_module = bottleneck_IR_SE 163 | self.n_styles = opts.n_styles 164 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 165 | BatchNorm2d(64), 166 | PReLU(64)) 167 | self.output_layer_2 = Sequential(BatchNorm2d(512), 168 | torch.nn.AdaptiveAvgPool2d((7, 7)), 169 | Flatten(), 170 | Linear(512 * 7 * 7, 512)) 171 | self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1) 172 | modules = [] 173 | for block in blocks: 174 | for bottleneck in block: 175 | modules.append(unit_module(bottleneck.in_channel, 176 | bottleneck.depth, 177 | bottleneck.stride)) 178 | self.body = Sequential(*modules) 179 | 180 | def forward(self, x): 181 | x = self.input_layer(x) 182 | x = self.body(x) 183 | x = self.output_layer_2(x) 184 | x = self.linear(x) 185 | x = x.view(-1, self.n_styles, 512) 186 | return x 187 | -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 7 | from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 8 | 9 | device = 'cuda:0' 10 | 11 | 12 | class MTCNN(): 13 | def __init__(self): 14 | print(device) 15 | self.pnet = PNet().to(device) 16 | self.rnet = RNet().to(device) 17 | self.onet = ONet().to(device) 18 | self.pnet.eval() 19 | self.rnet.eval() 20 | self.onet.eval() 21 | self.refrence = get_reference_facial_points(default_square=True) 22 | 23 | def align(self, img): 24 | _, landmarks = self.detect_faces(img) 25 | if len(landmarks) == 0: 26 | return None, None 27 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 28 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 29 | return Image.fromarray(warped_face), tfm 30 | 31 | def align_multi(self, img, limit=None, min_face_size=30.0): 32 | boxes, landmarks = self.detect_faces(img, min_face_size) 33 | if limit: 34 | boxes = boxes[:limit] 35 | landmarks = landmarks[:limit] 36 | faces = [] 37 | tfms = [] 38 | for landmark in landmarks: 39 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 40 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 41 | faces.append(Image.fromarray(warped_face)) 42 | tfms.append(tfm) 43 | return boxes, faces, tfms 44 | 45 | def detect_faces(self, image, min_face_size=20.0, 46 | thresholds=[0.15, 0.25, 0.35], 47 | nms_thresholds=[0.7, 0.7, 0.7]): 48 | """ 49 | Arguments: 50 | image: an instance of PIL.Image. 51 | min_face_size: a float number. 52 | thresholds: a list of length 3. 53 | nms_thresholds: a list of length 3. 54 | 55 | Returns: 56 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 57 | bounding boxes and facial landmarks. 58 | """ 59 | 60 | # BUILD AN IMAGE PYRAMID 61 | width, height = image.size 62 | min_length = min(height, width) 63 | 64 | min_detection_size = 12 65 | factor = 0.707 # sqrt(0.5) 66 | 67 | # scales for scaling the image 68 | scales = [] 69 | 70 | # scales the image so that 71 | # minimum size that we can detect equals to 72 | # minimum face size that we want to detect 73 | m = min_detection_size / min_face_size 74 | min_length *= m 75 | 76 | factor_count = 0 77 | while min_length > min_detection_size: 78 | scales.append(m * factor ** factor_count) 79 | min_length *= factor 80 | factor_count += 1 81 | 82 | # STAGE 1 83 | 84 | # it will be returned 85 | bounding_boxes = [] 86 | 87 | with torch.no_grad(): 88 | # run P-Net on different scales 89 | for s in scales: 90 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 91 | bounding_boxes.append(boxes) 92 | 93 | # collect boxes (and offsets, and scores) from different scales 94 | bounding_boxes = [i for i in bounding_boxes if i is not None] 95 | bounding_boxes = np.vstack(bounding_boxes) 96 | 97 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 98 | bounding_boxes = bounding_boxes[keep] 99 | 100 | # use offsets predicted by pnet to transform bounding boxes 101 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 102 | # shape [n_boxes, 5] 103 | 104 | bounding_boxes = convert_to_square(bounding_boxes) 105 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 106 | 107 | # STAGE 2 108 | 109 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 110 | img_boxes = torch.FloatTensor(img_boxes).to(device) 111 | 112 | output = self.rnet(img_boxes) 113 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 114 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 115 | 116 | keep = np.where(probs[:, 1] > thresholds[1])[0] 117 | bounding_boxes = bounding_boxes[keep] 118 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 119 | offsets = offsets[keep] 120 | 121 | keep = nms(bounding_boxes, nms_thresholds[1]) 122 | bounding_boxes = bounding_boxes[keep] 123 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 124 | bounding_boxes = convert_to_square(bounding_boxes) 125 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 126 | 127 | # STAGE 3 128 | 129 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 130 | if len(img_boxes) == 0: 131 | return [], [] 132 | img_boxes = torch.FloatTensor(img_boxes).to(device) 133 | output = self.onet(img_boxes) 134 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 135 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 136 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 137 | 138 | keep = np.where(probs[:, 1] > thresholds[2])[0] 139 | bounding_boxes = bounding_boxes[keep] 140 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 141 | offsets = offsets[keep] 142 | landmarks = landmarks[keep] 143 | 144 | # compute landmark points 145 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 146 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 147 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 148 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 149 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 150 | 151 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 152 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 153 | bounding_boxes = bounding_boxes[keep] 154 | landmarks = landmarks[keep] 155 | 156 | return bounding_boxes, landmarks 157 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/align_trans.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 24 15:43:29 2017 4 | @author: zhaoy 5 | """ 6 | import numpy as np 7 | import cv2 8 | 9 | # from scipy.linalg import lstsq 10 | # from scipy.ndimage import geometric_transform # , map_coordinates 11 | 12 | from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2 13 | 14 | # reference facial points, a list of coordinates (x,y) 15 | REFERENCE_FACIAL_POINTS = [ 16 | [30.29459953, 51.69630051], 17 | [65.53179932, 51.50139999], 18 | [48.02519989, 71.73660278], 19 | [33.54930115, 92.3655014], 20 | [62.72990036, 92.20410156] 21 | ] 22 | 23 | DEFAULT_CROP_SIZE = (96, 112) 24 | 25 | 26 | class FaceWarpException(Exception): 27 | def __str__(self): 28 | return 'In File {}:{}'.format( 29 | __file__, super.__str__(self)) 30 | 31 | 32 | def get_reference_facial_points(output_size=None, 33 | inner_padding_factor=0.0, 34 | outer_padding=(0, 0), 35 | default_square=False): 36 | """ 37 | Function: 38 | ---------- 39 | get reference 5 key points according to crop settings: 40 | 0. Set default crop_size: 41 | if default_square: 42 | crop_size = (112, 112) 43 | else: 44 | crop_size = (96, 112) 45 | 1. Pad the crop_size by inner_padding_factor in each side; 46 | 2. Resize crop_size into (output_size - outer_padding*2), 47 | pad into output_size with outer_padding; 48 | 3. Output reference_5point; 49 | Parameters: 50 | ---------- 51 | @output_size: (w, h) or None 52 | size of aligned face image 53 | @inner_padding_factor: (w_factor, h_factor) 54 | padding factor for inner (w, h) 55 | @outer_padding: (w_pad, h_pad) 56 | each row is a pair of coordinates (x, y) 57 | @default_square: True or False 58 | if True: 59 | default crop_size = (112, 112) 60 | else: 61 | default crop_size = (96, 112); 62 | !!! make sure, if output_size is not None: 63 | (output_size - outer_padding) 64 | = some_scale * (default crop_size * (1.0 + inner_padding_factor)) 65 | Returns: 66 | ---------- 67 | @reference_5point: 5x2 np.array 68 | each row is a pair of transformed coordinates (x, y) 69 | """ 70 | # print('\n===> get_reference_facial_points():') 71 | 72 | # print('---> Params:') 73 | # print(' output_size: ', output_size) 74 | # print(' inner_padding_factor: ', inner_padding_factor) 75 | # print(' outer_padding:', outer_padding) 76 | # print(' default_square: ', default_square) 77 | 78 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 79 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 80 | 81 | # 0) make the inner region a square 82 | if default_square: 83 | size_diff = max(tmp_crop_size) - tmp_crop_size 84 | tmp_5pts += size_diff / 2 85 | tmp_crop_size += size_diff 86 | 87 | # print('---> default:') 88 | # print(' crop_size = ', tmp_crop_size) 89 | # print(' reference_5pts = ', tmp_5pts) 90 | 91 | if (output_size and 92 | output_size[0] == tmp_crop_size[0] and 93 | output_size[1] == tmp_crop_size[1]): 94 | # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) 95 | return tmp_5pts 96 | 97 | if (inner_padding_factor == 0 and 98 | outer_padding == (0, 0)): 99 | if output_size is None: 100 | # print('No paddings to do: return default reference points') 101 | return tmp_5pts 102 | else: 103 | raise FaceWarpException( 104 | 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 105 | 106 | # check output size 107 | if not (0 <= inner_padding_factor <= 1.0): 108 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 109 | 110 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) 111 | and output_size is None): 112 | output_size = tmp_crop_size * \ 113 | (1 + inner_padding_factor * 2).astype(np.int32) 114 | output_size += np.array(outer_padding) 115 | # print(' deduced from paddings, output_size = ', output_size) 116 | 117 | if not (outer_padding[0] < output_size[0] 118 | and outer_padding[1] < output_size[1]): 119 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]' 120 | 'and outer_padding[1] < output_size[1])') 121 | 122 | # 1) pad the inner region according inner_padding_factor 123 | # print('---> STEP1: pad the inner region according inner_padding_factor') 124 | if inner_padding_factor > 0: 125 | size_diff = tmp_crop_size * inner_padding_factor * 2 126 | tmp_5pts += size_diff / 2 127 | tmp_crop_size += np.round(size_diff).astype(np.int32) 128 | 129 | # print(' crop_size = ', tmp_crop_size) 130 | # print(' reference_5pts = ', tmp_5pts) 131 | 132 | # 2) resize the padded inner region 133 | # print('---> STEP2: resize the padded inner region') 134 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 135 | # print(' crop_size = ', tmp_crop_size) 136 | # print(' size_bf_outer_pad = ', size_bf_outer_pad) 137 | 138 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 139 | raise FaceWarpException('Must have (output_size - outer_padding)' 140 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 141 | 142 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 143 | # print(' resize scale_factor = ', scale_factor) 144 | tmp_5pts = tmp_5pts * scale_factor 145 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 146 | # tmp_5pts = tmp_5pts + size_diff / 2 147 | tmp_crop_size = size_bf_outer_pad 148 | # print(' crop_size = ', tmp_crop_size) 149 | # print(' reference_5pts = ', tmp_5pts) 150 | 151 | # 3) add outer_padding to make output_size 152 | reference_5point = tmp_5pts + np.array(outer_padding) 153 | tmp_crop_size = output_size 154 | # print('---> STEP3: add outer_padding to make output_size') 155 | # print(' crop_size = ', tmp_crop_size) 156 | # print(' reference_5pts = ', tmp_5pts) 157 | 158 | # print('===> end get_reference_facial_points\n') 159 | 160 | return reference_5point 161 | 162 | 163 | def get_affine_transform_matrix(src_pts, dst_pts): 164 | """ 165 | Function: 166 | ---------- 167 | get affine transform matrix 'tfm' from src_pts to dst_pts 168 | Parameters: 169 | ---------- 170 | @src_pts: Kx2 np.array 171 | source points matrix, each row is a pair of coordinates (x, y) 172 | @dst_pts: Kx2 np.array 173 | destination points matrix, each row is a pair of coordinates (x, y) 174 | Returns: 175 | ---------- 176 | @tfm: 2x3 np.array 177 | transform matrix from src_pts to dst_pts 178 | """ 179 | 180 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 181 | n_pts = src_pts.shape[0] 182 | ones = np.ones((n_pts, 1), src_pts.dtype) 183 | src_pts_ = np.hstack([src_pts, ones]) 184 | dst_pts_ = np.hstack([dst_pts, ones]) 185 | 186 | # #print(('src_pts_:\n' + str(src_pts_)) 187 | # #print(('dst_pts_:\n' + str(dst_pts_)) 188 | 189 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 190 | 191 | # #print(('np.linalg.lstsq return A: \n' + str(A)) 192 | # #print(('np.linalg.lstsq return res: \n' + str(res)) 193 | # #print(('np.linalg.lstsq return rank: \n' + str(rank)) 194 | # #print(('np.linalg.lstsq return s: \n' + str(s)) 195 | 196 | if rank == 3: 197 | tfm = np.float32([ 198 | [A[0, 0], A[1, 0], A[2, 0]], 199 | [A[0, 1], A[1, 1], A[2, 1]] 200 | ]) 201 | elif rank == 2: 202 | tfm = np.float32([ 203 | [A[0, 0], A[1, 0], 0], 204 | [A[0, 1], A[1, 1], 0] 205 | ]) 206 | 207 | return tfm 208 | 209 | 210 | def warp_and_crop_face(src_img, 211 | facial_pts, 212 | reference_pts=None, 213 | crop_size=(96, 112), 214 | align_type='smilarity'): 215 | """ 216 | Function: 217 | ---------- 218 | apply affine transform 'trans' to uv 219 | Parameters: 220 | ---------- 221 | @src_img: 3x3 np.array 222 | input image 223 | @facial_pts: could be 224 | 1)a list of K coordinates (x,y) 225 | or 226 | 2) Kx2 or 2xK np.array 227 | each row or col is a pair of coordinates (x, y) 228 | @reference_pts: could be 229 | 1) a list of K coordinates (x,y) 230 | or 231 | 2) Kx2 or 2xK np.array 232 | each row or col is a pair of coordinates (x, y) 233 | or 234 | 3) None 235 | if None, use default reference facial points 236 | @crop_size: (w, h) 237 | output face image size 238 | @align_type: transform type, could be one of 239 | 1) 'similarity': use similarity transform 240 | 2) 'cv2_affine': use the first 3 points to do affine transform, 241 | by calling cv2.getAffineTransform() 242 | 3) 'affine': use all points to do affine transform 243 | Returns: 244 | ---------- 245 | @face_img: output face image with size (w, h) = @crop_size 246 | """ 247 | 248 | if reference_pts is None: 249 | if crop_size[0] == 96 and crop_size[1] == 112: 250 | reference_pts = REFERENCE_FACIAL_POINTS 251 | else: 252 | default_square = False 253 | inner_padding_factor = 0 254 | outer_padding = (0, 0) 255 | output_size = crop_size 256 | 257 | reference_pts = get_reference_facial_points(output_size, 258 | inner_padding_factor, 259 | outer_padding, 260 | default_square) 261 | 262 | ref_pts = np.float32(reference_pts) 263 | ref_pts_shp = ref_pts.shape 264 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 265 | raise FaceWarpException( 266 | 'reference_pts.shape must be (K,2) or (2,K) and K>2') 267 | 268 | if ref_pts_shp[0] == 2: 269 | ref_pts = ref_pts.T 270 | 271 | src_pts = np.float32(facial_pts) 272 | src_pts_shp = src_pts.shape 273 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 274 | raise FaceWarpException( 275 | 'facial_pts.shape must be (K,2) or (2,K) and K>2') 276 | 277 | if src_pts_shp[0] == 2: 278 | src_pts = src_pts.T 279 | 280 | # #print('--->src_pts:\n', src_pts 281 | # #print('--->ref_pts\n', ref_pts 282 | 283 | if src_pts.shape != ref_pts.shape: 284 | raise FaceWarpException( 285 | 'facial_pts and reference_pts must have the same shape') 286 | 287 | if align_type is 'cv2_affine': 288 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 289 | # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm)) 290 | elif align_type is 'affine': 291 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 292 | # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm)) 293 | else: 294 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) 295 | # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm)) 296 | 297 | # #print('--->Transform matrix: ' 298 | # #print(('type(tfm):' + str(type(tfm))) 299 | # #print(('tfm.dtype:' + str(tfm.dtype)) 300 | # #print( tfm 301 | 302 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) 303 | 304 | return face_img, tfm 305 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def nms(boxes, overlap_threshold=0.5, mode='union'): 6 | """Non-maximum suppression. 7 | 8 | Arguments: 9 | boxes: a float numpy array of shape [n, 5], 10 | where each row is (xmin, ymin, xmax, ymax, score). 11 | overlap_threshold: a float number. 12 | mode: 'union' or 'min'. 13 | 14 | Returns: 15 | list with indices of the selected boxes 16 | """ 17 | 18 | # if there are no boxes, return the empty list 19 | if len(boxes) == 0: 20 | return [] 21 | 22 | # list of picked indices 23 | pick = [] 24 | 25 | # grab the coordinates of the bounding boxes 26 | x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] 27 | 28 | area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) 29 | ids = np.argsort(score) # in increasing order 30 | 31 | while len(ids) > 0: 32 | 33 | # grab index of the largest value 34 | last = len(ids) - 1 35 | i = ids[last] 36 | pick.append(i) 37 | 38 | # compute intersections 39 | # of the box with the largest score 40 | # with the rest of boxes 41 | 42 | # left top corner of intersection boxes 43 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 44 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 45 | 46 | # right bottom corner of intersection boxes 47 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 48 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 49 | 50 | # width and height of intersection boxes 51 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 52 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 53 | 54 | # intersections' areas 55 | inter = w * h 56 | if mode == 'min': 57 | overlap = inter / np.minimum(area[i], area[ids[:last]]) 58 | elif mode == 'union': 59 | # intersection over union (IoU) 60 | overlap = inter / (area[i] + area[ids[:last]] - inter) 61 | 62 | # delete all boxes where overlap is too big 63 | ids = np.delete( 64 | ids, 65 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 66 | ) 67 | 68 | return pick 69 | 70 | 71 | def convert_to_square(bboxes): 72 | """Convert bounding boxes to a square form. 73 | 74 | Arguments: 75 | bboxes: a float numpy array of shape [n, 5]. 76 | 77 | Returns: 78 | a float numpy array of shape [n, 5], 79 | squared bounding boxes. 80 | """ 81 | 82 | square_bboxes = np.zeros_like(bboxes) 83 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 84 | h = y2 - y1 + 1.0 85 | w = x2 - x1 + 1.0 86 | max_side = np.maximum(h, w) 87 | square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 88 | square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 89 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 90 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 91 | return square_bboxes 92 | 93 | 94 | def calibrate_box(bboxes, offsets): 95 | """Transform bounding boxes to be more like true bounding boxes. 96 | 'offsets' is one of the outputs of the nets. 97 | 98 | Arguments: 99 | bboxes: a float numpy array of shape [n, 5]. 100 | offsets: a float numpy array of shape [n, 4]. 101 | 102 | Returns: 103 | a float numpy array of shape [n, 5]. 104 | """ 105 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 106 | w = x2 - x1 + 1.0 107 | h = y2 - y1 + 1.0 108 | w = np.expand_dims(w, 1) 109 | h = np.expand_dims(h, 1) 110 | 111 | # this is what happening here: 112 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 113 | # x1_true = x1 + tx1*w 114 | # y1_true = y1 + ty1*h 115 | # x2_true = x2 + tx2*w 116 | # y2_true = y2 + ty2*h 117 | # below is just more compact form of this 118 | 119 | # are offsets always such that 120 | # x1 < x2 and y1 < y2 ? 121 | 122 | translation = np.hstack([w, h, w, h]) * offsets 123 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 124 | return bboxes 125 | 126 | 127 | def get_image_boxes(bounding_boxes, img, size=24): 128 | """Cut out boxes from the image. 129 | 130 | Arguments: 131 | bounding_boxes: a float numpy array of shape [n, 5]. 132 | img: an instance of PIL.Image. 133 | size: an integer, size of cutouts. 134 | 135 | Returns: 136 | a float numpy array of shape [n, 3, size, size]. 137 | """ 138 | 139 | num_boxes = len(bounding_boxes) 140 | width, height = img.size 141 | 142 | [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) 143 | img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') 144 | 145 | for i in range(num_boxes): 146 | img_box = np.zeros((h[i], w[i], 3), 'uint8') 147 | 148 | img_array = np.asarray(img, 'uint8') 149 | img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \ 150 | img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] 151 | 152 | # resize 153 | img_box = Image.fromarray(img_box) 154 | img_box = img_box.resize((size, size), Image.BILINEAR) 155 | img_box = np.asarray(img_box, 'float32') 156 | 157 | img_boxes[i, :, :, :] = _preprocess(img_box) 158 | 159 | return img_boxes 160 | 161 | 162 | def correct_bboxes(bboxes, width, height): 163 | """Crop boxes that are too big and get coordinates 164 | with respect to cutouts. 165 | 166 | Arguments: 167 | bboxes: a float numpy array of shape [n, 5], 168 | where each row is (xmin, ymin, xmax, ymax, score). 169 | width: a float number. 170 | height: a float number. 171 | 172 | Returns: 173 | dy, dx, edy, edx: a int numpy arrays of shape [n], 174 | coordinates of the boxes with respect to the cutouts. 175 | y, x, ey, ex: a int numpy arrays of shape [n], 176 | corrected ymin, xmin, ymax, xmax. 177 | h, w: a int numpy arrays of shape [n], 178 | just heights and widths of boxes. 179 | 180 | in the following order: 181 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 182 | """ 183 | 184 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 185 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 186 | num_boxes = bboxes.shape[0] 187 | 188 | # 'e' stands for end 189 | # (x, y) -> (ex, ey) 190 | x, y, ex, ey = x1, y1, x2, y2 191 | 192 | # we need to cut out a box from the image. 193 | # (x, y, ex, ey) are corrected coordinates of the box 194 | # in the image. 195 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 196 | # from the image. 197 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 198 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 199 | 200 | # if box's bottom right corner is too far right 201 | ind = np.where(ex > width - 1.0)[0] 202 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 203 | ex[ind] = width - 1.0 204 | 205 | # if box's bottom right corner is too low 206 | ind = np.where(ey > height - 1.0)[0] 207 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 208 | ey[ind] = height - 1.0 209 | 210 | # if box's top left corner is too far left 211 | ind = np.where(x < 0.0)[0] 212 | dx[ind] = 0.0 - x[ind] 213 | x[ind] = 0.0 214 | 215 | # if box's top left corner is too high 216 | ind = np.where(y < 0.0)[0] 217 | dy[ind] = 0.0 - y[ind] 218 | y[ind] = 0.0 219 | 220 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 221 | return_list = [i.astype('int32') for i in return_list] 222 | 223 | return return_list 224 | 225 | 226 | def _preprocess(img): 227 | """Preprocessing step before feeding the network. 228 | 229 | Arguments: 230 | img: a float numpy array of shape [h, w, c]. 231 | 232 | Returns: 233 | a float numpy array of shape [1, c, h, w]. 234 | """ 235 | img = img.transpose((2, 0, 1)) 236 | img = np.expand_dims(img, 0) 237 | img = (img - 127.5) * 0.0078125 238 | return img 239 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = 'cuda:0' 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, 'float32') 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack([ 93 | np.round((stride * inds[1] + 1.0) / scale), 94 | np.round((stride * inds[0] + 1.0) / scale), 95 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 96 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 97 | score, offsets 98 | ]) 99 | # why one is added? 100 | 101 | return bounding_boxes.T 102 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 11 06:54:28 2017 4 | 5 | @author: zhaoyafei 6 | """ 7 | 8 | import numpy as np 9 | from numpy.linalg import inv, norm, lstsq 10 | from numpy.linalg import matrix_rank as rank 11 | 12 | 13 | class MatlabCp2tormException(Exception): 14 | def __str__(self): 15 | return 'In File {}:{}'.format( 16 | __file__, super.__str__(self)) 17 | 18 | 19 | def tformfwd(trans, uv): 20 | """ 21 | Function: 22 | ---------- 23 | apply affine transform 'trans' to uv 24 | 25 | Parameters: 26 | ---------- 27 | @trans: 3x3 np.array 28 | transform matrix 29 | @uv: Kx2 np.array 30 | each row is a pair of coordinates (x, y) 31 | 32 | Returns: 33 | ---------- 34 | @xy: Kx2 np.array 35 | each row is a pair of transformed coordinates (x, y) 36 | """ 37 | uv = np.hstack(( 38 | uv, np.ones((uv.shape[0], 1)) 39 | )) 40 | xy = np.dot(uv, trans) 41 | xy = xy[:, 0:-1] 42 | return xy 43 | 44 | 45 | def tforminv(trans, uv): 46 | """ 47 | Function: 48 | ---------- 49 | apply the inverse of affine transform 'trans' to uv 50 | 51 | Parameters: 52 | ---------- 53 | @trans: 3x3 np.array 54 | transform matrix 55 | @uv: Kx2 np.array 56 | each row is a pair of coordinates (x, y) 57 | 58 | Returns: 59 | ---------- 60 | @xy: Kx2 np.array 61 | each row is a pair of inverse-transformed coordinates (x, y) 62 | """ 63 | Tinv = inv(trans) 64 | xy = tformfwd(Tinv, uv) 65 | return xy 66 | 67 | 68 | def findNonreflectiveSimilarity(uv, xy, options=None): 69 | options = {'K': 2} 70 | 71 | K = options['K'] 72 | M = xy.shape[0] 73 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 74 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 75 | # print('--->x, y:\n', x, y 76 | 77 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) 78 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) 79 | X = np.vstack((tmp1, tmp2)) 80 | # print('--->X.shape: ', X.shape 81 | # print('X:\n', X 82 | 83 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 84 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 85 | U = np.vstack((u, v)) 86 | # print('--->U.shape: ', U.shape 87 | # print('U:\n', U 88 | 89 | # We know that X * r = U 90 | if rank(X) >= 2 * K: 91 | r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want 92 | r = np.squeeze(r) 93 | else: 94 | raise Exception('cp2tform:twoUniquePointsReq') 95 | 96 | # print('--->r:\n', r 97 | 98 | sc = r[0] 99 | ss = r[1] 100 | tx = r[2] 101 | ty = r[3] 102 | 103 | Tinv = np.array([ 104 | [sc, -ss, 0], 105 | [ss, sc, 0], 106 | [tx, ty, 1] 107 | ]) 108 | 109 | # print('--->Tinv:\n', Tinv 110 | 111 | T = inv(Tinv) 112 | # print('--->T:\n', T 113 | 114 | T[:, 2] = np.array([0, 0, 1]) 115 | 116 | return T, Tinv 117 | 118 | 119 | def findSimilarity(uv, xy, options=None): 120 | options = {'K': 2} 121 | 122 | # uv = np.array(uv) 123 | # xy = np.array(xy) 124 | 125 | # Solve for trans1 126 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) 127 | 128 | # Solve for trans2 129 | 130 | # manually reflect the xy data across the Y-axis 131 | xyR = xy 132 | xyR[:, 0] = -1 * xyR[:, 0] 133 | 134 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) 135 | 136 | # manually reflect the tform to undo the reflection done on xyR 137 | TreflectY = np.array([ 138 | [-1, 0, 0], 139 | [0, 1, 0], 140 | [0, 0, 1] 141 | ]) 142 | 143 | trans2 = np.dot(trans2r, TreflectY) 144 | 145 | # Figure out if trans1 or trans2 is better 146 | xy1 = tformfwd(trans1, uv) 147 | norm1 = norm(xy1 - xy) 148 | 149 | xy2 = tformfwd(trans2, uv) 150 | norm2 = norm(xy2 - xy) 151 | 152 | if norm1 <= norm2: 153 | return trans1, trans1_inv 154 | else: 155 | trans2_inv = inv(trans2) 156 | return trans2, trans2_inv 157 | 158 | 159 | def get_similarity_transform(src_pts, dst_pts, reflective=True): 160 | """ 161 | Function: 162 | ---------- 163 | Find Similarity Transform Matrix 'trans': 164 | u = src_pts[:, 0] 165 | v = src_pts[:, 1] 166 | x = dst_pts[:, 0] 167 | y = dst_pts[:, 1] 168 | [x, y, 1] = [u, v, 1] * trans 169 | 170 | Parameters: 171 | ---------- 172 | @src_pts: Kx2 np.array 173 | source points, each row is a pair of coordinates (x, y) 174 | @dst_pts: Kx2 np.array 175 | destination points, each row is a pair of transformed 176 | coordinates (x, y) 177 | @reflective: True or False 178 | if True: 179 | use reflective similarity transform 180 | else: 181 | use non-reflective similarity transform 182 | 183 | Returns: 184 | ---------- 185 | @trans: 3x3 np.array 186 | transform matrix from uv to xy 187 | trans_inv: 3x3 np.array 188 | inverse of trans, transform matrix from xy to uv 189 | """ 190 | 191 | if reflective: 192 | trans, trans_inv = findSimilarity(src_pts, dst_pts) 193 | else: 194 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) 195 | 196 | return trans, trans_inv 197 | 198 | 199 | def cvt_tform_mat_for_cv2(trans): 200 | """ 201 | Function: 202 | ---------- 203 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be 204 | directly used by cv2.warpAffine(): 205 | u = src_pts[:, 0] 206 | v = src_pts[:, 1] 207 | x = dst_pts[:, 0] 208 | y = dst_pts[:, 1] 209 | [x, y].T = cv_trans * [u, v, 1].T 210 | 211 | Parameters: 212 | ---------- 213 | @trans: 3x3 np.array 214 | transform matrix from uv to xy 215 | 216 | Returns: 217 | ---------- 218 | @cv2_trans: 2x3 np.array 219 | transform matrix from src_pts to dst_pts, could be directly used 220 | for cv2.warpAffine() 221 | """ 222 | cv2_trans = trans[:, 0:2].T 223 | 224 | return cv2_trans 225 | 226 | 227 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): 228 | """ 229 | Function: 230 | ---------- 231 | Find Similarity Transform Matrix 'cv2_trans' which could be 232 | directly used by cv2.warpAffine(): 233 | u = src_pts[:, 0] 234 | v = src_pts[:, 1] 235 | x = dst_pts[:, 0] 236 | y = dst_pts[:, 1] 237 | [x, y].T = cv_trans * [u, v, 1].T 238 | 239 | Parameters: 240 | ---------- 241 | @src_pts: Kx2 np.array 242 | source points, each row is a pair of coordinates (x, y) 243 | @dst_pts: Kx2 np.array 244 | destination points, each row is a pair of transformed 245 | coordinates (x, y) 246 | reflective: True or False 247 | if True: 248 | use reflective similarity transform 249 | else: 250 | use non-reflective similarity transform 251 | 252 | Returns: 253 | ---------- 254 | @cv2_trans: 2x3 np.array 255 | transform matrix from src_pts to dst_pts, could be directly used 256 | for cv2.warpAffine() 257 | """ 258 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) 259 | cv2_trans = cvt_tform_mat_for_cv2(trans) 260 | 261 | return cv2_trans 262 | 263 | 264 | if __name__ == '__main__': 265 | """ 266 | u = [0, 6, -2] 267 | v = [0, 3, 5] 268 | x = [-1, 0, 4] 269 | y = [-1, -10, 4] 270 | 271 | # In Matlab, run: 272 | # 273 | # uv = [u'; v']; 274 | # xy = [x'; y']; 275 | # tform_sim=cp2tform(uv,xy,'similarity'); 276 | # 277 | # trans = tform_sim.tdata.T 278 | # ans = 279 | # -0.0764 -1.6190 0 280 | # 1.6190 -0.0764 0 281 | # -3.2156 0.0290 1.0000 282 | # trans_inv = tform_sim.tdata.Tinv 283 | # ans = 284 | # 285 | # -0.0291 0.6163 0 286 | # -0.6163 -0.0291 0 287 | # -0.0756 1.9826 1.0000 288 | # xy_m=tformfwd(tform_sim, u,v) 289 | # 290 | # xy_m = 291 | # 292 | # -3.2156 0.0290 293 | # 1.1833 -9.9143 294 | # 5.0323 2.8853 295 | # uv_m=tforminv(tform_sim, x,y) 296 | # 297 | # uv_m = 298 | # 299 | # 0.5698 1.3953 300 | # 6.0872 2.2733 301 | # -2.6570 4.3314 302 | """ 303 | u = [0, 6, -2] 304 | v = [0, 3, 5] 305 | x = [-1, 0, 4] 306 | y = [-1, -10, 4] 307 | 308 | uv = np.array((u, v)).T 309 | xy = np.array((x, y)).T 310 | 311 | print('\n--->uv:') 312 | print(uv) 313 | print('\n--->xy:') 314 | print(xy) 315 | 316 | trans, trans_inv = get_similarity_transform(uv, xy) 317 | 318 | print('\n--->trans matrix:') 319 | print(trans) 320 | 321 | print('\n--->trans_inv matrix:') 322 | print(trans_inv) 323 | 324 | print('\n---> apply transform to uv') 325 | print('\nxy_m = uv_augmented * trans') 326 | uv_aug = np.hstack(( 327 | uv, np.ones((uv.shape[0], 1)) 328 | )) 329 | xy_m = np.dot(uv_aug, trans) 330 | print(xy_m) 331 | 332 | print('\nxy_m = tformfwd(trans, uv)') 333 | xy_m = tformfwd(trans, uv) 334 | print(xy_m) 335 | 336 | print('\n---> apply inverse transform to xy') 337 | print('\nuv_m = xy_augmented * trans_inv') 338 | xy_aug = np.hstack(( 339 | xy, np.ones((xy.shape[0], 1)) 340 | )) 341 | uv_m = np.dot(xy_aug, trans_inv) 342 | print(uv_m) 343 | 344 | print('\nuv_m = tformfwd(trans_inv, xy)') 345 | uv_m = tformfwd(trans_inv, xy) 346 | print(uv_m) 347 | 348 | uv_m = tforminv(trans, xy) 349 | print('\nuv_m = tforminv(trans, xy)') 350 | print(uv_m) 351 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /models/psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | from models.encoders import psp_encoders 11 | from models.stylegan2.model import Generator 12 | from configs.paths_config import model_paths 13 | 14 | 15 | def get_keys(d, name): 16 | if 'state_dict' in d: 17 | d = d['state_dict'] 18 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 19 | return d_filt 20 | 21 | 22 | class pSp(nn.Module): 23 | 24 | def __init__(self, opts): 25 | super(pSp, self).__init__() 26 | self.set_opts(opts) 27 | # compute number of style inputs based on the output resolution 28 | self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 29 | # Define architecture 30 | self.encoder = self.set_encoder() 31 | self.decoder = Generator(self.opts.output_size, 512, 8) 32 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 33 | # Load weights if needed 34 | self.load_weights() 35 | 36 | def set_encoder(self): 37 | if self.opts.encoder_type == 'GradualStyleEncoder': 38 | encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) 39 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': 40 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) 41 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': 42 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) 43 | else: 44 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 45 | return encoder 46 | 47 | def load_weights(self): 48 | if self.opts.checkpoint_path is not None: 49 | print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path)) 50 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 51 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 52 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 53 | self.__load_latent_avg(ckpt) 54 | else: 55 | print('Loading encoders weights from irse50!') 56 | encoder_ckpt = torch.load(model_paths['ir_se50']) 57 | # if input to encoder is not an RGB image, do not load the input layer weights 58 | if self.opts.label_nc != 0: 59 | encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} 60 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 61 | print('Loading decoder weights from pretrained!') 62 | ckpt = torch.load(self.opts.stylegan_weights) 63 | self.decoder.load_state_dict(ckpt['g_ema'], strict=False) 64 | if self.opts.learn_in_w: 65 | self.__load_latent_avg(ckpt, repeat=1) 66 | else: 67 | self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) 68 | 69 | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 70 | inject_latent=None, return_latents=False, alpha=None): 71 | if input_code: 72 | codes = x 73 | else: 74 | codes = self.encoder(x) 75 | # normalize with respect to the center of an average face 76 | if self.opts.start_from_latent_avg: 77 | if self.opts.learn_in_w: 78 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1) 79 | else: 80 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 81 | 82 | 83 | if latent_mask is not None: 84 | for i in latent_mask: 85 | if inject_latent is not None: 86 | if alpha is not None: 87 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 88 | else: 89 | codes[:, i] = inject_latent[:, i] 90 | else: 91 | codes[:, i] = 0 92 | 93 | input_is_latent = not input_code 94 | images, result_latent = self.decoder([codes], 95 | input_is_latent=input_is_latent, 96 | randomize_noise=randomize_noise, 97 | return_latents=return_latents) 98 | 99 | if resize: 100 | images = self.face_pool(images) 101 | 102 | if return_latents: 103 | return images, result_latent 104 | else: 105 | return images 106 | 107 | def set_opts(self, opts): 108 | self.opts = opts 109 | 110 | def __load_latent_avg(self, ckpt, repeat=None): 111 | if 'latent_avg' in ckpt: 112 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 113 | if repeat is not None: 114 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 115 | else: 116 | self.latent_avg = None 117 | -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /notebooks/images/input_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/notebooks/images/input_img.jpg -------------------------------------------------------------------------------- /notebooks/images/input_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/notebooks/images/input_mask.png -------------------------------------------------------------------------------- /notebooks/images/input_sketch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/notebooks/images/input_sketch.jpg -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/options/__init__.py -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 14 | self.parser.add_argument('--data_path', type=str, default='gt_images', help='Path to directory of images to evaluate') 15 | self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side') 16 | self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at 1024x1024') 17 | 18 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 19 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') 20 | 21 | # arguments for style-mixing script 22 | self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') 23 | self.parser.add_argument('--n_outputs_to_generate', type=int, default=5, help='Number of outputs to generate per input image.') 24 | self.parser.add_argument('--mix_alpha', type=float, default=None, help='Alpha value for style-mixing') 25 | self.parser.add_argument('--latent_mask', type=str, default=None, help='Comma-separated list of latents to perform style-mixing with') 26 | 27 | # arguments for super-resolution 28 | self.parser.add_argument('--resize_factors', type=str, default=None, 29 | help='Downsampling factor for super-res (should be a single value for inference).') 30 | 31 | def parse(self): 32 | opts = self.parser.parse_args() 33 | return opts -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from configs.paths_config import model_paths 3 | 4 | 5 | class TrainOptions: 6 | 7 | def __init__(self): 8 | self.parser = ArgumentParser() 9 | self.initialize() 10 | 11 | def initialize(self): 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, help='Type of dataset/experiment to run') 14 | self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') 15 | self.parser.add_argument('--input_nc', default=3, type=int, help='Number of input image channels to the psp encoder') 16 | self.parser.add_argument('--label_nc', default=0, type=int, help='Number of input label channels to the psp encoder') 17 | self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator') 18 | 19 | self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training') 20 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 21 | self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') 22 | self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') 23 | 24 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') 25 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 26 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') 27 | self.parser.add_argument('--start_from_latent_avg', action='store_true', help='Whether to add average latent vector to generate codes from encoder.') 28 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+') 29 | 30 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') 31 | self.parser.add_argument('--id_lambda', default=0, type=float, help='ID loss multiplier factor') 32 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') 33 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor') 34 | self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, help='LPIPS loss multiplier factor for inner image region') 35 | self.parser.add_argument('--l2_lambda_crop', default=0, type=float, help='L2 loss multiplier factor for inner image region') 36 | self.parser.add_argument('--moco_lambda', default=0, type=float, help='Moco-based feature similarity loss multiplier factor') 37 | 38 | self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, help='Path to StyleGAN model weights') 39 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 40 | 41 | self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') 42 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') 43 | self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') 44 | self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') 45 | self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') 46 | 47 | # arguments for weights & biases support 48 | self.parser.add_argument('--use_wandb', action="store_true", help='Whether to use Weights & Biases to track experiment.') 49 | 50 | # arguments for super-resolution 51 | self.parser.add_argument('--resize_factors', type=str, default=None, help='For super-res, comma-separated resize factors to use for inference.') 52 | 53 | def parse(self): 54 | opts = self.parser.parse_args() 55 | return opts 56 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from cog import BasePredictor, Input, Path 4 | import shutil 5 | from argparse import Namespace 6 | import time 7 | import sys 8 | import pprint 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | import torchvision.transforms as transforms 13 | import dlib 14 | 15 | sys.path.append(".") 16 | sys.path.append("..") 17 | 18 | from datasets import augmentations 19 | from utils.common import tensor2im, log_input_image 20 | from models.psp import pSp 21 | from scripts.align_all_parallel import align_face 22 | 23 | 24 | class Predictor(BasePredictor): 25 | def setup(self): 26 | self.predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 27 | model_paths = { 28 | "ffhq_frontalize": "pretrained_models/psp_ffhq_frontalization.pt", 29 | "celebs_sketch_to_face": "pretrained_models/psp_celebs_sketch_to_face.pt", 30 | "celebs_super_resolution": "pretrained_models/psp_celebs_super_resolution.pt", 31 | "toonify": "pretrained_models/psp_ffhq_toonify.pt", 32 | } 33 | 34 | loaded_models = {} 35 | for key, value in model_paths.items(): 36 | loaded_models[key] = torch.load(value, map_location="cpu") 37 | 38 | self.opts = {} 39 | for key, value in loaded_models.items(): 40 | self.opts[key] = value["opts"] 41 | 42 | for key in self.opts.keys(): 43 | self.opts[key]["checkpoint_path"] = model_paths[key] 44 | if "learn_in_w" not in self.opts[key]: 45 | self.opts[key]["learn_in_w"] = False 46 | if "output_size" not in self.opts[key]: 47 | self.opts[key]["output_size"] = 1024 48 | 49 | self.transforms = {} 50 | for key in model_paths.keys(): 51 | if key in ["ffhq_frontalize", "toonify"]: 52 | self.transforms[key] = transforms.Compose( 53 | [ 54 | transforms.Resize((256, 256)), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 57 | ] 58 | ) 59 | elif key == "celebs_sketch_to_face": 60 | self.transforms[key] = transforms.Compose( 61 | [transforms.Resize((256, 256)), transforms.ToTensor()] 62 | ) 63 | elif key == "celebs_super_resolution": 64 | self.transforms[key] = transforms.Compose( 65 | [ 66 | transforms.Resize((256, 256)), 67 | augmentations.BilinearResize(factors=[16]), 68 | transforms.Resize((256, 256)), 69 | transforms.ToTensor(), 70 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 71 | ] 72 | ) 73 | 74 | def predict( 75 | self, 76 | image: Path = Input(description="input image"), 77 | model: str = Input( 78 | choices=[ 79 | "celebs_sketch_to_face", 80 | "ffhq_frontalize", 81 | "celebs_super_resolution", 82 | "toonify", 83 | ], 84 | description="choose model type", 85 | ), 86 | ) -> Path: 87 | opts = self.opts[model] 88 | opts = Namespace(**opts) 89 | pprint.pprint(opts) 90 | 91 | net = pSp(opts) 92 | net.eval() 93 | net.cuda() 94 | print("Model successfully loaded!") 95 | 96 | original_image = Image.open(str(image)) 97 | if opts.label_nc == 0: 98 | original_image = original_image.convert("RGB") 99 | else: 100 | original_image = original_image.convert("L") 101 | original_image.resize( 102 | (self.opts[model]["output_size"], self.opts[model]["output_size"]) 103 | ) 104 | 105 | # Align Image 106 | if model not in ["celebs_sketch_to_face", "celebs_seg_to_face"]: 107 | input_image = self.run_alignment(str(image)) 108 | else: 109 | input_image = original_image 110 | 111 | img_transforms = self.transforms[model] 112 | transformed_image = img_transforms(input_image) 113 | 114 | if model in ["celebs_sketch_to_face", "celebs_seg_to_face"]: 115 | latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 116 | else: 117 | latent_mask = None 118 | 119 | with torch.no_grad(): 120 | result_image = run_on_batch( 121 | transformed_image.unsqueeze(0), net, latent_mask 122 | )[0] 123 | input_vis_image = log_input_image(transformed_image, opts) 124 | output_image = tensor2im(result_image) 125 | 126 | if model == "celebs_super_resolution": 127 | res = np.concatenate( 128 | [ 129 | np.array( 130 | input_vis_image.resize( 131 | ( 132 | self.opts[model]["output_size"], 133 | self.opts[model]["output_size"], 134 | ) 135 | ) 136 | ), 137 | np.array( 138 | output_image.resize( 139 | ( 140 | self.opts[model]["output_size"], 141 | self.opts[model]["output_size"], 142 | ) 143 | ) 144 | ), 145 | ], 146 | axis=1, 147 | ) 148 | else: 149 | res = np.array( 150 | output_image.resize( 151 | (self.opts[model]["output_size"], self.opts[model]["output_size"]) 152 | ) 153 | ) 154 | 155 | out_path = Path(tempfile.mkdtemp()) / "out.png" 156 | Image.fromarray(np.array(res)).save(str(out_path)) 157 | return out_path 158 | 159 | def run_alignment(self, image_path): 160 | aligned_image = align_face(filepath=image_path, predictor=self.predictor) 161 | print("Aligned image has shape: {}".format(aligned_image.size)) 162 | return aligned_image 163 | 164 | 165 | def run_on_batch(inputs, net, latent_mask=None): 166 | if latent_mask is None: 167 | result_batch = net(inputs.to("cuda").float(), randomize_noise=False) 168 | else: 169 | result_batch = [] 170 | for image_idx, input_image in enumerate(inputs): 171 | # get latent vector to inject into our input image 172 | vec_to_inject = np.random.randn(1, 512).astype("float32") 173 | _, latent_to_inject = net( 174 | torch.from_numpy(vec_to_inject).to("cuda"), 175 | input_code=True, 176 | return_latents=True, 177 | ) 178 | # get output image with injected style vector 179 | res = net( 180 | input_image.unsqueeze(0).to("cuda").float(), 181 | latent_mask=latent_mask, 182 | inject_latent=latent_to_inject, 183 | resize=False, 184 | ) 185 | result_batch.append(res) 186 | result_batch = torch.cat(result_batch, dim=0) 187 | return result_batch 188 | -------------------------------------------------------------------------------- /scripts/align_all_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) 3 | author: lzhbrian (https://lzhbrian.me) 4 | date: 2020.1.5 5 | note: code is heavily borrowed from 6 | https://github.com/NVlabs/ffhq-dataset 7 | http://dlib.net/face_landmark_detection.py.html 8 | 9 | requirements: 10 | apt install cmake 11 | conda install Pillow numpy scipy 12 | pip install dlib 13 | # download face landmark model from: 14 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 15 | """ 16 | from argparse import ArgumentParser 17 | import time 18 | import numpy as np 19 | import PIL 20 | import PIL.Image 21 | import os 22 | import scipy 23 | import scipy.ndimage 24 | import dlib 25 | import multiprocessing as mp 26 | import math 27 | 28 | from configs.paths_config import model_paths 29 | SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"] 30 | 31 | 32 | def get_landmark(filepath, predictor): 33 | """get landmark with dlib 34 | :return: np.array shape=(68, 2) 35 | """ 36 | detector = dlib.get_frontal_face_detector() 37 | 38 | img = dlib.load_rgb_image(filepath) 39 | dets = detector(img, 1) 40 | 41 | for k, d in enumerate(dets): 42 | shape = predictor(img, d) 43 | 44 | t = list(shape.parts()) 45 | a = [] 46 | for tt in t: 47 | a.append([tt.x, tt.y]) 48 | lm = np.array(a) 49 | return lm 50 | 51 | 52 | def align_face(filepath, predictor): 53 | """ 54 | :param filepath: str 55 | :return: PIL Image 56 | """ 57 | 58 | lm = get_landmark(filepath, predictor) 59 | 60 | lm_chin = lm[0: 17] # left-right 61 | lm_eyebrow_left = lm[17: 22] # left-right 62 | lm_eyebrow_right = lm[22: 27] # left-right 63 | lm_nose = lm[27: 31] # top-down 64 | lm_nostrils = lm[31: 36] # top-down 65 | lm_eye_left = lm[36: 42] # left-clockwise 66 | lm_eye_right = lm[42: 48] # left-clockwise 67 | lm_mouth_outer = lm[48: 60] # left-clockwise 68 | lm_mouth_inner = lm[60: 68] # left-clockwise 69 | 70 | # Calculate auxiliary vectors. 71 | eye_left = np.mean(lm_eye_left, axis=0) 72 | eye_right = np.mean(lm_eye_right, axis=0) 73 | eye_avg = (eye_left + eye_right) * 0.5 74 | eye_to_eye = eye_right - eye_left 75 | mouth_left = lm_mouth_outer[0] 76 | mouth_right = lm_mouth_outer[6] 77 | mouth_avg = (mouth_left + mouth_right) * 0.5 78 | eye_to_mouth = mouth_avg - eye_avg 79 | 80 | # Choose oriented crop rectangle. 81 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 82 | x /= np.hypot(*x) 83 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 84 | y = np.flipud(x) * [-1, 1] 85 | c = eye_avg + eye_to_mouth * 0.1 86 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 87 | qsize = np.hypot(*x) * 2 88 | 89 | # read image 90 | img = PIL.Image.open(filepath) 91 | 92 | output_size = 256 93 | transform_size = 256 94 | enable_padding = True 95 | 96 | # Shrink. 97 | shrink = int(np.floor(qsize / output_size * 0.5)) 98 | if shrink > 1: 99 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 100 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 101 | quad /= shrink 102 | qsize /= shrink 103 | 104 | # Crop. 105 | border = max(int(np.rint(qsize * 0.1)), 3) 106 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 107 | int(np.ceil(max(quad[:, 1])))) 108 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 109 | min(crop[3] + border, img.size[1])) 110 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 111 | img = img.crop(crop) 112 | quad -= crop[0:2] 113 | 114 | # Pad. 115 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 116 | int(np.ceil(max(quad[:, 1])))) 117 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 118 | max(pad[3] - img.size[1] + border, 0)) 119 | if enable_padding and max(pad) > border - 4: 120 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 121 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 122 | h, w, _ = img.shape 123 | y, x, _ = np.ogrid[:h, :w, :1] 124 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 125 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 126 | blur = qsize * 0.02 127 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 128 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 129 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 130 | quad += pad[:2] 131 | 132 | # Transform. 133 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 134 | if output_size < transform_size: 135 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 136 | 137 | # Save aligned image. 138 | return img 139 | 140 | 141 | def chunks(lst, n): 142 | """Yield successive n-sized chunks from lst.""" 143 | for i in range(0, len(lst), n): 144 | yield lst[i:i + n] 145 | 146 | 147 | def extract_on_paths(file_paths): 148 | predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) 149 | pid = mp.current_process().name 150 | print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths))) 151 | tot_count = len(file_paths) 152 | count = 0 153 | for file_path, res_path in file_paths: 154 | count += 1 155 | if count % 100 == 0: 156 | print('{} done with {}/{}'.format(pid, count, tot_count)) 157 | try: 158 | res = align_face(file_path, predictor) 159 | res = res.convert('RGB') 160 | os.makedirs(os.path.dirname(res_path), exist_ok=True) 161 | res.save(res_path) 162 | except Exception: 163 | continue 164 | print('\tDone!') 165 | 166 | 167 | def parse_args(): 168 | parser = ArgumentParser(add_help=False) 169 | parser.add_argument('--num_threads', type=int, default=1) 170 | parser.add_argument('--root_path', type=str, default='') 171 | args = parser.parse_args() 172 | return args 173 | 174 | 175 | def run(args): 176 | root_path = args.root_path 177 | out_crops_path = root_path + '_crops' 178 | if not os.path.exists(out_crops_path): 179 | os.makedirs(out_crops_path, exist_ok=True) 180 | 181 | file_paths = [] 182 | for root, dirs, files in os.walk(root_path): 183 | for file in files: 184 | file_path = os.path.join(root, file) 185 | fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path)) 186 | res_path = '{}.jpg'.format(os.path.splitext(fname)[0]) 187 | if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path): 188 | continue 189 | file_paths.append((file_path, res_path)) 190 | 191 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 192 | print(len(file_chunks)) 193 | pool = mp.Pool(args.num_threads) 194 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 195 | tic = time.time() 196 | pool.map(extract_on_paths, file_chunks) 197 | toc = time.time() 198 | print('Mischief managed in {}s'.format(toc - tic)) 199 | 200 | 201 | if __name__ == '__main__': 202 | args = parse_args() 203 | run(args) 204 | -------------------------------------------------------------------------------- /scripts/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | import sys 7 | from PIL import Image 8 | import multiprocessing as mp 9 | import math 10 | import torch 11 | import torchvision.transforms as trans 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from models.mtcnn.mtcnn import MTCNN 17 | from models.encoders.model_irse import IR_101 18 | from configs.paths_config import model_paths 19 | CIRCULAR_FACE_PATH = model_paths['circular_face'] 20 | 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | 28 | def extract_on_paths(file_paths): 29 | facenet = IR_101(input_size=112) 30 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 31 | facenet.cuda() 32 | facenet.eval() 33 | mtcnn = MTCNN() 34 | id_transform = trans.Compose([ 35 | trans.ToTensor(), 36 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 37 | ]) 38 | 39 | pid = mp.current_process().name 40 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 41 | tot_count = len(file_paths) 42 | count = 0 43 | 44 | scores_dict = {} 45 | for res_path, gt_path in file_paths: 46 | count += 1 47 | if count % 100 == 0: 48 | print('{} done with {}/{}'.format(pid, count, tot_count)) 49 | if True: 50 | input_im = Image.open(res_path) 51 | input_im, _ = mtcnn.align(input_im) 52 | if input_im is None: 53 | print('{} skipping {}'.format(pid, res_path)) 54 | continue 55 | 56 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 57 | 58 | result_im = Image.open(gt_path) 59 | result_im, _ = mtcnn.align(result_im) 60 | if result_im is None: 61 | print('{} skipping {}'.format(pid, gt_path)) 62 | continue 63 | 64 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 65 | score = float(input_id.dot(result_id)) 66 | scores_dict[os.path.basename(gt_path)] = score 67 | 68 | return scores_dict 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser(add_help=False) 73 | parser.add_argument('--num_threads', type=int, default=4) 74 | parser.add_argument('--data_path', type=str, default='results') 75 | parser.add_argument('--gt_path', type=str, default='gt_images') 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | def run(args): 81 | file_paths = [] 82 | for f in os.listdir(args.data_path): 83 | image_path = os.path.join(args.data_path, f) 84 | gt_path = os.path.join(args.gt_path, f) 85 | if f.endswith(".jpg") or f.endswith('.png'): 86 | file_paths.append([image_path, gt_path.replace('.png','.jpg')]) 87 | 88 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 89 | pool = mp.Pool(args.num_threads) 90 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 91 | 92 | tic = time.time() 93 | results = pool.map(extract_on_paths, file_chunks) 94 | scores_dict = {} 95 | for d in results: 96 | scores_dict.update(d) 97 | 98 | all_scores = list(scores_dict.values()) 99 | mean = np.mean(all_scores) 100 | std = np.std(all_scores) 101 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 102 | print(result_str) 103 | 104 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 105 | if not os.path.exists(out_path): 106 | os.makedirs(out_path) 107 | 108 | with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: 109 | f.write(result_str) 110 | with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: 111 | json.dump(scores_dict, f) 112 | 113 | toc = time.time() 114 | print('Mischief managed in {}s'.format(toc - tic)) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = parse_args() 119 | run(args) 120 | -------------------------------------------------------------------------------- /scripts/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import json 4 | import sys 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from criteria.lpips.lpips import LPIPS 15 | from datasets.gt_res_dataset import GTResDataset 16 | 17 | 18 | def parse_args(): 19 | parser = ArgumentParser(add_help=False) 20 | parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) 21 | parser.add_argument('--data_path', type=str, default='results') 22 | parser.add_argument('--gt_path', type=str, default='gt_images') 23 | parser.add_argument('--workers', type=int, default=4) 24 | parser.add_argument('--batch_size', type=int, default=4) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def run(args): 30 | 31 | transform = transforms.Compose([transforms.Resize((256, 256)), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 34 | 35 | print('Loading dataset') 36 | dataset = GTResDataset(root_path=args.data_path, 37 | gt_dir=args.gt_path, 38 | transform=transform) 39 | 40 | dataloader = DataLoader(dataset, 41 | batch_size=args.batch_size, 42 | shuffle=False, 43 | num_workers=int(args.workers), 44 | drop_last=True) 45 | 46 | if args.mode == 'lpips': 47 | loss_func = LPIPS(net_type='alex') 48 | elif args.mode == 'l2': 49 | loss_func = torch.nn.MSELoss() 50 | else: 51 | raise Exception('Not a valid mode!') 52 | loss_func.cuda() 53 | 54 | global_i = 0 55 | scores_dict = {} 56 | all_scores = [] 57 | for result_batch, gt_batch in tqdm(dataloader): 58 | for i in range(args.batch_size): 59 | loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda())) 60 | all_scores.append(loss) 61 | im_path = dataset.pairs[global_i][0] 62 | scores_dict[os.path.basename(im_path)] = loss 63 | global_i += 1 64 | 65 | all_scores = list(scores_dict.values()) 66 | mean = np.mean(all_scores) 67 | std = np.std(all_scores) 68 | result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) 69 | print('Finished with ', args.data_path) 70 | print(result_str) 71 | 72 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 73 | if not os.path.exists(out_path): 74 | os.makedirs(out_path) 75 | 76 | with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: 77 | f.write(result_str) 78 | with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: 79 | json.dump(scores_dict, f) 80 | 81 | 82 | if __name__ == '__main__': 83 | args = parse_args() 84 | run(args) 85 | -------------------------------------------------------------------------------- /scripts/generate_sketch_data.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.utils import save_image 3 | from torch.utils.serialization import load_lua 4 | import os 5 | import cv2 6 | import numpy as np 7 | 8 | """ 9 | NOTE!: Must have torch==0.4.1 and torchvision==0.2.1 10 | The sketch simplification model (sketch_gan.t7) from Simo Serra et al. can be downloaded from their official implementation: 11 | https://github.com/bobbens/sketch_simplification 12 | """ 13 | 14 | 15 | def sobel(img): 16 | opImgx = cv2.Sobel(img, cv2.CV_8U, 0, 1, ksize=3) 17 | opImgy = cv2.Sobel(img, cv2.CV_8U, 1, 0, ksize=3) 18 | return cv2.bitwise_or(opImgx, opImgy) 19 | 20 | 21 | def sketch(frame): 22 | frame = cv2.GaussianBlur(frame, (3, 3), 0) 23 | invImg = 255 - frame 24 | edgImg0 = sobel(frame) 25 | edgImg1 = sobel(invImg) 26 | edgImg = cv2.addWeighted(edgImg0, 0.75, edgImg1, 0.75, 0) 27 | opImg = 255 - edgImg 28 | return opImg 29 | 30 | 31 | def get_sketch_image(image_path): 32 | original = cv2.imread(image_path) 33 | original = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY) 34 | sketch_image = sketch(original) 35 | return sketch_image[:, :, np.newaxis] 36 | 37 | 38 | use_cuda = True 39 | 40 | cache = load_lua("/path/to/sketch_gan.t7") 41 | model = cache.model 42 | immean = cache.mean 43 | imstd = cache.std 44 | model.evaluate() 45 | 46 | data_path = "/path/to/data/imgs" 47 | images = [os.path.join(data_path, f) for f in os.listdir(data_path)] 48 | 49 | output_dir = "/path/to/data/edges" 50 | if not os.path.exists(output_dir): 51 | os.makedirs(output_dir) 52 | 53 | for idx, image_path in enumerate(images): 54 | if idx % 50 == 0: 55 | print("{} out of {}".format(idx, len(images))) 56 | data = get_sketch_image(image_path) 57 | data = ((transforms.ToTensor()(data) - immean) / imstd).unsqueeze(0) 58 | if use_cuda: 59 | pred = model.cuda().forward(data.cuda()).float() 60 | else: 61 | pred = model.forward(data) 62 | save_image(pred[0], os.path.join(output_dir, "{}_edges.jpg".format(image_path.split("/")[-1].split('.')[0]))) 63 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | from tqdm import tqdm 5 | import time 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | import sys 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from configs import data_configs 16 | from datasets.inference_dataset import InferenceDataset 17 | from utils.common import tensor2im, log_input_image 18 | from options.test_options import TestOptions 19 | from models.psp import pSp 20 | 21 | 22 | def run(): 23 | test_opts = TestOptions().parse() 24 | 25 | if test_opts.resize_factors is not None: 26 | assert len( 27 | test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!" 28 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results', 29 | 'downsampling_{}'.format(test_opts.resize_factors)) 30 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled', 31 | 'downsampling_{}'.format(test_opts.resize_factors)) 32 | else: 33 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 34 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') 35 | 36 | os.makedirs(out_path_results, exist_ok=True) 37 | os.makedirs(out_path_coupled, exist_ok=True) 38 | 39 | # update test options with options used during training 40 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 41 | opts = ckpt['opts'] 42 | opts.update(vars(test_opts)) 43 | if 'learn_in_w' not in opts: 44 | opts['learn_in_w'] = False 45 | if 'output_size' not in opts: 46 | opts['output_size'] = 1024 47 | opts = Namespace(**opts) 48 | 49 | net = pSp(opts) 50 | net.eval() 51 | net.cuda() 52 | 53 | print('Loading dataset for {}'.format(opts.dataset_type)) 54 | dataset_args = data_configs.DATASETS[opts.dataset_type] 55 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 56 | dataset = InferenceDataset(root=opts.data_path, 57 | transform=transforms_dict['transform_inference'], 58 | opts=opts) 59 | dataloader = DataLoader(dataset, 60 | batch_size=opts.test_batch_size, 61 | shuffle=False, 62 | num_workers=int(opts.test_workers), 63 | drop_last=True) 64 | 65 | if opts.n_images is None: 66 | opts.n_images = len(dataset) 67 | 68 | global_i = 0 69 | global_time = [] 70 | for input_batch in tqdm(dataloader): 71 | if global_i >= opts.n_images: 72 | break 73 | with torch.no_grad(): 74 | input_cuda = input_batch.cuda().float() 75 | tic = time.time() 76 | result_batch = run_on_batch(input_cuda, net, opts) 77 | toc = time.time() 78 | global_time.append(toc - tic) 79 | 80 | for i in range(opts.test_batch_size): 81 | result = tensor2im(result_batch[i]) 82 | im_path = dataset.paths[global_i] 83 | 84 | if opts.couple_outputs or global_i % 100 == 0: 85 | input_im = log_input_image(input_batch[i], opts) 86 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 87 | if opts.resize_factors is not None: 88 | # for super resolution, save the original, down-sampled, and output 89 | source = Image.open(im_path) 90 | res = np.concatenate([np.array(source.resize(resize_amount)), 91 | np.array(input_im.resize(resize_amount, resample=Image.NEAREST)), 92 | np.array(result.resize(resize_amount))], axis=1) 93 | else: 94 | # otherwise, save the original and output 95 | res = np.concatenate([np.array(input_im.resize(resize_amount)), 96 | np.array(result.resize(resize_amount))], axis=1) 97 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 98 | 99 | im_save_path = os.path.join(out_path_results, os.path.basename(im_path)) 100 | Image.fromarray(np.array(result)).save(im_save_path) 101 | 102 | global_i += 1 103 | 104 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 105 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 106 | print(result_str) 107 | 108 | with open(stats_path, 'w') as f: 109 | f.write(result_str) 110 | 111 | 112 | def run_on_batch(inputs, net, opts): 113 | if opts.latent_mask is None: 114 | result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs) 115 | else: 116 | latent_mask = [int(l) for l in opts.latent_mask.split(",")] 117 | result_batch = [] 118 | for image_idx, input_image in enumerate(inputs): 119 | # get latent vector to inject into our input image 120 | vec_to_inject = np.random.randn(1, 512).astype('float32') 121 | _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"), 122 | input_code=True, 123 | return_latents=True) 124 | # get output image with injected style vector 125 | res = net(input_image.unsqueeze(0).to("cuda").float(), 126 | latent_mask=latent_mask, 127 | inject_latent=latent_to_inject, 128 | alpha=opts.mix_alpha, 129 | resize=opts.resize_outputs) 130 | result_batch.append(res) 131 | result_batch = torch.cat(result_batch, dim=0) 132 | return result_batch 133 | 134 | 135 | if __name__ == '__main__': 136 | run() 137 | -------------------------------------------------------------------------------- /scripts/style_mixing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import sys 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs import data_configs 15 | from datasets.inference_dataset import InferenceDataset 16 | from utils.common import tensor2im, log_input_image 17 | from options.test_options import TestOptions 18 | from models.psp import pSp 19 | 20 | 21 | def run(): 22 | test_opts = TestOptions().parse() 23 | 24 | if test_opts.resize_factors is not None: 25 | factors = test_opts.resize_factors.split(',') 26 | assert len(factors) == 1, "When running inference, please provide a single downsampling factor!" 27 | mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing', 28 | 'downsampling_{}'.format(test_opts.resize_factors)) 29 | else: 30 | mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing') 31 | os.makedirs(mixed_path_results, exist_ok=True) 32 | 33 | # update test options with options used during training 34 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 35 | opts = ckpt['opts'] 36 | opts.update(vars(test_opts)) 37 | if 'learn_in_w' not in opts: 38 | opts['learn_in_w'] = False 39 | if 'output_size' not in opts: 40 | opts['output_size'] = 1024 41 | opts = Namespace(**opts) 42 | 43 | net = pSp(opts) 44 | net.eval() 45 | net.cuda() 46 | 47 | print('Loading dataset for {}'.format(opts.dataset_type)) 48 | dataset_args = data_configs.DATASETS[opts.dataset_type] 49 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 50 | dataset = InferenceDataset(root=opts.data_path, 51 | transform=transforms_dict['transform_inference'], 52 | opts=opts) 53 | dataloader = DataLoader(dataset, 54 | batch_size=opts.test_batch_size, 55 | shuffle=False, 56 | num_workers=int(opts.test_workers), 57 | drop_last=True) 58 | 59 | latent_mask = [int(l) for l in opts.latent_mask.split(",")] 60 | if opts.n_images is None: 61 | opts.n_images = len(dataset) 62 | 63 | global_i = 0 64 | for input_batch in tqdm(dataloader): 65 | if global_i >= opts.n_images: 66 | break 67 | with torch.no_grad(): 68 | input_batch = input_batch.cuda() 69 | for image_idx, input_image in enumerate(input_batch): 70 | # generate random vectors to inject into input image 71 | vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32') 72 | multi_modal_outputs = [] 73 | for vec_to_inject in vecs_to_inject: 74 | cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda") 75 | # get latent vector to inject into our input image 76 | _, latent_to_inject = net(cur_vec, 77 | input_code=True, 78 | return_latents=True) 79 | # get output image with injected style vector 80 | res = net(input_image.unsqueeze(0).to("cuda").float(), 81 | latent_mask=latent_mask, 82 | inject_latent=latent_to_inject, 83 | alpha=opts.mix_alpha, 84 | resize=opts.resize_outputs) 85 | multi_modal_outputs.append(res[0]) 86 | 87 | # visualize multi modal outputs 88 | input_im_path = dataset.paths[global_i] 89 | image = input_batch[image_idx] 90 | input_image = log_input_image(image, opts) 91 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 92 | res = np.array(input_image.resize(resize_amount)) 93 | for output in multi_modal_outputs: 94 | output = tensor2im(output) 95 | res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1) 96 | Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path))) 97 | global_i += 1 98 | 99 | 100 | if __name__ == '__main__': 101 | run() 102 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import sys 7 | import pprint 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from options.train_options import TrainOptions 13 | from training.coach import Coach 14 | 15 | 16 | def main(): 17 | opts = TrainOptions().parse() 18 | if os.path.exists(opts.exp_dir): 19 | raise Exception('Oops... {} already exists'.format(opts.exp_dir)) 20 | os.makedirs(opts.exp_dir) 21 | 22 | opts_dict = vars(opts) 23 | pprint.pprint(opts_dict) 24 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 25 | json.dump(opts_dict, f, indent=4, sort_keys=True) 26 | 27 | coach = Coach(opts) 28 | coach.train() 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/training/__init__.py -------------------------------------------------------------------------------- /training/coach.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | matplotlib.use('Agg') 6 | 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.nn.functional as F 12 | 13 | from utils import common, train_utils 14 | from criteria import id_loss, w_norm, moco_loss 15 | from configs import data_configs 16 | from datasets.images_dataset import ImagesDataset 17 | from criteria.lpips.lpips import LPIPS 18 | from models.psp import pSp 19 | from training.ranger import Ranger 20 | 21 | 22 | class Coach: 23 | def __init__(self, opts): 24 | self.opts = opts 25 | 26 | self.global_step = 0 27 | 28 | self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES 29 | self.opts.device = self.device 30 | 31 | if self.opts.use_wandb: 32 | from utils.wandb_utils import WBLogger 33 | self.wb_logger = WBLogger(self.opts) 34 | 35 | # Initialize network 36 | self.net = pSp(self.opts).to(self.device) 37 | 38 | # Estimate latent_avg via dense sampling if latent_avg is not available 39 | if self.net.latent_avg is None: 40 | self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach() 41 | 42 | # Initialize loss 43 | if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0: 44 | raise ValueError('Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!') 45 | 46 | self.mse_loss = nn.MSELoss().to(self.device).eval() 47 | if self.opts.lpips_lambda > 0: 48 | self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() 49 | if self.opts.id_lambda > 0: 50 | self.id_loss = id_loss.IDLoss().to(self.device).eval() 51 | if self.opts.w_norm_lambda > 0: 52 | self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg) 53 | if self.opts.moco_lambda > 0: 54 | self.moco_loss = moco_loss.MocoLoss().to(self.device).eval() 55 | 56 | # Initialize optimizer 57 | self.optimizer = self.configure_optimizers() 58 | 59 | # Initialize dataset 60 | self.train_dataset, self.test_dataset = self.configure_datasets() 61 | self.train_dataloader = DataLoader(self.train_dataset, 62 | batch_size=self.opts.batch_size, 63 | shuffle=True, 64 | num_workers=int(self.opts.workers), 65 | drop_last=True) 66 | self.test_dataloader = DataLoader(self.test_dataset, 67 | batch_size=self.opts.test_batch_size, 68 | shuffle=False, 69 | num_workers=int(self.opts.test_workers), 70 | drop_last=True) 71 | 72 | # Initialize logger 73 | log_dir = os.path.join(opts.exp_dir, 'logs') 74 | os.makedirs(log_dir, exist_ok=True) 75 | self.logger = SummaryWriter(log_dir=log_dir) 76 | 77 | # Initialize checkpoint dir 78 | self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') 79 | os.makedirs(self.checkpoint_dir, exist_ok=True) 80 | self.best_val_loss = None 81 | if self.opts.save_interval is None: 82 | self.opts.save_interval = self.opts.max_steps 83 | 84 | def train(self): 85 | self.net.train() 86 | while self.global_step < self.opts.max_steps: 87 | for batch_idx, batch in enumerate(self.train_dataloader): 88 | self.optimizer.zero_grad() 89 | x, y = batch 90 | x, y = x.to(self.device).float(), y.to(self.device).float() 91 | y_hat, latent = self.net.forward(x, return_latents=True) 92 | loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) 93 | loss.backward() 94 | self.optimizer.step() 95 | 96 | # Logging related 97 | if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0): 98 | self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces') 99 | if self.global_step % self.opts.board_interval == 0: 100 | self.print_metrics(loss_dict, prefix='train') 101 | self.log_metrics(loss_dict, prefix='train') 102 | 103 | # Log images of first batch to wandb 104 | if self.opts.use_wandb and batch_idx == 0: 105 | self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="train", step=self.global_step, opts=self.opts) 106 | 107 | # Validation related 108 | val_loss_dict = None 109 | if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: 110 | val_loss_dict = self.validate() 111 | if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): 112 | self.best_val_loss = val_loss_dict['loss'] 113 | self.checkpoint_me(val_loss_dict, is_best=True) 114 | 115 | if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: 116 | if val_loss_dict is not None: 117 | self.checkpoint_me(val_loss_dict, is_best=False) 118 | else: 119 | self.checkpoint_me(loss_dict, is_best=False) 120 | 121 | if self.global_step == self.opts.max_steps: 122 | print('OMG, finished training!') 123 | break 124 | 125 | self.global_step += 1 126 | 127 | def validate(self): 128 | self.net.eval() 129 | agg_loss_dict = [] 130 | for batch_idx, batch in enumerate(self.test_dataloader): 131 | x, y = batch 132 | 133 | with torch.no_grad(): 134 | x, y = x.to(self.device).float(), y.to(self.device).float() 135 | y_hat, latent = self.net.forward(x, return_latents=True) 136 | loss, cur_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) 137 | agg_loss_dict.append(cur_loss_dict) 138 | 139 | # Logging related 140 | self.parse_and_log_images(id_logs, x, y, y_hat, 141 | title='images/test/faces', 142 | subscript='{:04d}'.format(batch_idx)) 143 | 144 | # Log images of first batch to wandb 145 | if self.opts.use_wandb and batch_idx == 0: 146 | self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="test", step=self.global_step, opts=self.opts) 147 | 148 | # For first step just do sanity test on small amount of data 149 | if self.global_step == 0 and batch_idx >= 4: 150 | self.net.train() 151 | return None # Do not log, inaccurate in first batch 152 | 153 | loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) 154 | self.log_metrics(loss_dict, prefix='test') 155 | self.print_metrics(loss_dict, prefix='test') 156 | 157 | self.net.train() 158 | return loss_dict 159 | 160 | def checkpoint_me(self, loss_dict, is_best): 161 | save_name = 'best_model.pt' if is_best else f'iteration_{self.global_step}.pt' 162 | save_dict = self.__get_save_dict() 163 | checkpoint_path = os.path.join(self.checkpoint_dir, save_name) 164 | torch.save(save_dict, checkpoint_path) 165 | with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: 166 | if is_best: 167 | f.write(f'**Best**: Step - {self.global_step}, Loss - {self.best_val_loss} \n{loss_dict}\n') 168 | if self.opts.use_wandb: 169 | self.wb_logger.log_best_model() 170 | else: 171 | f.write(f'Step - {self.global_step}, \n{loss_dict}\n') 172 | 173 | def configure_optimizers(self): 174 | params = list(self.net.encoder.parameters()) 175 | if self.opts.train_decoder: 176 | params += list(self.net.decoder.parameters()) 177 | if self.opts.optim_name == 'adam': 178 | optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) 179 | else: 180 | optimizer = Ranger(params, lr=self.opts.learning_rate) 181 | return optimizer 182 | 183 | def configure_datasets(self): 184 | if self.opts.dataset_type not in data_configs.DATASETS.keys(): 185 | Exception(f'{self.opts.dataset_type} is not a valid dataset_type') 186 | print(f'Loading dataset for {self.opts.dataset_type}') 187 | dataset_args = data_configs.DATASETS[self.opts.dataset_type] 188 | transforms_dict = dataset_args['transforms'](self.opts).get_transforms() 189 | train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], 190 | target_root=dataset_args['train_target_root'], 191 | source_transform=transforms_dict['transform_source'], 192 | target_transform=transforms_dict['transform_gt_train'], 193 | opts=self.opts) 194 | test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], 195 | target_root=dataset_args['test_target_root'], 196 | source_transform=transforms_dict['transform_source'], 197 | target_transform=transforms_dict['transform_test'], 198 | opts=self.opts) 199 | if self.opts.use_wandb: 200 | self.wb_logger.log_dataset_wandb(train_dataset, dataset_name="Train") 201 | self.wb_logger.log_dataset_wandb(test_dataset, dataset_name="Test") 202 | print(f"Number of training samples: {len(train_dataset)}") 203 | print(f"Number of test samples: {len(test_dataset)}") 204 | return train_dataset, test_dataset 205 | 206 | def calc_loss(self, x, y, y_hat, latent): 207 | loss_dict = {} 208 | loss = 0.0 209 | id_logs = None 210 | if self.opts.id_lambda > 0: 211 | loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x) 212 | loss_dict['loss_id'] = float(loss_id) 213 | loss_dict['id_improve'] = float(sim_improvement) 214 | loss = loss_id * self.opts.id_lambda 215 | if self.opts.l2_lambda > 0: 216 | loss_l2 = F.mse_loss(y_hat, y) 217 | loss_dict['loss_l2'] = float(loss_l2) 218 | loss += loss_l2 * self.opts.l2_lambda 219 | if self.opts.lpips_lambda > 0: 220 | loss_lpips = self.lpips_loss(y_hat, y) 221 | loss_dict['loss_lpips'] = float(loss_lpips) 222 | loss += loss_lpips * self.opts.lpips_lambda 223 | if self.opts.lpips_lambda_crop > 0: 224 | loss_lpips_crop = self.lpips_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220]) 225 | loss_dict['loss_lpips_crop'] = float(loss_lpips_crop) 226 | loss += loss_lpips_crop * self.opts.lpips_lambda_crop 227 | if self.opts.l2_lambda_crop > 0: 228 | loss_l2_crop = F.mse_loss(y_hat[:, :, 35:223, 32:220], y[:, :, 35:223, 32:220]) 229 | loss_dict['loss_l2_crop'] = float(loss_l2_crop) 230 | loss += loss_l2_crop * self.opts.l2_lambda_crop 231 | if self.opts.w_norm_lambda > 0: 232 | loss_w_norm = self.w_norm_loss(latent, self.net.latent_avg) 233 | loss_dict['loss_w_norm'] = float(loss_w_norm) 234 | loss += loss_w_norm * self.opts.w_norm_lambda 235 | if self.opts.moco_lambda > 0: 236 | loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x) 237 | loss_dict['loss_moco'] = float(loss_moco) 238 | loss_dict['id_improve'] = float(sim_improvement) 239 | loss += loss_moco * self.opts.moco_lambda 240 | 241 | loss_dict['loss'] = float(loss) 242 | return loss, loss_dict, id_logs 243 | 244 | def log_metrics(self, metrics_dict, prefix): 245 | for key, value in metrics_dict.items(): 246 | self.logger.add_scalar(f'{prefix}/{key}', value, self.global_step) 247 | if self.opts.use_wandb: 248 | self.wb_logger.log(prefix, metrics_dict, self.global_step) 249 | 250 | def print_metrics(self, metrics_dict, prefix): 251 | print(f'Metrics for {prefix}, step {self.global_step}') 252 | for key, value in metrics_dict.items(): 253 | print(f'\t{key} = ', value) 254 | 255 | def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2): 256 | im_data = [] 257 | for i in range(display_count): 258 | cur_im_data = { 259 | 'input_face': common.log_input_image(x[i], self.opts), 260 | 'target_face': common.tensor2im(y[i]), 261 | 'output_face': common.tensor2im(y_hat[i]), 262 | } 263 | if id_logs is not None: 264 | for key in id_logs[i]: 265 | cur_im_data[key] = id_logs[i][key] 266 | im_data.append(cur_im_data) 267 | self.log_images(title, im_data=im_data, subscript=subscript) 268 | 269 | def log_images(self, name, im_data, subscript=None, log_latest=False): 270 | fig = common.vis_faces(im_data) 271 | step = self.global_step 272 | if log_latest: 273 | step = 0 274 | if subscript: 275 | path = os.path.join(self.logger.log_dir, name, f'{subscript}_{step:04d}.jpg') 276 | else: 277 | path = os.path.join(self.logger.log_dir, name, f'{step:04d}.jpg') 278 | os.makedirs(os.path.dirname(path), exist_ok=True) 279 | fig.savefig(path) 280 | plt.close(fig) 281 | 282 | def __get_save_dict(self): 283 | save_dict = { 284 | 'state_dict': self.net.state_dict(), 285 | 'opts': vars(self.opts) 286 | } 287 | # save the latent avg in state_dict for inference if truncation of w was used during training 288 | if self.opts.start_from_latent_avg: 289 | save_dict['latent_avg'] = self.net.latent_avg 290 | return save_dict 291 | -------------------------------------------------------------------------------- /training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 152 | else: 153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pixel2style2pixel/5cfff385beb7b95fbce775662b48fcc80081928d/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # Log images 8 | def log_input_image(x, opts): 9 | if opts.label_nc == 0: 10 | return tensor2im(x) 11 | elif opts.label_nc == 1: 12 | return tensor2sketch(x) 13 | else: 14 | return tensor2map(x) 15 | 16 | 17 | def tensor2im(var): 18 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 19 | var = ((var + 1) / 2) 20 | var[var < 0] = 0 21 | var[var > 1] = 1 22 | var = var * 255 23 | return Image.fromarray(var.astype('uint8')) 24 | 25 | 26 | def tensor2map(var): 27 | mask = np.argmax(var.data.cpu().numpy(), axis=0) 28 | colors = get_colors() 29 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) 30 | for class_idx in np.unique(mask): 31 | mask_image[mask == class_idx] = colors[class_idx] 32 | mask_image = mask_image.astype('uint8') 33 | return Image.fromarray(mask_image) 34 | 35 | 36 | def tensor2sketch(var): 37 | im = var[0].cpu().detach().numpy() 38 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) 39 | im = (im * 255).astype(np.uint8) 40 | return Image.fromarray(im) 41 | 42 | 43 | # Visualization utils 44 | def get_colors(): 45 | # currently support up to 19 classes (for the celebs-hq-mask dataset) 46 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 47 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 48 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 49 | return colors 50 | 51 | 52 | def vis_faces(log_hooks): 53 | display_count = len(log_hooks) 54 | fig = plt.figure(figsize=(8, 4 * display_count)) 55 | gs = fig.add_gridspec(display_count, 3) 56 | for i in range(display_count): 57 | hooks_dict = log_hooks[i] 58 | fig.add_subplot(gs[i, 0]) 59 | if 'diff_input' in hooks_dict: 60 | vis_faces_with_id(hooks_dict, fig, gs, i) 61 | else: 62 | vis_faces_no_id(hooks_dict, fig, gs, i) 63 | plt.tight_layout() 64 | return fig 65 | 66 | 67 | def vis_faces_with_id(hooks_dict, fig, gs, i): 68 | plt.imshow(hooks_dict['input_face']) 69 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 70 | fig.add_subplot(gs[i, 1]) 71 | plt.imshow(hooks_dict['target_face']) 72 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 73 | float(hooks_dict['diff_target']))) 74 | fig.add_subplot(gs[i, 2]) 75 | plt.imshow(hooks_dict['output_face']) 76 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 77 | 78 | 79 | def vis_faces_no_id(hooks_dict, fig, gs, i): 80 | plt.imshow(hooks_dict['input_face'], cmap="gray") 81 | plt.title('Input') 82 | fig.add_subplot(gs[i, 1]) 83 | plt.imshow(hooks_dict['target_face']) 84 | plt.title('Target') 85 | fig.add_subplot(gs[i, 2]) 86 | plt.imshow(hooks_dict['output_face']) 87 | plt.title('Output') 88 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | for root, _, fnames in sorted(os.walk(dir)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | path = os.path.join(root, fname) 24 | images.append(path) 25 | return images 26 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def aggregate_loss_dict(agg_loss_dict): 3 | mean_vals = {} 4 | for output in agg_loss_dict: 5 | for key in output: 6 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 7 | for key in mean_vals: 8 | if len(mean_vals[key]) > 0: 9 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 10 | else: 11 | print('{} has no value'.format(key)) 12 | mean_vals[key] = 0 13 | return mean_vals 14 | -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import numpy as np 4 | import wandb 5 | 6 | from utils import common 7 | 8 | 9 | class WBLogger: 10 | 11 | def __init__(self, opts): 12 | wandb_run_name = os.path.basename(opts.exp_dir) 13 | wandb.init(project="pixel2style2pixel", config=vars(opts), name=wandb_run_name) 14 | 15 | @staticmethod 16 | def log_best_model(): 17 | wandb.run.summary["best-model-save-time"] = datetime.datetime.now() 18 | 19 | @staticmethod 20 | def log(prefix, metrics_dict, global_step): 21 | log_dict = {f'{prefix}_{key}': value for key, value in metrics_dict.items()} 22 | log_dict["global_step"] = global_step 23 | wandb.log(log_dict) 24 | 25 | @staticmethod 26 | def log_dataset_wandb(dataset, dataset_name, n_images=16): 27 | idxs = np.random.choice(a=range(len(dataset)), size=n_images, replace=False) 28 | data = [wandb.Image(dataset.source_paths[idx]) for idx in idxs] 29 | wandb.log({f"{dataset_name} Data Samples": data}) 30 | 31 | @staticmethod 32 | def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts): 33 | im_data = [] 34 | column_names = ["Source", "Target", "Output"] 35 | if id_logs is not None: 36 | column_names.append("ID Diff Output to Target") 37 | for i in range(len(x)): 38 | cur_im_data = [ 39 | wandb.Image(common.log_input_image(x[i], opts)), 40 | wandb.Image(common.tensor2im(y[i])), 41 | wandb.Image(common.tensor2im(y_hat[i])), 42 | ] 43 | if id_logs is not None: 44 | cur_im_data.append(id_logs[i]["diff_target"]) 45 | im_data.append(cur_im_data) 46 | outputs_table = wandb.Table(data=im_data, columns=column_names) 47 | wandb.log({f"{prefix.title()} Step {step} Output Samples": outputs_table}) 48 | --------------------------------------------------------------------------------