├── .gitignore ├── LICENSE ├── README.md ├── criteria ├── __init__.py ├── all_loss.py ├── id_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py └── moco_loss.py ├── data_function.py ├── distributed.py ├── hparams.py ├── infer.py ├── metrics ├── cal_sliced_wasserstein.py ├── fideity_loss_calc.py ├── id_loss_calc.py ├── img_loss_calc.py └── sliced_wasserstein.py ├── models ├── __init__.py ├── fpn_encoders.py ├── helpers.py ├── map2style.py ├── model_irse.py └── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ ├── __init__.py │ └── src │ ├── __init__.py │ ├── align_trans.py │ ├── box_utils.py │ ├── detector.py │ ├── first_stage.py │ ├── get_nets.py │ ├── matlab_cp2tform.py │ ├── visualization_utils.py │ └── weights │ ├── onet.npy │ ├── pnet.npy │ └── rnet.npy ├── optimizer ├── LookAhead.py ├── RAdam.py ├── Ranger.py └── __init__.py ├── pre_trained_model ├── __init__.py ├── helpers.py └── model_irse.py ├── stylegan2 ├── __init__.py ├── model.py ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu └── stylegan2_infer.py ├── train.py ├── train_ddp.py ├── transforms_config ├── car_transforms.py └── normal_transforms.py └── weights_init └── weight_init_normal.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kangneng Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Framework of GAN Inversion 2 | 3 | ## Introcuction 4 | * You can implement your own inversion idea using our repo. We offer a full range of tuning settings (in hparams.py), some excellent backbones and classics loss functions. You can modify the arch of network or loss easily. 5 | 6 | ## Recent Updates 7 | * 2021.9.1 The simplfied framework of GAN Inversion is released. 8 | 9 | ## Requirements 10 | * pip install git+git://github.com/lehduong/torch-warmup-lr.git 11 | * PyTorch1.7 12 | 13 | ## Train 14 | ### Without DDP 15 | `python train.py` 16 | ### With DDP 17 | `python -m torch.distributed.launch --nproc_per_node=nums_gpus train.py` 18 | ## Done 19 | ### Tuning Setting 20 | * Apply_init 21 | * Optimizer_mode 22 | * Scheduler_mode 23 | * Open_warn_up 24 | 25 | ### Backbone 26 | * GradualStyleEncoder from [restyle-encoder](https://github.com/yuval-alaluf/restyle-encoder) and [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) 27 | * ResNetGradualStyleEncoder from [restyle-encoder](https://github.com/yuval-alaluf/restyle-encoder) and [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) 28 | 29 | 30 | ### Loss 31 | * MSE 32 | * LPIPS 33 | * ID loss from [restyle-encoder](https://github.com/yuval-alaluf/restyle-encoder) and [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) 34 | * Moco loss from [restyle-encoder](https://github.com/yuval-alaluf/restyle-encoder) and [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) 35 | 36 | 37 | ## TODO 38 | - [x] DDP 39 | - [ ] More Backbones 40 | - [ ] Metrics 41 | 42 | ## Acknowledgements 43 | This repository is an unoffical PyTorch Framework of GAN Inversion and highly based on [restyle-encoder](https://github.com/yuval-alaluf/restyle-encoder), [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), [AliProducts](https://github.com/misads/AliProducts) and [stylegan2](https://github.com/NVlabs/stylegan2). Jointly developed with [Prof. Shuang Song](ssong@ustb.edu.cn). Thank you for the above repo. Thank you to [Daiheng Gao](https://github.com/tomguluson92) and [Jie Zhang](https://scholar.google.com.hk/citations?user=gBkYZeMAAAAJ) for all the help I received. 44 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/all_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | import numpy as np 6 | from torchvision import transforms 7 | from torchvision import models 8 | from torch.autograd import Variable 9 | from PIL import Image 10 | from torchvision import utils 11 | from torch.autograd import Variable 12 | import torch.autograd as autograd 13 | from hparams import hparams as hp 14 | from criteria import id_loss, moco_loss 15 | from criteria.lpips.lpips import LPIPS 16 | 17 | class Base_Loss(nn.Module): 18 | def __init__(self): 19 | super(Base_Loss,self).__init__() 20 | import lpips 21 | # self.loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores 22 | self.loss_fn_vgg = LPIPS(net_type='alex').cuda().eval() 23 | 24 | if hp.dataset_type == 'car': 25 | self.moco_loss = moco_loss.MocoLoss() 26 | else: 27 | self.id_loss = id_loss.IDLoss().cuda().eval() 28 | 29 | 30 | 31 | self.criterion_mse = nn.MSELoss() 32 | 33 | 34 | def forward(self, gt,predict_images): 35 | 36 | loss_mse = self.criterion_mse(gt, predict_images) 37 | loss_lpips = self.loss_fn_vgg(gt,predict_images) 38 | 39 | if hp.dataset_type == 'car': 40 | loss_per = self.moco_loss(predict_images,gt,gt)[0] 41 | else: 42 | loss_per = self.id_loss(predict_images,gt,gt)[0] 43 | 44 | loss_all = hp.loss_lambda_mse*loss_mse + hp.loss_lambda_lpips*loss_lpips + hp.loss_lambda_id*loss_per 45 | 46 | return loss_all,loss_mse,loss_lpips,loss_per 47 | -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from hparams import hparams as hp 4 | from pre_trained_model.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self): 9 | arc_model_path = hp.arc_model_path 10 | super(IDLoss, self).__init__() 11 | print('Loading ResNet ArcFace') 12 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 13 | self.facenet.load_state_dict(torch.load(arc_model_path)) 14 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 15 | self.facenet.eval() 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def forward(self, y_hat, y, x): 24 | n_samples = x.shape[0] 25 | x_feats = self.extract_feats(x) 26 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 27 | y_hat_feats = self.extract_feats(y_hat) 28 | y_feats = y_feats.detach() 29 | loss = 0 30 | sim_improvement = 0 31 | id_logs = [] 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | diff_input = y_hat_feats[i].dot(x_feats[i]) 36 | diff_views = y_feats[i].dot(x_feats[i]) 37 | id_logs.append({'diff_target': float(diff_target), 38 | 'diff_input': float(diff_input), 39 | 'diff_views': float(diff_views)}) 40 | loss += 1 - diff_target 41 | id_diff = float(diff_target) - float(diff_views) 42 | sim_improvement += id_diff 43 | count += 1 44 | 45 | return loss / count, sim_improvement / count, id_logs -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2+1e-8, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from hparams import hparams as hp 5 | 6 | class MocoLoss(nn.Module): 7 | 8 | def __init__(self): 9 | super(MocoLoss, self).__init__() 10 | moco_model_path = hp.moco_model_path 11 | print("Loading MOCO model from path: {}".format(moco_model_path)) 12 | self.model = self.__load_model(moco_model_path) 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(moco_model_path): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(moco_model_path, map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs -------------------------------------------------------------------------------- /data_function.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from os.path import dirname, join, basename, isfile 3 | import sys 4 | import csv 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import random 11 | import os 12 | from pathlib import Path 13 | import argparse 14 | import cv2 15 | import sys 16 | from sklearn.metrics import mean_squared_error, r2_score 17 | import json 18 | from torchvision import utils 19 | from hparams import hparams as hp 20 | from tqdm import tqdm 21 | 22 | 23 | 24 | class ImageData(torch.utils.data.Dataset): 25 | def __init__(self, root_dir, transfom): 26 | 27 | 28 | self.transformss = transfom 29 | self.root_dir = root_dir 30 | self.root_dir_list = os.listdir(self.root_dir) 31 | 32 | 33 | 34 | def __len__(self): 35 | 36 | return len(self.root_dir_list) 37 | 38 | def __getitem__(self, index): 39 | 40 | 41 | 42 | img_path = os.path.join(self.root_dir, self.root_dir_list[index]) 43 | 44 | img = Image.open(img_path).convert('RGB') 45 | 46 | img = self.transformss(img) 47 | return img 48 | 49 | 50 | class GTResDataset(torch.utils.data.Dataset): 51 | 52 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 53 | self.pairs = [] 54 | for f in os.listdir(root_path): 55 | f_ = f.replace('predict','origin') 56 | image_path = os.path.join(root_path, f) 57 | gt_path = os.path.join(gt_dir, f_) 58 | # if f.endswith(".jpg") or f.endswith(".png"): 59 | # self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) 60 | self.pairs.append([image_path, gt_path, None]) 61 | self.transform = transform 62 | self.transform_train = transform_train 63 | 64 | def __len__(self): 65 | return len(self.pairs) 66 | 67 | def __getitem__(self, index): 68 | from_path, to_path, _ = self.pairs[index] 69 | from_im = Image.open(from_path).convert('RGB') 70 | to_im = Image.open(to_path).convert('RGB') 71 | 72 | if self.transform: 73 | to_im = self.transform(to_im) 74 | from_im = self.transform(from_im) 75 | 76 | return from_im, to_im 77 | 78 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from transforms_config import car_transforms, normal_transforms 2 | 3 | class hparams: 4 | description = 'PyTorch Training' 5 | output_dir = 'log/0831' 6 | 7 | img_size = 1024 8 | resize = True 9 | dataset_type = 'ffhq' # ['ffhq', 'celebahq', 'car', 'bggn'] 10 | dataset_path = '' 11 | 12 | apply_init = True 13 | if apply_init: 14 | init_type = 'normal' # ['normal', 'xavier', 'xavier_uniform', 'kaiming', 'orthogonal', 'none] 15 | 16 | backbone = 'GradualStyleEncoder' # ['GradualStyleEncoder', 'ResNetGradualStyleEncoder'] 17 | epochs_per_checkpoint = 10 18 | latest_checkpoint_file = 'checkpoint_latest.pt' 19 | epochs = 50000 20 | batch = 5 21 | ckpt = None 22 | 23 | init_lr = 0.001 24 | 25 | # loss 26 | loss_lambda_mse = 1 27 | loss_lambda_lpips = 0.8 28 | loss_lambda_id = 0.1 29 | 30 | 31 | arc_model_path = '' 32 | moco_model_path = '' 33 | circular_face_model_paths = '' 34 | weight_path_pytorch = '' 35 | 36 | mtcnn_path_pnet = '' 37 | mtcnn_path_rnet = '' 38 | mtcnn_path_onet = '' 39 | 40 | if dataset_type == 'car': 41 | transform = car_transforms.get_transforms() 42 | else: 43 | transform = normal_transforms.get_transforms() 44 | 45 | optimizer_mode = 'adam' # ['adam', 'sgd', 'radam', 'lookahead', 'ranger'] 46 | scheduler_mode = 'StepLR' # ['StepLR', 'MultiStepLR', 'ReduceLROnPlateau'] 47 | 48 | open_warn_up = True 49 | if open_warn_up: 50 | warn_up_strategy = 'cos' # ['cos', 'linear', 'constant'] 51 | num_warmup = 3 52 | 53 | # for save 54 | norm = True 55 | row = 1 56 | rangee = (-1,1) 57 | 58 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from weights_init.weight_init_normal import weights_init_normal 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3' 4 | devicess = [0,1,2] 5 | import re 6 | import time 7 | import argparse 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch import nn 13 | import torch.distributed as dist 14 | import math 15 | import warnings 16 | from tqdm import tqdm 17 | from torch.optim.lr_scheduler import ReduceLROnPlateau,StepLR,MultiStepLR 18 | from torchvision import utils 19 | from hparams import hparams as hp 20 | from torch.autograd import Variable 21 | from torch_warmup_lr import WarmupLR 22 | from optimizer.LookAhead import Lookahead 23 | from optimizer.RAdam import RAdam 24 | from optimizer.Ranger import Ranger 25 | warnings.filterwarnings("ignore") 26 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 27 | 28 | 29 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 30 | 31 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 32 | 33 | def parse_testing_args(parser): 34 | """ 35 | Parse commandline arguments. 36 | """ 37 | 38 | parser.add_argument('-o', '--output_dir', type=str, default=hp.output_dir, required=False, help='Directory to save results') 39 | parser.add_argument('--latest-checkpoint-file', type=str, default=hp.latest_checkpoint_file, help='Use the latest checkpoint in each epoch') 40 | 41 | testing = parser.add_argument_group('testing setup') 42 | testing.add_argument('--batch', type=int, default=1, help='batch-size') 43 | 44 | testing.add_argument('--cudnn-enabled', default=True, help='Enable cudnn') 45 | testing.add_argument('--cudnn-benchmark', default=True, help='Run cudnn benchmark') 46 | 47 | return parser 48 | 49 | 50 | 51 | def test(): 52 | 53 | parser = argparse.ArgumentParser(description=hp.description) 54 | parser = parse_testing_args(parser) 55 | args, _ = parser.parse_known_args() 56 | args = parser.parse_args() 57 | torch.backends.cudnn.deterministic = True 58 | torch.backends.cudnn.enabled = args.cudnn_enabled 59 | torch.backends.cudnn.benchmark = args.cudnn_benchmark 60 | 61 | 62 | os.makedirs(args.output_dir, exist_ok=True) 63 | 64 | 65 | from stylegan2.stylegan2_infer import infer_face 66 | class_generate = infer_face(hp.weight_path_pytorch) 67 | 68 | 69 | n_styles = 2*int(math.log(hp.img_size, 2))-2 70 | if hp.backbone == 'GradualStyleEncoder': 71 | from models.fpn_encoders import GradualStyleEncoder 72 | model = GradualStyleEncoder(num_layers=50,n_styles=n_styles) 73 | elif hp.backbone == 'ResNetGradualStyleEncoder': 74 | from models.fpn_encoders import ResNetGradualStyleEncoder 75 | model = ResNetGradualStyleEncoder(n_styles=n_styles) 76 | else: 77 | Exception('Backbone error!') 78 | 79 | 80 | 81 | model = torch.nn.DataParallel(model, device_ids=devicess) 82 | 83 | 84 | 85 | print(os.path.join(args.output_dir, args.latest_checkpoint_file)) 86 | ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage) 87 | 88 | model.load_state_dict(ckpt["model"]) 89 | 90 | 91 | # model cuda 92 | model.cuda() 93 | 94 | 95 | from data_function import ImageData 96 | 97 | test_dataset = ImageData(hp.dataset_path, hp.transform['transform_inference']) 98 | test_loader = DataLoader(test_dataset, 99 | batch_size=args.batch, 100 | shuffle=False, 101 | pin_memory=False, 102 | drop_last=True) 103 | 104 | 105 | 106 | 107 | 108 | model.eval() 109 | 110 | 111 | 112 | for i, batch in enumerate(test_loader): 113 | 114 | img = batch.cuda() 115 | 116 | outputs = model(img) 117 | 118 | predicts = class_generate.generate_from_synthesis(outputs,None,randomize_noise=False,return_latents=True) 119 | if hp.resize: 120 | predicts = face_pool(predicts) 121 | if hp.dataset_type == 'car': 122 | predicts = predicts[:, :, 32:224, :] 123 | 124 | 125 | 126 | with torch.no_grad(): 127 | utils.save_image( 128 | predicts, 129 | os.path.join(args.output_dir,("step-{}-predict.png").format(i)), 130 | nrow=hp.row, 131 | normalize=hp.norm, 132 | range=hp.rangee, 133 | ) 134 | 135 | 136 | with torch.no_grad(): 137 | utils.save_image( 138 | img, 139 | os.path.join(args.output_dir,("step-{}-origin.png").format(i)), 140 | nrow=hp.row, 141 | normalize=hp.norm, 142 | range=hp.rangee, 143 | ) 144 | 145 | 146 | 147 | if __name__ == '__main__': 148 | test() 149 | -------------------------------------------------------------------------------- /metrics/cal_sliced_wasserstein.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/microsoft/CoCosNet/issues/25 2 | import os 3 | import time 4 | import re 5 | import bisect 6 | from collections import OrderedDict 7 | import numpy as np 8 | import scipy.ndimage 9 | import scipy.misc 10 | import importlib 11 | from PIL import Image 12 | import sys 13 | 14 | 15 | #---------------------------------------------------------------------------- 16 | # Evaluate one or more metrics for a previous training run. 17 | # To run, uncomment one of the appropriate lines in config.py and launch train.py. 18 | def import_module(module_or_obj_name): 19 | parts = module_or_obj_name.split('.') 20 | parts[0] = {'np': 'numpy', 'tf': 'tensorflow'}.get(parts[0], parts[0]) 21 | for i in range(len(parts), 0, -1): 22 | try: 23 | module = importlib.import_module('.'.join(parts[:i])) 24 | relative_obj_name = '.'.join(parts[i:]) 25 | return module, relative_obj_name 26 | except: 27 | pass 28 | raise ImportError(module_or_obj_name) 29 | 30 | def find_obj_in_module(module, relative_obj_name): 31 | obj = module 32 | for part in relative_obj_name.split('.'): 33 | obj = getattr(obj, part) 34 | return obj 35 | 36 | def import_obj(obj_name): 37 | module, relative_obj_name = import_module(obj_name) 38 | return find_obj_in_module(module, relative_obj_name) 39 | 40 | def evaluate_metrics(fake_imgs, real_imgs, real_passes, minibatch_size=20, metrics=['swd']): 41 | metric_class_names = { 42 | 'swd': 'sliced_wasserstein.API', 43 | } 44 | 45 | assert fake_imgs.shape[0] == real_imgs.shape[0] 46 | num_images = fake_imgs.shape[0] 47 | # Initialize metrics. 48 | metric_objs = [] 49 | for name in metrics: 50 | class_name = metric_class_names.get(name, name) 51 | print('Initializing %s...' % class_name) 52 | class_def = import_obj(class_name) 53 | image_shape = [3] + [256, 256] 54 | obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size) 55 | mode = 'warmup' 56 | obj.begin(mode) 57 | for idx in range(10): 58 | obj.feed(mode, np.random.randint(0, 256, size=[minibatch_size]+image_shape, dtype=np.uint8)) 59 | obj.end(mode) 60 | metric_objs.append(obj) 61 | 62 | # Print table header. 63 | print() 64 | print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='') 65 | for obj in metric_objs: 66 | for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()): 67 | print('%-*s' % (len(fmt % 0), name), end='') 68 | print() 69 | print('%-10s%-12s' % ('---', '---'), end='') 70 | for obj in metric_objs: 71 | for fmt in obj.get_metric_formatting(): 72 | print('%-*s' % (len(fmt % 0), '---'), end='') 73 | print() 74 | 75 | # Feed in reals. 76 | for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]: 77 | print('%-10s' % title, end='') 78 | time_begin = time.time() 79 | [obj.begin(mode) for obj in metric_objs] 80 | for begin in range(0, num_images, minibatch_size): 81 | end = min(begin + minibatch_size, num_images) 82 | if mode == 'fakes': 83 | images = fake_imgs[begin:end] 84 | else: 85 | images = real_imgs[begin:end] 86 | if images.shape[1] == 1: 87 | images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB 88 | [obj.feed(mode, images) for obj in metric_objs] 89 | results = [obj.end(mode) for obj in metric_objs] 90 | print('------') 91 | for obj, vals in zip(metric_objs, results): 92 | for val, fmt in zip(vals, obj.get_metric_formatting()): 93 | print(fmt % val, end='') 94 | print() 95 | 96 | def get_image(folder): 97 | files = os.listdir(folder) 98 | imgs_path = [it for it in files if (it.endswith('.jpg') or it.endswith('.png'))] 99 | print('load {} imgs'.format(len(imgs_path))) 100 | img_list = [] 101 | for path in imgs_path: 102 | img = Image.open(os.path.join(folder, path)) 103 | img_list.append(np.array(img)[np.newaxis, :].transpose(0,3,1,2)) 104 | imgs = np.concatenate(img_list, axis=0) 105 | return imgs 106 | 107 | fake_imgs = get_image(sys.argv[1]) 108 | real_imgs = get_image(sys.argv[2]) 109 | #real_imgs = real_imgs[:len(fake_imgs)] 110 | evaluate_metrics(fake_imgs, real_imgs, 2, minibatch_size=20, metrics=['swd']) 111 | 112 | #python cal_sliced_wasserstein.py path1 path2 -------------------------------------------------------------------------------- /metrics/fideity_loss_calc.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/eladrich/pixel2style2pixel/blob/master/scripts/calc_losses_on_images.py 2 | from argparse import ArgumentParser 3 | import os 4 | import json 5 | import sys 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | import torch_fidelity 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from data_function import GTResDataset 16 | from cleanfid import fid 17 | 18 | def parse_args(): 19 | parser = ArgumentParser(add_help=False) 20 | parser.add_argument('--data_path', type=str, default='results') 21 | parser.add_argument('--gt_path', type=str, default='gt_images') 22 | parser.add_argument('--batch_size', type=int, default=4) 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def run(args): 28 | 29 | 30 | metrics_dict = torch_fidelity.calculate_metrics( 31 | input1=args.data_path, 32 | input2=args.gt_path, 33 | cuda=True, 34 | isc=True, 35 | fid=True, 36 | kid=True, 37 | ppl=True, 38 | verbose=False, 39 | ) 40 | 41 | 42 | print('Finished') 43 | print(metrics_dict) 44 | 45 | score_fid_clean = fid.compute_fid(args.gt_path, args.data_path,mode="clean") 46 | score_fid_tf = fid.compute_fid(args.gt_path, args.data_path,mode="legacy_tensorflow") 47 | score_fid_torch = fid.compute_fid(args.gt_path, args.data_path,mode="legacy_pytorch") 48 | score_kid_clean = fid.compute_kid(args.gt_path, args.data_path,mode="clean") 49 | score_kid_tf = fid.compute_kid(args.gt_path, args.data_path,mode="legacy_tensorflow") 50 | score_kid_torch = fid.compute_kid(args.gt_path, args.data_path,mode="legacy_pytorch") 51 | print('fid clean:',score_fid_clean) 52 | print('fid tf:',score_fid_tf) 53 | print('fid torch:',score_fid_torch) 54 | print('kid clean:',score_kid_clean) 55 | print('kid tf:',score_kid_tf) 56 | print('kid torch:',score_kid_torch) 57 | 58 | # out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 59 | # if not os.path.exists(out_path): 60 | # os.makedirs(out_path) 61 | 62 | # with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: 63 | # f.write(result_str) 64 | # with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: 65 | # json.dump(scores_dict, f) 66 | 67 | 68 | if __name__ == '__main__': 69 | args = parse_args() 70 | run(args) -------------------------------------------------------------------------------- /metrics/id_loss_calc.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/eladrich/pixel2style2pixel/blob/master/scripts/calc_id_loss_parallel.py 2 | from argparse import ArgumentParser 3 | import time 4 | import numpy as np 5 | import os 6 | import json 7 | import sys 8 | from PIL import Image 9 | import multiprocessing as mp 10 | import math 11 | import torch 12 | import torchvision.transforms as trans 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | from models.mtcnn.mtcnn import MTCNN 18 | from models.model_irse import IR_101 19 | from hparams import hparams as hp 20 | CIRCULAR_FACE_PATH = hp.circular_face_model_paths 21 | 22 | 23 | def chunks(lst, n): 24 | """Yield successive n-sized chunks from lst.""" 25 | for i in range(0, len(lst), n): 26 | yield lst[i:i + n] 27 | 28 | 29 | def extract_on_paths(file_paths): 30 | facenet = IR_101(input_size=112) 31 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 32 | facenet.cuda() 33 | facenet.eval() 34 | mtcnn = MTCNN() 35 | id_transform = trans.Compose([ 36 | trans.ToTensor(), 37 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 38 | ]) 39 | 40 | pid = mp.current_process().name 41 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 42 | tot_count = len(file_paths) 43 | count = 0 44 | 45 | scores_dict = {} 46 | for res_path, gt_path in file_paths: 47 | count += 1 48 | if count % 100 == 0: 49 | print('{} done with {}/{}'.format(pid, count, tot_count)) 50 | if True: 51 | input_im = Image.open(res_path) 52 | input_im, _ = mtcnn.align(input_im) 53 | if input_im is None: 54 | print('{} skipping {}'.format(pid, res_path)) 55 | continue 56 | 57 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 58 | 59 | result_im = Image.open(gt_path) 60 | result_im, _ = mtcnn.align(result_im) 61 | if result_im is None: 62 | print('{} skipping {}'.format(pid, gt_path)) 63 | continue 64 | 65 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 66 | score = float(input_id.dot(result_id)) 67 | scores_dict[os.path.basename(gt_path)] = score 68 | 69 | return scores_dict 70 | 71 | 72 | def parse_args(): 73 | parser = ArgumentParser(add_help=False) 74 | parser.add_argument('--num_threads', type=int, default=4) 75 | parser.add_argument('--data_path', type=str, default='results') 76 | parser.add_argument('--gt_path', type=str, default='gt_images') 77 | args = parser.parse_args() 78 | return args 79 | 80 | 81 | def run(args): 82 | file_paths = [] 83 | for f in os.listdir(args.data_path): 84 | image_path = os.path.join(args.data_path, f) 85 | f_ = f.replace('predict','origin') 86 | gt_path = os.path.join(args.gt_path, f_) 87 | 88 | # if f.endswith(".jpg") or f.endswith('.png'): 89 | file_paths.append([image_path, gt_path]) 90 | 91 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 92 | pool = mp.Pool(args.num_threads) 93 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 94 | 95 | tic = time.time() 96 | results = pool.map(extract_on_paths, file_chunks) 97 | scores_dict = {} 98 | for d in results: 99 | scores_dict.update(d) 100 | 101 | all_scores = list(scores_dict.values()) 102 | mean = np.mean(all_scores) 103 | std = np.std(all_scores) 104 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 105 | print(result_str) 106 | 107 | # out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 108 | # if not os.path.exists(out_path): 109 | # os.makedirs(out_path) 110 | 111 | # with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: 112 | # f.write(result_str) 113 | # with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: 114 | # json.dump(scores_dict, f) 115 | 116 | # toc = time.time() 117 | # print('Mischief managed in {}s'.format(toc - tic)) 118 | 119 | 120 | if __name__ == '__main__': 121 | args = parse_args() 122 | run(args) -------------------------------------------------------------------------------- /metrics/img_loss_calc.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/eladrich/pixel2style2pixel/blob/master/scripts/calc_losses_on_images.py 2 | from argparse import ArgumentParser 3 | import os 4 | import json 5 | import sys 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchvision.transforms as transforms 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from criteria.lpips.lpips import LPIPS 16 | from data_function import GTResDataset 17 | 18 | from piqa import PSNR, TV, SSIM, MS_SSIM, GMSD, MS_GMSD, MDSI, HaarPSI, VSI, FSIM 19 | 20 | def parse_args(): 21 | parser = ArgumentParser(add_help=False) 22 | parser.add_argument('--mode', type=str, default='ms_ssim', choices=['lpips', 'l2', 'psnr', 'tv', 'ssim', 'ms_ssim', 'gmsd', 'ms_gmsd', 'mdsi', 'haarpsi', 'vsi', 'fsim']) 23 | parser.add_argument('--data_path', type=str, default='results') 24 | parser.add_argument('--gt_path', type=str, default='gt_images') 25 | parser.add_argument('--workers', type=int, default=4) 26 | parser.add_argument('--batch_size', type=int, default=4) 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def run(args): 32 | 33 | 34 | 35 | if args.mode == 'lpips' or args.mode == 'l2': 36 | transform = transforms.Compose([transforms.Resize((256, 256)), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 39 | else: 40 | transform = transforms.Compose([transforms.Resize((256, 256)), 41 | transforms.ToTensor()]) 42 | 43 | print('Loading dataset') 44 | dataset = GTResDataset(root_path=args.data_path, 45 | gt_dir=args.gt_path, 46 | transform=transform) 47 | 48 | dataloader = DataLoader(dataset, 49 | batch_size=args.batch_size, 50 | shuffle=False, 51 | num_workers=int(args.workers), 52 | drop_last=True) 53 | 54 | if args.mode == 'lpips': 55 | loss_func = LPIPS(net_type='alex') 56 | elif args.mode == 'l2': 57 | loss_func = torch.nn.MSELoss() 58 | elif args.mode == 'psnr': 59 | loss_func = PSNR() 60 | elif args.mode == 'tv': 61 | loss_func = TV() 62 | elif args.mode == 'ssim': 63 | loss_func = SSIM() 64 | elif args.mode == 'ms_ssim': 65 | loss_func = MS_SSIM() 66 | elif args.mode == 'gmsd': 67 | loss_func = GMSD() 68 | elif args.mode == 'ms_gmsd': 69 | loss_func = MS_GMSD() 70 | elif args.mode == 'mdsi': 71 | loss_func = MDSI() 72 | elif args.mode == 'haarpsi': 73 | loss_func = HaarPSI() 74 | elif args.mode == 'vsi': 75 | loss_func = VSI() 76 | elif args.mode == 'fsim': 77 | loss_func = FSIM() 78 | else: 79 | raise Exception('Not a valid mode!') 80 | loss_func.cuda() 81 | 82 | global_i = 0 83 | scores_dict = {} 84 | all_scores = [] 85 | for result_batch, gt_batch in tqdm(dataloader): 86 | # print(result_batch) 87 | # print(gt_batch) 88 | for i in range(args.batch_size): 89 | if args.mode == 'lpips' or args.mode == 'l2': 90 | loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda())) 91 | else: 92 | loss = loss_func(result_batch.cuda(),gt_batch.cuda()) 93 | 94 | all_scores.append(loss) 95 | im_path = dataset.pairs[global_i][0] 96 | if args.mode == 'lpips' or args.mode == 'l2': 97 | scores_dict[os.path.basename(im_path)] = loss 98 | else: 99 | scores_dict[os.path.basename(im_path)] = loss.cpu() 100 | global_i += 1 101 | 102 | if args.mode == 'lpips' or args.mode == 'l2': 103 | continue 104 | else: 105 | break 106 | all_scores = list(scores_dict.values()) 107 | 108 | mean = np.mean(all_scores) 109 | std = np.std(all_scores) 110 | result_str = 'Average loss is {:.4f}+-{:.4f}'.format(mean, std) 111 | print('Finished with ', args.data_path) 112 | print(result_str) 113 | 114 | # out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 115 | # if not os.path.exists(out_path): 116 | # os.makedirs(out_path) 117 | 118 | # with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: 119 | # f.write(result_str) 120 | # with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: 121 | # json.dump(scores_dict, f) 122 | 123 | 124 | if __name__ == '__main__': 125 | args = parse_args() 126 | run(args) -------------------------------------------------------------------------------- /metrics/sliced_wasserstein.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | 4 | #---------------------------------------------------------------------------- 5 | 6 | def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image): 7 | S = minibatch.shape # (minibatch, channel, height, width) 8 | assert len(S) == 4 and S[1] == 3 9 | N = nhoods_per_image * S[0] 10 | H = nhood_size // 2 11 | nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H+1, -H:H+1] 12 | img = nhood // nhoods_per_image 13 | x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1)) 14 | y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1)) 15 | idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x 16 | return minibatch.flat[idx] 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def finalize_descriptors(desc): 21 | if isinstance(desc, list): 22 | desc = np.concatenate(desc, axis=0) 23 | assert desc.ndim == 4 # (neighborhood, channel, height, width) 24 | desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True) 25 | desc /= np.std(desc, axis=(0, 2, 3), keepdims=True) 26 | desc = desc.reshape(desc.shape[0], -1) 27 | return desc 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat): 32 | assert A.ndim == 2 and A.shape == B.shape # (neighborhood, descriptor_component) 33 | results = [] 34 | for repeat in range(dir_repeats): 35 | dirs = np.random.randn(A.shape[1], dirs_per_repeat) # (descriptor_component, direction) 36 | dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) # normalize descriptor components for each direction 37 | dirs = dirs.astype(np.float32) 38 | projA = np.matmul(A, dirs) # (neighborhood, direction) 39 | projB = np.matmul(B, dirs) 40 | projA = np.sort(projA, axis=0) # sort neighborhood projections for each direction 41 | projB = np.sort(projB, axis=0) 42 | dists = np.abs(projA - projB) # pointwise wasserstein distances 43 | results.append(np.mean(dists)) # average over neighborhoods and directions 44 | return np.mean(results) # average over repeats 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def downscale_minibatch(minibatch, lod): 49 | if lod == 0: 50 | return minibatch 51 | t = minibatch.astype(np.float32) 52 | for i in range(lod): 53 | t = (t[:, :, 0::2, 0::2] + t[:, :, 0::2, 1::2] + t[:, :, 1::2, 0::2] + t[:, :, 1::2, 1::2]) * 0.25 54 | return np.round(t).clip(0, 255).astype(np.uint8) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | gaussian_filter = np.float32([ 59 | [1, 4, 6, 4, 1], 60 | [4, 16, 24, 16, 4], 61 | [6, 24, 36, 24, 6], 62 | [4, 16, 24, 16, 4], 63 | [1, 4, 6, 4, 1]]) / 256.0 64 | 65 | def pyr_down(minibatch): # matches cv2.pyrDown() 66 | assert minibatch.ndim == 4 67 | return scipy.ndimage.convolve(minibatch, gaussian_filter[np.newaxis, np.newaxis, :, :], mode='mirror')[:, :, ::2, ::2] 68 | 69 | def pyr_up(minibatch): # matches cv2.pyrUp() 70 | assert minibatch.ndim == 4 71 | S = minibatch.shape 72 | res = np.zeros((S[0], S[1], S[2] * 2, S[3] * 2), minibatch.dtype) 73 | res[:, :, ::2, ::2] = minibatch 74 | return scipy.ndimage.convolve(res, gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, mode='mirror') 75 | 76 | def generate_laplacian_pyramid(minibatch, num_levels): 77 | pyramid = [np.float32(minibatch)] 78 | for i in range(1, num_levels): 79 | pyramid.append(pyr_down(pyramid[-1])) 80 | pyramid[-2] -= pyr_up(pyramid[-1]) 81 | return pyramid 82 | 83 | def reconstruct_laplacian_pyramid(pyramid): 84 | minibatch = pyramid[-1] 85 | for level in pyramid[-2::-1]: 86 | minibatch = pyr_up(minibatch) + level 87 | return minibatch 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | class API: 92 | def __init__(self, num_images, image_shape, image_dtype, minibatch_size): 93 | self.nhood_size = 7 94 | self.nhoods_per_image = 128 95 | self.dir_repeats = 4 96 | self.dirs_per_repeat = 128 97 | self.resolutions = [] 98 | res = image_shape[1] 99 | while res >= 16: 100 | self.resolutions.append(res) 101 | res //= 2 102 | 103 | def get_metric_names(self): 104 | return ['SWDx1e3_%d' % res for res in self.resolutions] + ['SWDx1e3_avg'] 105 | 106 | def get_metric_formatting(self): 107 | return ['%-13.4f'] * len(self.get_metric_names()) 108 | 109 | def begin(self, mode): 110 | assert mode in ['warmup', 'reals', 'fakes'] 111 | self.descriptors = [[] for res in self.resolutions] 112 | 113 | def feed(self, mode, minibatch): 114 | for lod, level in enumerate(generate_laplacian_pyramid(minibatch, len(self.resolutions))): 115 | desc = get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image) 116 | self.descriptors[lod].append(desc) 117 | 118 | def end(self, mode): 119 | desc = [finalize_descriptors(d) for d in self.descriptors] 120 | del self.descriptors 121 | if mode in ['warmup', 'reals']: 122 | self.desc_real = desc 123 | if mode == 'fakes': 124 | dist = [sliced_wasserstein(dreal, dfake, self.dir_repeats, self.dirs_per_repeat) for dreal, dfake in zip(self.desc_real, desc)] 125 | else: 126 | dist = [0, 0, 0, 0, 0, 0] 127 | del desc 128 | dist = [d * 1e3 for d in dist] # multiply by 10^3 129 | return dist + [np.mean(dist)] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/__init__.py -------------------------------------------------------------------------------- /models/fpn_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 5 | from torchvision.models.resnet import resnet34 6 | 7 | from models.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 8 | from models.map2style import GradualStyleBlock 9 | 10 | 11 | class GradualStyleEncoder(Module): 12 | """ 13 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 14 | an ResNet IRSE-50 backbone. 15 | Note this class is designed to be used for the human facial domain. 16 | """ 17 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 18 | super(GradualStyleEncoder, self).__init__() 19 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 20 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 21 | blocks = get_blocks(num_layers) 22 | if mode == 'ir': 23 | unit_module = bottleneck_IR 24 | elif mode == 'ir_se': 25 | unit_module = bottleneck_IR_SE 26 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | self.coarse_ind = 3 40 | self.middle_ind = 7 41 | for i in range(self.style_count): 42 | if i < self.coarse_ind: 43 | style = GradualStyleBlock(512, 512, 16) 44 | elif i < self.middle_ind: 45 | style = GradualStyleBlock(512, 512, 32) 46 | else: 47 | style = GradualStyleBlock(512, 512, 64) 48 | self.styles.append(style) 49 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 50 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 51 | 52 | def _upsample_add(self, x, y): 53 | _, _, H, W = y.size() 54 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 55 | 56 | def forward(self, x): 57 | x = self.input_layer(x) 58 | 59 | latents = [] 60 | modulelist = list(self.body._modules.values()) 61 | for i, l in enumerate(modulelist): 62 | x = l(x) 63 | if i == 6: 64 | c1 = x 65 | elif i == 20: 66 | c2 = x 67 | elif i == 23: 68 | c3 = x 69 | 70 | for j in range(self.coarse_ind): 71 | latents.append(self.styles[j](c3)) 72 | 73 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 74 | for j in range(self.coarse_ind, self.middle_ind): 75 | latents.append(self.styles[j](p2)) 76 | 77 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 78 | for j in range(self.middle_ind, self.style_count): 79 | latents.append(self.styles[j](p1)) 80 | 81 | out = torch.stack(latents, dim=1) 82 | return out 83 | 84 | 85 | class ResNetGradualStyleEncoder(Module): 86 | """ 87 | Original encoder architecture from pixel2style2pixel. This classes uses an FPN-based architecture applied over 88 | an ResNet34 backbone. 89 | """ 90 | def __init__(self, n_styles=18, opts=None): 91 | super(ResNetGradualStyleEncoder, self).__init__() 92 | 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 94 | self.bn1 = BatchNorm2d(64) 95 | self.relu = PReLU(64) 96 | 97 | resnet_basenet = resnet34(pretrained=True) 98 | blocks = [ 99 | resnet_basenet.layer1, 100 | resnet_basenet.layer2, 101 | resnet_basenet.layer3, 102 | resnet_basenet.layer4 103 | ] 104 | 105 | modules = [] 106 | for block in blocks: 107 | for bottleneck in block: 108 | modules.append(bottleneck) 109 | 110 | self.body = Sequential(*modules) 111 | 112 | self.styles = nn.ModuleList() 113 | self.style_count = n_styles 114 | self.coarse_ind = 3 115 | self.middle_ind = 7 116 | for i in range(self.style_count): 117 | if i < self.coarse_ind: 118 | style = GradualStyleBlock(512, 512, 16) 119 | elif i < self.middle_ind: 120 | style = GradualStyleBlock(512, 512, 32) 121 | else: 122 | style = GradualStyleBlock(512, 512, 64) 123 | self.styles.append(style) 124 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 125 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 126 | 127 | def _upsample_add(self, x, y): 128 | _, _, H, W = y.size() 129 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | 136 | latents = [] 137 | modulelist = list(self.body._modules.values()) 138 | for i, l in enumerate(modulelist): 139 | x = l(x) 140 | if i == 6: 141 | c1 = x 142 | elif i == 12: 143 | c2 = x 144 | elif i == 15: 145 | c3 = x 146 | 147 | for j in range(self.coarse_ind): 148 | latents.append(self.styles[j](c3)) 149 | 150 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 151 | for j in range(self.coarse_ind, self.middle_ind): 152 | latents.append(self.styles[j](p2)) 153 | 154 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 155 | for j in range(self.middle_ind, self.style_count): 156 | latents.append(self.styles[j](p1)) 157 | 158 | out = torch.stack(latents, dim=1) 159 | return out -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut -------------------------------------------------------------------------------- /models/map2style.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Module 4 | 5 | from stylegan2.model import EqualLinear 6 | 7 | 8 | class GradualStyleBlock(Module): 9 | def __init__(self, in_c, out_c, spatial): 10 | super(GradualStyleBlock, self).__init__() 11 | self.out_c = out_c 12 | self.spatial = spatial 13 | num_pools = int(np.log2(spatial)) 14 | modules = [] 15 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 16 | nn.LeakyReLU()] 17 | for i in range(num_pools - 1): 18 | modules += [ 19 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 20 | nn.LeakyReLU() 21 | ] 22 | self.convs = nn.Sequential(*modules) 23 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 24 | 25 | def forward(self, x): 26 | x = self.convs(x) 27 | x = x.view(-1, self.out_c) 28 | x = self.linear(x) 29 | return x -------------------------------------------------------------------------------- /models/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 7 | from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 8 | 9 | device = 'cuda:0' 10 | 11 | 12 | class MTCNN(): 13 | def __init__(self): 14 | print(device) 15 | self.pnet = PNet().to(device) 16 | self.rnet = RNet().to(device) 17 | self.onet = ONet().to(device) 18 | self.pnet.eval() 19 | self.rnet.eval() 20 | self.onet.eval() 21 | self.refrence = get_reference_facial_points(default_square=True) 22 | 23 | def align(self, img): 24 | _, landmarks = self.detect_faces(img) 25 | if len(landmarks) == 0: 26 | return None, None 27 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 28 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 29 | return Image.fromarray(warped_face), tfm 30 | 31 | def align_multi(self, img, limit=None, min_face_size=30.0): 32 | boxes, landmarks = self.detect_faces(img, min_face_size) 33 | if limit: 34 | boxes = boxes[:limit] 35 | landmarks = landmarks[:limit] 36 | faces = [] 37 | tfms = [] 38 | for landmark in landmarks: 39 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 40 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 41 | faces.append(Image.fromarray(warped_face)) 42 | tfms.append(tfm) 43 | return boxes, faces, tfms 44 | 45 | def detect_faces(self, image, min_face_size=20.0, 46 | thresholds=[0.15, 0.25, 0.35], 47 | nms_thresholds=[0.7, 0.7, 0.7]): 48 | """ 49 | Arguments: 50 | image: an instance of PIL.Image. 51 | min_face_size: a float number. 52 | thresholds: a list of length 3. 53 | nms_thresholds: a list of length 3. 54 | 55 | Returns: 56 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 57 | bounding boxes and facial landmarks. 58 | """ 59 | 60 | # BUILD AN IMAGE PYRAMID 61 | width, height = image.size 62 | min_length = min(height, width) 63 | 64 | min_detection_size = 12 65 | factor = 0.707 # sqrt(0.5) 66 | 67 | # scales for scaling the image 68 | scales = [] 69 | 70 | # scales the image so that 71 | # minimum size that we can detect equals to 72 | # minimum face size that we want to detect 73 | m = min_detection_size / min_face_size 74 | min_length *= m 75 | 76 | factor_count = 0 77 | while min_length > min_detection_size: 78 | scales.append(m * factor ** factor_count) 79 | min_length *= factor 80 | factor_count += 1 81 | 82 | # STAGE 1 83 | 84 | # it will be returned 85 | bounding_boxes = [] 86 | 87 | with torch.no_grad(): 88 | # run P-Net on different scales 89 | for s in scales: 90 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 91 | bounding_boxes.append(boxes) 92 | 93 | # collect boxes (and offsets, and scores) from different scales 94 | bounding_boxes = [i for i in bounding_boxes if i is not None] 95 | bounding_boxes = np.vstack(bounding_boxes) 96 | 97 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 98 | bounding_boxes = bounding_boxes[keep] 99 | 100 | # use offsets predicted by pnet to transform bounding boxes 101 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 102 | # shape [n_boxes, 5] 103 | 104 | bounding_boxes = convert_to_square(bounding_boxes) 105 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 106 | 107 | # STAGE 2 108 | 109 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 110 | img_boxes = torch.FloatTensor(img_boxes).to(device) 111 | 112 | output = self.rnet(img_boxes) 113 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 114 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 115 | 116 | keep = np.where(probs[:, 1] > thresholds[1])[0] 117 | bounding_boxes = bounding_boxes[keep] 118 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 119 | offsets = offsets[keep] 120 | 121 | keep = nms(bounding_boxes, nms_thresholds[1]) 122 | bounding_boxes = bounding_boxes[keep] 123 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 124 | bounding_boxes = convert_to_square(bounding_boxes) 125 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 126 | 127 | # STAGE 3 128 | 129 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 130 | if len(img_boxes) == 0: 131 | return [], [] 132 | img_boxes = torch.FloatTensor(img_boxes).to(device) 133 | output = self.onet(img_boxes) 134 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 135 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 136 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 137 | 138 | keep = np.where(probs[:, 1] > thresholds[2])[0] 139 | bounding_boxes = bounding_boxes[keep] 140 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 141 | offsets = offsets[keep] 142 | landmarks = landmarks[keep] 143 | 144 | # compute landmark points 145 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 146 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 147 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 148 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 149 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 150 | 151 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 152 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 153 | bounding_boxes = bounding_boxes[keep] 154 | landmarks = landmarks[keep] 155 | 156 | return bounding_boxes, landmarks 157 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/align_trans.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 24 15:43:29 2017 4 | @author: zhaoy 5 | """ 6 | import numpy as np 7 | import cv2 8 | 9 | # from scipy.linalg import lstsq 10 | # from scipy.ndimage import geometric_transform # , map_coordinates 11 | 12 | from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2 13 | 14 | # reference facial points, a list of coordinates (x,y) 15 | REFERENCE_FACIAL_POINTS = [ 16 | [30.29459953, 51.69630051], 17 | [65.53179932, 51.50139999], 18 | [48.02519989, 71.73660278], 19 | [33.54930115, 92.3655014], 20 | [62.72990036, 92.20410156] 21 | ] 22 | 23 | DEFAULT_CROP_SIZE = (96, 112) 24 | 25 | 26 | class FaceWarpException(Exception): 27 | def __str__(self): 28 | return 'In File {}:{}'.format( 29 | __file__, super.__str__(self)) 30 | 31 | 32 | def get_reference_facial_points(output_size=None, 33 | inner_padding_factor=0.0, 34 | outer_padding=(0, 0), 35 | default_square=False): 36 | """ 37 | Function: 38 | ---------- 39 | get reference 5 key points according to crop settings: 40 | 0. Set default crop_size: 41 | if default_square: 42 | crop_size = (112, 112) 43 | else: 44 | crop_size = (96, 112) 45 | 1. Pad the crop_size by inner_padding_factor in each side; 46 | 2. Resize crop_size into (output_size - outer_padding*2), 47 | pad into output_size with outer_padding; 48 | 3. Output reference_5point; 49 | Parameters: 50 | ---------- 51 | @output_size: (w, h) or None 52 | size of aligned face image 53 | @inner_padding_factor: (w_factor, h_factor) 54 | padding factor for inner (w, h) 55 | @outer_padding: (w_pad, h_pad) 56 | each row is a pair of coordinates (x, y) 57 | @default_square: True or False 58 | if True: 59 | default crop_size = (112, 112) 60 | else: 61 | default crop_size = (96, 112); 62 | !!! make sure, if output_size is not None: 63 | (output_size - outer_padding) 64 | = some_scale * (default crop_size * (1.0 + inner_padding_factor)) 65 | Returns: 66 | ---------- 67 | @reference_5point: 5x2 np.array 68 | each row is a pair of transformed coordinates (x, y) 69 | """ 70 | # print('\n===> get_reference_facial_points():') 71 | 72 | # print('---> Params:') 73 | # print(' output_size: ', output_size) 74 | # print(' inner_padding_factor: ', inner_padding_factor) 75 | # print(' outer_padding:', outer_padding) 76 | # print(' default_square: ', default_square) 77 | 78 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 79 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 80 | 81 | # 0) make the inner region a square 82 | if default_square: 83 | size_diff = max(tmp_crop_size) - tmp_crop_size 84 | tmp_5pts += size_diff / 2 85 | tmp_crop_size += size_diff 86 | 87 | # print('---> default:') 88 | # print(' crop_size = ', tmp_crop_size) 89 | # print(' reference_5pts = ', tmp_5pts) 90 | 91 | if (output_size and 92 | output_size[0] == tmp_crop_size[0] and 93 | output_size[1] == tmp_crop_size[1]): 94 | # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) 95 | return tmp_5pts 96 | 97 | if (inner_padding_factor == 0 and 98 | outer_padding == (0, 0)): 99 | if output_size is None: 100 | # print('No paddings to do: return default reference points') 101 | return tmp_5pts 102 | else: 103 | raise FaceWarpException( 104 | 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 105 | 106 | # check output size 107 | if not (0 <= inner_padding_factor <= 1.0): 108 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 109 | 110 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) 111 | and output_size is None): 112 | output_size = tmp_crop_size * \ 113 | (1 + inner_padding_factor * 2).astype(np.int32) 114 | output_size += np.array(outer_padding) 115 | # print(' deduced from paddings, output_size = ', output_size) 116 | 117 | if not (outer_padding[0] < output_size[0] 118 | and outer_padding[1] < output_size[1]): 119 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]' 120 | 'and outer_padding[1] < output_size[1])') 121 | 122 | # 1) pad the inner region according inner_padding_factor 123 | # print('---> STEP1: pad the inner region according inner_padding_factor') 124 | if inner_padding_factor > 0: 125 | size_diff = tmp_crop_size * inner_padding_factor * 2 126 | tmp_5pts += size_diff / 2 127 | tmp_crop_size += np.round(size_diff).astype(np.int32) 128 | 129 | # print(' crop_size = ', tmp_crop_size) 130 | # print(' reference_5pts = ', tmp_5pts) 131 | 132 | # 2) resize the padded inner region 133 | # print('---> STEP2: resize the padded inner region') 134 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 135 | # print(' crop_size = ', tmp_crop_size) 136 | # print(' size_bf_outer_pad = ', size_bf_outer_pad) 137 | 138 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 139 | raise FaceWarpException('Must have (output_size - outer_padding)' 140 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 141 | 142 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 143 | # print(' resize scale_factor = ', scale_factor) 144 | tmp_5pts = tmp_5pts * scale_factor 145 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 146 | # tmp_5pts = tmp_5pts + size_diff / 2 147 | tmp_crop_size = size_bf_outer_pad 148 | # print(' crop_size = ', tmp_crop_size) 149 | # print(' reference_5pts = ', tmp_5pts) 150 | 151 | # 3) add outer_padding to make output_size 152 | reference_5point = tmp_5pts + np.array(outer_padding) 153 | tmp_crop_size = output_size 154 | # print('---> STEP3: add outer_padding to make output_size') 155 | # print(' crop_size = ', tmp_crop_size) 156 | # print(' reference_5pts = ', tmp_5pts) 157 | 158 | # print('===> end get_reference_facial_points\n') 159 | 160 | return reference_5point 161 | 162 | 163 | def get_affine_transform_matrix(src_pts, dst_pts): 164 | """ 165 | Function: 166 | ---------- 167 | get affine transform matrix 'tfm' from src_pts to dst_pts 168 | Parameters: 169 | ---------- 170 | @src_pts: Kx2 np.array 171 | source points matrix, each row is a pair of coordinates (x, y) 172 | @dst_pts: Kx2 np.array 173 | destination points matrix, each row is a pair of coordinates (x, y) 174 | Returns: 175 | ---------- 176 | @tfm: 2x3 np.array 177 | transform matrix from src_pts to dst_pts 178 | """ 179 | 180 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 181 | n_pts = src_pts.shape[0] 182 | ones = np.ones((n_pts, 1), src_pts.dtype) 183 | src_pts_ = np.hstack([src_pts, ones]) 184 | dst_pts_ = np.hstack([dst_pts, ones]) 185 | 186 | # #print(('src_pts_:\n' + str(src_pts_)) 187 | # #print(('dst_pts_:\n' + str(dst_pts_)) 188 | 189 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 190 | 191 | # #print(('np.linalg.lstsq return A: \n' + str(A)) 192 | # #print(('np.linalg.lstsq return res: \n' + str(res)) 193 | # #print(('np.linalg.lstsq return rank: \n' + str(rank)) 194 | # #print(('np.linalg.lstsq return s: \n' + str(s)) 195 | 196 | if rank == 3: 197 | tfm = np.float32([ 198 | [A[0, 0], A[1, 0], A[2, 0]], 199 | [A[0, 1], A[1, 1], A[2, 1]] 200 | ]) 201 | elif rank == 2: 202 | tfm = np.float32([ 203 | [A[0, 0], A[1, 0], 0], 204 | [A[0, 1], A[1, 1], 0] 205 | ]) 206 | 207 | return tfm 208 | 209 | 210 | def warp_and_crop_face(src_img, 211 | facial_pts, 212 | reference_pts=None, 213 | crop_size=(96, 112), 214 | align_type='smilarity'): 215 | """ 216 | Function: 217 | ---------- 218 | apply affine transform 'trans' to uv 219 | Parameters: 220 | ---------- 221 | @src_img: 3x3 np.array 222 | input image 223 | @facial_pts: could be 224 | 1)a list of K coordinates (x,y) 225 | or 226 | 2) Kx2 or 2xK np.array 227 | each row or col is a pair of coordinates (x, y) 228 | @reference_pts: could be 229 | 1) a list of K coordinates (x,y) 230 | or 231 | 2) Kx2 or 2xK np.array 232 | each row or col is a pair of coordinates (x, y) 233 | or 234 | 3) None 235 | if None, use default reference facial points 236 | @crop_size: (w, h) 237 | output face image size 238 | @align_type: transform type, could be one of 239 | 1) 'similarity': use similarity transform 240 | 2) 'cv2_affine': use the first 3 points to do affine transform, 241 | by calling cv2.getAffineTransform() 242 | 3) 'affine': use all points to do affine transform 243 | Returns: 244 | ---------- 245 | @face_img: output face image with size (w, h) = @crop_size 246 | """ 247 | 248 | if reference_pts is None: 249 | if crop_size[0] == 96 and crop_size[1] == 112: 250 | reference_pts = REFERENCE_FACIAL_POINTS 251 | else: 252 | default_square = False 253 | inner_padding_factor = 0 254 | outer_padding = (0, 0) 255 | output_size = crop_size 256 | 257 | reference_pts = get_reference_facial_points(output_size, 258 | inner_padding_factor, 259 | outer_padding, 260 | default_square) 261 | 262 | ref_pts = np.float32(reference_pts) 263 | ref_pts_shp = ref_pts.shape 264 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 265 | raise FaceWarpException( 266 | 'reference_pts.shape must be (K,2) or (2,K) and K>2') 267 | 268 | if ref_pts_shp[0] == 2: 269 | ref_pts = ref_pts.T 270 | 271 | src_pts = np.float32(facial_pts) 272 | src_pts_shp = src_pts.shape 273 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 274 | raise FaceWarpException( 275 | 'facial_pts.shape must be (K,2) or (2,K) and K>2') 276 | 277 | if src_pts_shp[0] == 2: 278 | src_pts = src_pts.T 279 | 280 | # #print('--->src_pts:\n', src_pts 281 | # #print('--->ref_pts\n', ref_pts 282 | 283 | if src_pts.shape != ref_pts.shape: 284 | raise FaceWarpException( 285 | 'facial_pts and reference_pts must have the same shape') 286 | 287 | if align_type is 'cv2_affine': 288 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 289 | # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm)) 290 | elif align_type is 'affine': 291 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 292 | # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm)) 293 | else: 294 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) 295 | # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm)) 296 | 297 | # #print('--->Transform matrix: ' 298 | # #print(('type(tfm):' + str(type(tfm))) 299 | # #print(('tfm.dtype:' + str(tfm.dtype)) 300 | # #print( tfm 301 | 302 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) 303 | 304 | return face_img, tfm 305 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def nms(boxes, overlap_threshold=0.5, mode='union'): 6 | """Non-maximum suppression. 7 | 8 | Arguments: 9 | boxes: a float numpy array of shape [n, 5], 10 | where each row is (xmin, ymin, xmax, ymax, score). 11 | overlap_threshold: a float number. 12 | mode: 'union' or 'min'. 13 | 14 | Returns: 15 | list with indices of the selected boxes 16 | """ 17 | 18 | # if there are no boxes, return the empty list 19 | if len(boxes) == 0: 20 | return [] 21 | 22 | # list of picked indices 23 | pick = [] 24 | 25 | # grab the coordinates of the bounding boxes 26 | x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] 27 | 28 | area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) 29 | ids = np.argsort(score) # in increasing order 30 | 31 | while len(ids) > 0: 32 | 33 | # grab index of the largest value 34 | last = len(ids) - 1 35 | i = ids[last] 36 | pick.append(i) 37 | 38 | # compute intersections 39 | # of the box with the largest score 40 | # with the rest of boxes 41 | 42 | # left top corner of intersection boxes 43 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 44 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 45 | 46 | # right bottom corner of intersection boxes 47 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 48 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 49 | 50 | # width and height of intersection boxes 51 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 52 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 53 | 54 | # intersections' areas 55 | inter = w * h 56 | if mode == 'min': 57 | overlap = inter / np.minimum(area[i], area[ids[:last]]) 58 | elif mode == 'union': 59 | # intersection over union (IoU) 60 | overlap = inter / (area[i] + area[ids[:last]] - inter) 61 | 62 | # delete all boxes where overlap is too big 63 | ids = np.delete( 64 | ids, 65 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 66 | ) 67 | 68 | return pick 69 | 70 | 71 | def convert_to_square(bboxes): 72 | """Convert bounding boxes to a square form. 73 | 74 | Arguments: 75 | bboxes: a float numpy array of shape [n, 5]. 76 | 77 | Returns: 78 | a float numpy array of shape [n, 5], 79 | squared bounding boxes. 80 | """ 81 | 82 | square_bboxes = np.zeros_like(bboxes) 83 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 84 | h = y2 - y1 + 1.0 85 | w = x2 - x1 + 1.0 86 | max_side = np.maximum(h, w) 87 | square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 88 | square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 89 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 90 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 91 | return square_bboxes 92 | 93 | 94 | def calibrate_box(bboxes, offsets): 95 | """Transform bounding boxes to be more like true bounding boxes. 96 | 'offsets' is one of the outputs of the nets. 97 | 98 | Arguments: 99 | bboxes: a float numpy array of shape [n, 5]. 100 | offsets: a float numpy array of shape [n, 4]. 101 | 102 | Returns: 103 | a float numpy array of shape [n, 5]. 104 | """ 105 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 106 | w = x2 - x1 + 1.0 107 | h = y2 - y1 + 1.0 108 | w = np.expand_dims(w, 1) 109 | h = np.expand_dims(h, 1) 110 | 111 | # this is what happening here: 112 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 113 | # x1_true = x1 + tx1*w 114 | # y1_true = y1 + ty1*h 115 | # x2_true = x2 + tx2*w 116 | # y2_true = y2 + ty2*h 117 | # below is just more compact form of this 118 | 119 | # are offsets always such that 120 | # x1 < x2 and y1 < y2 ? 121 | 122 | translation = np.hstack([w, h, w, h]) * offsets 123 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 124 | return bboxes 125 | 126 | 127 | def get_image_boxes(bounding_boxes, img, size=24): 128 | """Cut out boxes from the image. 129 | 130 | Arguments: 131 | bounding_boxes: a float numpy array of shape [n, 5]. 132 | img: an instance of PIL.Image. 133 | size: an integer, size of cutouts. 134 | 135 | Returns: 136 | a float numpy array of shape [n, 3, size, size]. 137 | """ 138 | 139 | num_boxes = len(bounding_boxes) 140 | width, height = img.size 141 | 142 | [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) 143 | img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') 144 | 145 | for i in range(num_boxes): 146 | img_box = np.zeros((h[i], w[i], 3), 'uint8') 147 | 148 | img_array = np.asarray(img, 'uint8') 149 | img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \ 150 | img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] 151 | 152 | # resize 153 | img_box = Image.fromarray(img_box) 154 | img_box = img_box.resize((size, size), Image.BILINEAR) 155 | img_box = np.asarray(img_box, 'float32') 156 | 157 | img_boxes[i, :, :, :] = _preprocess(img_box) 158 | 159 | return img_boxes 160 | 161 | 162 | def correct_bboxes(bboxes, width, height): 163 | """Crop boxes that are too big and get coordinates 164 | with respect to cutouts. 165 | 166 | Arguments: 167 | bboxes: a float numpy array of shape [n, 5], 168 | where each row is (xmin, ymin, xmax, ymax, score). 169 | width: a float number. 170 | height: a float number. 171 | 172 | Returns: 173 | dy, dx, edy, edx: a int numpy arrays of shape [n], 174 | coordinates of the boxes with respect to the cutouts. 175 | y, x, ey, ex: a int numpy arrays of shape [n], 176 | corrected ymin, xmin, ymax, xmax. 177 | h, w: a int numpy arrays of shape [n], 178 | just heights and widths of boxes. 179 | 180 | in the following order: 181 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 182 | """ 183 | 184 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 185 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 186 | num_boxes = bboxes.shape[0] 187 | 188 | # 'e' stands for end 189 | # (x, y) -> (ex, ey) 190 | x, y, ex, ey = x1, y1, x2, y2 191 | 192 | # we need to cut out a box from the image. 193 | # (x, y, ex, ey) are corrected coordinates of the box 194 | # in the image. 195 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 196 | # from the image. 197 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 198 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 199 | 200 | # if box's bottom right corner is too far right 201 | ind = np.where(ex > width - 1.0)[0] 202 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 203 | ex[ind] = width - 1.0 204 | 205 | # if box's bottom right corner is too low 206 | ind = np.where(ey > height - 1.0)[0] 207 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 208 | ey[ind] = height - 1.0 209 | 210 | # if box's top left corner is too far left 211 | ind = np.where(x < 0.0)[0] 212 | dx[ind] = 0.0 - x[ind] 213 | x[ind] = 0.0 214 | 215 | # if box's top left corner is too high 216 | ind = np.where(y < 0.0)[0] 217 | dy[ind] = 0.0 - y[ind] 218 | y[ind] = 0.0 219 | 220 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 221 | return_list = [i.astype('int32') for i in return_list] 222 | 223 | return return_list 224 | 225 | 226 | def _preprocess(img): 227 | """Preprocessing step before feeding the network. 228 | 229 | Arguments: 230 | img: a float numpy array of shape [h, w, c]. 231 | 232 | Returns: 233 | a float numpy array of shape [1, c, h, w]. 234 | """ 235 | img = img.transpose((2, 0, 1)) 236 | img = np.expand_dims(img, 0) 237 | img = (img - 127.5) * 0.0078125 238 | return img 239 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = 'cuda:0' 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, 'float32') 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack([ 93 | np.round((stride * inds[1] + 1.0) / scale), 94 | np.round((stride * inds[0] + 1.0) / scale), 95 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 96 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 97 | score, offsets 98 | ]) 99 | # why one is added? 100 | 101 | return bounding_boxes.T 102 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | from hparams import hparams as hp 7 | PNET_PATH = hp.mtcnn_path_pnet 8 | ONET_PATH = hp.mtcnn_path_onet 9 | RNET_PATH = hp.mtcnn_path_rnet 10 | 11 | 12 | class Flatten(nn.Module): 13 | 14 | def __init__(self): 15 | super(Flatten, self).__init__() 16 | 17 | def forward(self, x): 18 | """ 19 | Arguments: 20 | x: a float tensor with shape [batch_size, c, h, w]. 21 | Returns: 22 | a float tensor with shape [batch_size, c*h*w]. 23 | """ 24 | 25 | # without this pretrained model isn't working 26 | x = x.transpose(3, 2).contiguous() 27 | 28 | return x.view(x.size(0), -1) 29 | 30 | 31 | class PNet(nn.Module): 32 | 33 | def __init__(self): 34 | super().__init__() 35 | 36 | # suppose we have input with size HxW, then 37 | # after first layer: H - 2, 38 | # after pool: ceil((H - 2)/2), 39 | # after second conv: ceil((H - 2)/2) - 2, 40 | # after last conv: ceil((H - 2)/2) - 4, 41 | # and the same for W 42 | 43 | self.features = nn.Sequential(OrderedDict([ 44 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 45 | ('prelu1', nn.PReLU(10)), 46 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 47 | 48 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 49 | ('prelu2', nn.PReLU(16)), 50 | 51 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 52 | ('prelu3', nn.PReLU(32)) 53 | ])) 54 | 55 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 56 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 57 | 58 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 59 | for n, p in self.named_parameters(): 60 | p.data = torch.FloatTensor(weights[n]) 61 | 62 | def forward(self, x): 63 | """ 64 | Arguments: 65 | x: a float tensor with shape [batch_size, 3, h, w]. 66 | Returns: 67 | b: a float tensor with shape [batch_size, 4, h', w']. 68 | a: a float tensor with shape [batch_size, 2, h', w']. 69 | """ 70 | x = self.features(x) 71 | a = self.conv4_1(x) 72 | b = self.conv4_2(x) 73 | a = F.softmax(a, dim=-1) 74 | return b, a 75 | 76 | 77 | class RNet(nn.Module): 78 | 79 | def __init__(self): 80 | super().__init__() 81 | 82 | self.features = nn.Sequential(OrderedDict([ 83 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 84 | ('prelu1', nn.PReLU(28)), 85 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 86 | 87 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 88 | ('prelu2', nn.PReLU(48)), 89 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 90 | 91 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 92 | ('prelu3', nn.PReLU(64)), 93 | 94 | ('flatten', Flatten()), 95 | ('conv4', nn.Linear(576, 128)), 96 | ('prelu4', nn.PReLU(128)) 97 | ])) 98 | 99 | self.conv5_1 = nn.Linear(128, 2) 100 | self.conv5_2 = nn.Linear(128, 4) 101 | 102 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 103 | for n, p in self.named_parameters(): 104 | p.data = torch.FloatTensor(weights[n]) 105 | 106 | def forward(self, x): 107 | """ 108 | Arguments: 109 | x: a float tensor with shape [batch_size, 3, h, w]. 110 | Returns: 111 | b: a float tensor with shape [batch_size, 4]. 112 | a: a float tensor with shape [batch_size, 2]. 113 | """ 114 | x = self.features(x) 115 | a = self.conv5_1(x) 116 | b = self.conv5_2(x) 117 | a = F.softmax(a, dim=-1) 118 | return b, a 119 | 120 | 121 | class ONet(nn.Module): 122 | 123 | def __init__(self): 124 | super().__init__() 125 | 126 | self.features = nn.Sequential(OrderedDict([ 127 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 128 | ('prelu1', nn.PReLU(32)), 129 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 130 | 131 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 132 | ('prelu2', nn.PReLU(64)), 133 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 134 | 135 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 136 | ('prelu3', nn.PReLU(64)), 137 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 138 | 139 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 140 | ('prelu4', nn.PReLU(128)), 141 | 142 | ('flatten', Flatten()), 143 | ('conv5', nn.Linear(1152, 256)), 144 | ('drop5', nn.Dropout(0.25)), 145 | ('prelu5', nn.PReLU(256)), 146 | ])) 147 | 148 | self.conv6_1 = nn.Linear(256, 2) 149 | self.conv6_2 = nn.Linear(256, 4) 150 | self.conv6_3 = nn.Linear(256, 10) 151 | 152 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 153 | for n, p in self.named_parameters(): 154 | p.data = torch.FloatTensor(weights[n]) 155 | 156 | def forward(self, x): 157 | """ 158 | Arguments: 159 | x: a float tensor with shape [batch_size, 3, h, w]. 160 | Returns: 161 | c: a float tensor with shape [batch_size, 10]. 162 | b: a float tensor with shape [batch_size, 4]. 163 | a: a float tensor with shape [batch_size, 2]. 164 | """ 165 | x = self.features(x) 166 | a = self.conv6_1(x) 167 | b = self.conv6_2(x) 168 | c = self.conv6_3(x) 169 | a = F.softmax(a, dim=-1) 170 | return c, b, a 171 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 11 06:54:28 2017 4 | 5 | @author: zhaoyafei 6 | """ 7 | 8 | import numpy as np 9 | from numpy.linalg import inv, norm, lstsq 10 | from numpy.linalg import matrix_rank as rank 11 | 12 | 13 | class MatlabCp2tormException(Exception): 14 | def __str__(self): 15 | return 'In File {}:{}'.format( 16 | __file__, super.__str__(self)) 17 | 18 | 19 | def tformfwd(trans, uv): 20 | """ 21 | Function: 22 | ---------- 23 | apply affine transform 'trans' to uv 24 | 25 | Parameters: 26 | ---------- 27 | @trans: 3x3 np.array 28 | transform matrix 29 | @uv: Kx2 np.array 30 | each row is a pair of coordinates (x, y) 31 | 32 | Returns: 33 | ---------- 34 | @xy: Kx2 np.array 35 | each row is a pair of transformed coordinates (x, y) 36 | """ 37 | uv = np.hstack(( 38 | uv, np.ones((uv.shape[0], 1)) 39 | )) 40 | xy = np.dot(uv, trans) 41 | xy = xy[:, 0:-1] 42 | return xy 43 | 44 | 45 | def tforminv(trans, uv): 46 | """ 47 | Function: 48 | ---------- 49 | apply the inverse of affine transform 'trans' to uv 50 | 51 | Parameters: 52 | ---------- 53 | @trans: 3x3 np.array 54 | transform matrix 55 | @uv: Kx2 np.array 56 | each row is a pair of coordinates (x, y) 57 | 58 | Returns: 59 | ---------- 60 | @xy: Kx2 np.array 61 | each row is a pair of inverse-transformed coordinates (x, y) 62 | """ 63 | Tinv = inv(trans) 64 | xy = tformfwd(Tinv, uv) 65 | return xy 66 | 67 | 68 | def findNonreflectiveSimilarity(uv, xy, options=None): 69 | options = {'K': 2} 70 | 71 | K = options['K'] 72 | M = xy.shape[0] 73 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 74 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 75 | # print('--->x, y:\n', x, y 76 | 77 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) 78 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) 79 | X = np.vstack((tmp1, tmp2)) 80 | # print('--->X.shape: ', X.shape 81 | # print('X:\n', X 82 | 83 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 84 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 85 | U = np.vstack((u, v)) 86 | # print('--->U.shape: ', U.shape 87 | # print('U:\n', U 88 | 89 | # We know that X * r = U 90 | if rank(X) >= 2 * K: 91 | r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want 92 | r = np.squeeze(r) 93 | else: 94 | raise Exception('cp2tform:twoUniquePointsReq') 95 | 96 | # print('--->r:\n', r 97 | 98 | sc = r[0] 99 | ss = r[1] 100 | tx = r[2] 101 | ty = r[3] 102 | 103 | Tinv = np.array([ 104 | [sc, -ss, 0], 105 | [ss, sc, 0], 106 | [tx, ty, 1] 107 | ]) 108 | 109 | # print('--->Tinv:\n', Tinv 110 | 111 | T = inv(Tinv) 112 | # print('--->T:\n', T 113 | 114 | T[:, 2] = np.array([0, 0, 1]) 115 | 116 | return T, Tinv 117 | 118 | 119 | def findSimilarity(uv, xy, options=None): 120 | options = {'K': 2} 121 | 122 | # uv = np.array(uv) 123 | # xy = np.array(xy) 124 | 125 | # Solve for trans1 126 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) 127 | 128 | # Solve for trans2 129 | 130 | # manually reflect the xy data across the Y-axis 131 | xyR = xy 132 | xyR[:, 0] = -1 * xyR[:, 0] 133 | 134 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) 135 | 136 | # manually reflect the tform to undo the reflection done on xyR 137 | TreflectY = np.array([ 138 | [-1, 0, 0], 139 | [0, 1, 0], 140 | [0, 0, 1] 141 | ]) 142 | 143 | trans2 = np.dot(trans2r, TreflectY) 144 | 145 | # Figure out if trans1 or trans2 is better 146 | xy1 = tformfwd(trans1, uv) 147 | norm1 = norm(xy1 - xy) 148 | 149 | xy2 = tformfwd(trans2, uv) 150 | norm2 = norm(xy2 - xy) 151 | 152 | if norm1 <= norm2: 153 | return trans1, trans1_inv 154 | else: 155 | trans2_inv = inv(trans2) 156 | return trans2, trans2_inv 157 | 158 | 159 | def get_similarity_transform(src_pts, dst_pts, reflective=True): 160 | """ 161 | Function: 162 | ---------- 163 | Find Similarity Transform Matrix 'trans': 164 | u = src_pts[:, 0] 165 | v = src_pts[:, 1] 166 | x = dst_pts[:, 0] 167 | y = dst_pts[:, 1] 168 | [x, y, 1] = [u, v, 1] * trans 169 | 170 | Parameters: 171 | ---------- 172 | @src_pts: Kx2 np.array 173 | source points, each row is a pair of coordinates (x, y) 174 | @dst_pts: Kx2 np.array 175 | destination points, each row is a pair of transformed 176 | coordinates (x, y) 177 | @reflective: True or False 178 | if True: 179 | use reflective similarity transform 180 | else: 181 | use non-reflective similarity transform 182 | 183 | Returns: 184 | ---------- 185 | @trans: 3x3 np.array 186 | transform matrix from uv to xy 187 | trans_inv: 3x3 np.array 188 | inverse of trans, transform matrix from xy to uv 189 | """ 190 | 191 | if reflective: 192 | trans, trans_inv = findSimilarity(src_pts, dst_pts) 193 | else: 194 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) 195 | 196 | return trans, trans_inv 197 | 198 | 199 | def cvt_tform_mat_for_cv2(trans): 200 | """ 201 | Function: 202 | ---------- 203 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be 204 | directly used by cv2.warpAffine(): 205 | u = src_pts[:, 0] 206 | v = src_pts[:, 1] 207 | x = dst_pts[:, 0] 208 | y = dst_pts[:, 1] 209 | [x, y].T = cv_trans * [u, v, 1].T 210 | 211 | Parameters: 212 | ---------- 213 | @trans: 3x3 np.array 214 | transform matrix from uv to xy 215 | 216 | Returns: 217 | ---------- 218 | @cv2_trans: 2x3 np.array 219 | transform matrix from src_pts to dst_pts, could be directly used 220 | for cv2.warpAffine() 221 | """ 222 | cv2_trans = trans[:, 0:2].T 223 | 224 | return cv2_trans 225 | 226 | 227 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): 228 | """ 229 | Function: 230 | ---------- 231 | Find Similarity Transform Matrix 'cv2_trans' which could be 232 | directly used by cv2.warpAffine(): 233 | u = src_pts[:, 0] 234 | v = src_pts[:, 1] 235 | x = dst_pts[:, 0] 236 | y = dst_pts[:, 1] 237 | [x, y].T = cv_trans * [u, v, 1].T 238 | 239 | Parameters: 240 | ---------- 241 | @src_pts: Kx2 np.array 242 | source points, each row is a pair of coordinates (x, y) 243 | @dst_pts: Kx2 np.array 244 | destination points, each row is a pair of transformed 245 | coordinates (x, y) 246 | reflective: True or False 247 | if True: 248 | use reflective similarity transform 249 | else: 250 | use non-reflective similarity transform 251 | 252 | Returns: 253 | ---------- 254 | @cv2_trans: 2x3 np.array 255 | transform matrix from src_pts to dst_pts, could be directly used 256 | for cv2.warpAffine() 257 | """ 258 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) 259 | cv2_trans = cvt_tform_mat_for_cv2(trans) 260 | 261 | return cv2_trans 262 | 263 | 264 | if __name__ == '__main__': 265 | """ 266 | u = [0, 6, -2] 267 | v = [0, 3, 5] 268 | x = [-1, 0, 4] 269 | y = [-1, -10, 4] 270 | 271 | # In Matlab, run: 272 | # 273 | # uv = [u'; v']; 274 | # xy = [x'; y']; 275 | # tform_sim=cp2tform(uv,xy,'similarity'); 276 | # 277 | # trans = tform_sim.tdata.T 278 | # ans = 279 | # -0.0764 -1.6190 0 280 | # 1.6190 -0.0764 0 281 | # -3.2156 0.0290 1.0000 282 | # trans_inv = tform_sim.tdata.Tinv 283 | # ans = 284 | # 285 | # -0.0291 0.6163 0 286 | # -0.6163 -0.0291 0 287 | # -0.0756 1.9826 1.0000 288 | # xy_m=tformfwd(tform_sim, u,v) 289 | # 290 | # xy_m = 291 | # 292 | # -3.2156 0.0290 293 | # 1.1833 -9.9143 294 | # 5.0323 2.8853 295 | # uv_m=tforminv(tform_sim, x,y) 296 | # 297 | # uv_m = 298 | # 299 | # 0.5698 1.3953 300 | # 6.0872 2.2733 301 | # -2.6570 4.3314 302 | """ 303 | u = [0, 6, -2] 304 | v = [0, 3, 5] 305 | x = [-1, 0, 4] 306 | y = [-1, -10, 4] 307 | 308 | uv = np.array((u, v)).T 309 | xy = np.array((x, y)).T 310 | 311 | print('\n--->uv:') 312 | print(uv) 313 | print('\n--->xy:') 314 | print(xy) 315 | 316 | trans, trans_inv = get_similarity_transform(uv, xy) 317 | 318 | print('\n--->trans matrix:') 319 | print(trans) 320 | 321 | print('\n--->trans_inv matrix:') 322 | print(trans_inv) 323 | 324 | print('\n---> apply transform to uv') 325 | print('\nxy_m = uv_augmented * trans') 326 | uv_aug = np.hstack(( 327 | uv, np.ones((uv.shape[0], 1)) 328 | )) 329 | xy_m = np.dot(uv_aug, trans) 330 | print(xy_m) 331 | 332 | print('\nxy_m = tformfwd(trans, uv)') 333 | xy_m = tformfwd(trans, uv) 334 | print(xy_m) 335 | 336 | print('\n---> apply inverse transform to xy') 337 | print('\nuv_m = xy_augmented * trans_inv') 338 | xy_aug = np.hstack(( 339 | xy, np.ones((xy.shape[0], 1)) 340 | )) 341 | uv_m = np.dot(xy_aug, trans_inv) 342 | print(uv_m) 343 | 344 | print('\nuv_m = tformfwd(trans_inv, xy)') 345 | uv_m = tformfwd(trans_inv, xy) 346 | print(uv_m) 347 | 348 | uv_m = tforminv(trans, xy) 349 | print('\nuv_m = tforminv(trans, xy)') 350 | print(uv_m) 351 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /optimizer/LookAhead.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import chain 3 | from torch.optim.optimizer import Optimizer 4 | import torch 5 | import warnings 6 | 7 | 8 | class Lookahead(Optimizer): 9 | def __init__(self, optimizer, k=5, alpha=0.5): 10 | self.optimizer = optimizer 11 | self.k = k 12 | self.alpha = alpha 13 | self.param_groups = self.optimizer.param_groups 14 | self.state = defaultdict(dict) 15 | self.fast_state = self.optimizer.state 16 | for group in self.param_groups: 17 | group["counter"] = 0 18 | 19 | def update(self, group): 20 | for fast in group["params"]: 21 | param_state = self.state[fast] 22 | if "slow_param" not in param_state: 23 | param_state["slow_param"] = torch.zeros_like(fast.data) 24 | param_state["slow_param"].copy_(fast.data) 25 | slow = param_state["slow_param"] 26 | slow += (fast.data - slow) * self.alpha 27 | fast.data.copy_(slow) 28 | 29 | def update_lookahead(self): 30 | for group in self.param_groups: 31 | self.update(group) 32 | 33 | def step(self, closure=None): 34 | loss = self.optimizer.step(closure) 35 | for group in self.param_groups: 36 | if group["counter"] == 0: 37 | self.update(group) 38 | group["counter"] += 1 39 | if group["counter"] >= self.k: 40 | group["counter"] = 0 41 | return loss 42 | 43 | def state_dict(self): 44 | fast_state_dict = self.optimizer.state_dict() 45 | slow_state = { 46 | (id(k) if isinstance(k, torch.Tensor) else k): v 47 | for k, v in self.state.items() 48 | } 49 | fast_state = fast_state_dict["state"] 50 | param_groups = fast_state_dict["param_groups"] 51 | return { 52 | "fast_state": fast_state, 53 | "slow_state": slow_state, 54 | "param_groups": param_groups, 55 | } 56 | 57 | def load_state_dict(self, state_dict): 58 | slow_state_dict = { 59 | "state": state_dict["slow_state"], 60 | "param_groups": state_dict["param_groups"], 61 | } 62 | fast_state_dict = { 63 | "state": state_dict["fast_state"], 64 | "param_groups": state_dict["param_groups"], 65 | } 66 | super(Lookahead, self).load_state_dict(slow_state_dict) 67 | self.optimizer.load_state_dict(fast_state_dict) 68 | self.fast_state = self.optimizer.state 69 | 70 | def add_param_group(self, param_group): 71 | param_group["counter"] = 0 72 | self.optimizer.add_param_group(param_group) -------------------------------------------------------------------------------- /optimizer/RAdam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 <= betas[0] < 1.0: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= betas[1] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 17 | 18 | self.degenerated_to_sgd = degenerated_to_sgd 19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 20 | for param in params: 21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 22 | param['buffer'] = [[None, None, None] for _ in range(10)] 23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 24 | buffer=[[None, None, None] for _ in range(10)]) 25 | super(RAdam, self).__init__(params, defaults) 26 | 27 | def __setstate__(self, state): 28 | super(RAdam, self).__setstate__(state) 29 | 30 | def step(self, closure=None): 31 | 32 | loss = None 33 | if closure is not None: 34 | loss = closure() 35 | 36 | for group in self.param_groups: 37 | 38 | for p in group['params']: 39 | if p.grad is None: 40 | continue 41 | grad = p.grad.data.float() 42 | if grad.is_sparse: 43 | raise RuntimeError('RAdam does not support sparse gradients') 44 | 45 | p_data_fp32 = p.data.float() 46 | 47 | state = self.state[p] 48 | 49 | if len(state) == 0: 50 | state['step'] = 0 51 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 52 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 53 | else: 54 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 55 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 56 | 57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 58 | beta1, beta2 = group['betas'] 59 | 60 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 61 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 62 | 63 | state['step'] += 1 64 | buffered = group['buffer'][int(state['step'] % 10)] 65 | if state['step'] == buffered[0]: 66 | N_sma, step_size = buffered[1], buffered[2] 67 | else: 68 | buffered[0] = state['step'] 69 | beta2_t = beta2 ** state['step'] 70 | N_sma_max = 2 / (1 - beta2) - 1 71 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 72 | buffered[1] = N_sma 73 | 74 | # more conservative since it's an approximated value 75 | if N_sma >= 5: 76 | step_size = math.sqrt( 77 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 78 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 79 | elif self.degenerated_to_sgd: 80 | step_size = 1.0 / (1 - beta1 ** state['step']) 81 | else: 82 | step_size = -1 83 | buffered[2] = step_size 84 | 85 | # more conservative since it's an approximated value 86 | if N_sma >= 5: 87 | if group['weight_decay'] != 0: 88 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 89 | denom = exp_avg_sq.sqrt().add_(group['eps']) 90 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 91 | p.data.copy_(p_data_fp32) 92 | elif step_size > 0: 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 95 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 96 | p.data.copy_(p_data_fp32) 97 | 98 | return loss 99 | 100 | 101 | class PlainRAdam(Optimizer): 102 | 103 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 104 | if not 0.0 <= lr: 105 | raise ValueError("Invalid learning rate: {}".format(lr)) 106 | if not 0.0 <= eps: 107 | raise ValueError("Invalid epsilon value: {}".format(eps)) 108 | if not 0.0 <= betas[0] < 1.0: 109 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 110 | if not 0.0 <= betas[1] < 1.0: 111 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 112 | 113 | self.degenerated_to_sgd = degenerated_to_sgd 114 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 115 | 116 | super(PlainRAdam, self).__init__(params, defaults) 117 | 118 | def __setstate__(self, state): 119 | super(PlainRAdam, self).__setstate__(state) 120 | 121 | def step(self, closure=None): 122 | 123 | loss = None 124 | if closure is not None: 125 | loss = closure() 126 | 127 | for group in self.param_groups: 128 | 129 | for p in group['params']: 130 | if p.grad is None: 131 | continue 132 | grad = p.grad.data.float() 133 | if grad.is_sparse: 134 | raise RuntimeError('RAdam does not support sparse gradients') 135 | 136 | p_data_fp32 = p.data.float() 137 | 138 | state = self.state[p] 139 | 140 | if len(state) == 0: 141 | state['step'] = 0 142 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 143 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 144 | else: 145 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 146 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 147 | 148 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 149 | beta1, beta2 = group['betas'] 150 | 151 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 152 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 153 | 154 | state['step'] += 1 155 | beta2_t = beta2 ** state['step'] 156 | N_sma_max = 2 / (1 - beta2) - 1 157 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 158 | 159 | # more conservative since it's an approximated value 160 | if N_sma >= 5: 161 | if group['weight_decay'] != 0: 162 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 163 | step_size = group['lr'] * math.sqrt( 164 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 165 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 166 | denom = exp_avg_sq.sqrt().add_(group['eps']) 167 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 168 | p.data.copy_(p_data_fp32) 169 | elif self.degenerated_to_sgd: 170 | if group['weight_decay'] != 0: 171 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 172 | step_size = group['lr'] / (1 - beta1 ** state['step']) 173 | p_data_fp32.add_(-step_size, exp_avg) 174 | p.data.copy_(p_data_fp32) 175 | 176 | return loss 177 | 178 | 179 | class AdamW(Optimizer): 180 | 181 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 182 | if not 0.0 <= lr: 183 | raise ValueError("Invalid learning rate: {}".format(lr)) 184 | if not 0.0 <= eps: 185 | raise ValueError("Invalid epsilon value: {}".format(eps)) 186 | if not 0.0 <= betas[0] < 1.0: 187 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 188 | if not 0.0 <= betas[1] < 1.0: 189 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 190 | 191 | defaults = dict(lr=lr, betas=betas, eps=eps, 192 | weight_decay=weight_decay, warmup=warmup) 193 | super(AdamW, self).__init__(params, defaults) 194 | 195 | def __setstate__(self, state): 196 | super(AdamW, self).__setstate__(state) 197 | 198 | def step(self, closure=None): 199 | loss = None 200 | if closure is not None: 201 | loss = closure() 202 | 203 | for group in self.param_groups: 204 | 205 | for p in group['params']: 206 | if p.grad is None: 207 | continue 208 | grad = p.grad.data.float() 209 | if grad.is_sparse: 210 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 211 | 212 | p_data_fp32 = p.data.float() 213 | 214 | state = self.state[p] 215 | 216 | if len(state) == 0: 217 | state['step'] = 0 218 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 219 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 220 | else: 221 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 222 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 223 | 224 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 225 | beta1, beta2 = group['betas'] 226 | 227 | state['step'] += 1 228 | 229 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 230 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 231 | 232 | denom = exp_avg_sq.sqrt().add_(group['eps']) 233 | bias_correction1 = 1 - beta1 ** state['step'] 234 | bias_correction2 = 1 - beta2 ** state['step'] 235 | 236 | if group['warmup'] > state['step']: 237 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 238 | else: 239 | scheduled_lr = group['lr'] 240 | 241 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 242 | 243 | if group['weight_decay'] != 0: 244 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 245 | 246 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 247 | 248 | p.data.copy_(p_data_fp32) 249 | 250 | return loss -------------------------------------------------------------------------------- /optimizer/Ranger.py: -------------------------------------------------------------------------------- 1 | 2 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 3 | 4 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 5 | # and/or 6 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 7 | 8 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 9 | 10 | # This version = 20.4.11 11 | 12 | # Credits: 13 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 14 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 15 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 16 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 17 | 18 | # summary of changes: 19 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 20 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 21 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 22 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 23 | # changed eps to 1e-5 as better default than 1e-8. 24 | 25 | import math 26 | import torch 27 | from torch.optim.optimizer import Optimizer 28 | 29 | 30 | class Ranger(Optimizer): 31 | 32 | def __init__(self, params, lr=1e-3, # lr 33 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 34 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 35 | use_gc=True, gc_conv_only=False 36 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 37 | ): 38 | 39 | # parameter checks 40 | if not 0.0 <= alpha <= 1.0: 41 | raise ValueError(f'Invalid slow update rate: {alpha}') 42 | if not 1 <= k: 43 | raise ValueError(f'Invalid lookahead steps: {k}') 44 | if not lr > 0: 45 | raise ValueError(f'Invalid Learning Rate: {lr}') 46 | if not eps > 0: 47 | raise ValueError(f'Invalid eps: {eps}') 48 | 49 | # parameter comments: 50 | # beta1 (momentum) of .95 seems to work better than .90... 51 | # N_sma_threshold of 5 seems better in testing than 4. 52 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 53 | 54 | # prep defaults and init torch.optim base 55 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 56 | eps=eps, weight_decay=weight_decay) 57 | super().__init__(params, defaults) 58 | 59 | # adjustable threshold 60 | self.N_sma_threshhold = N_sma_threshhold 61 | 62 | # look ahead params 63 | 64 | self.alpha = alpha 65 | self.k = k 66 | 67 | # radam buffer for state 68 | self.radam_buffer = [[None, None, None] for ind in range(10)] 69 | 70 | # gc on or off 71 | self.use_gc = use_gc 72 | 73 | # level of gradient centralization 74 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 75 | 76 | print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") 77 | if (self.use_gc and self.gc_gradient_threshold == 1): 78 | print(f"GC applied to both conv and fc layers") 79 | elif (self.use_gc and self.gc_gradient_threshold == 3): 80 | print(f"GC applied to conv layers only") 81 | 82 | def __setstate__(self, state): 83 | print("set state called") 84 | super(Ranger, self).__setstate__(state) 85 | 86 | def step(self, closure=None): 87 | loss = None 88 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 89 | # Uncomment if you need to use the actual closure... 90 | 91 | # if closure is not None: 92 | # loss = closure() 93 | 94 | # Evaluate averages and grad, update param tensors 95 | for group in self.param_groups: 96 | 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | grad = p.grad.data.float() 101 | 102 | if grad.is_sparse: 103 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 104 | 105 | p_data_fp32 = p.data.float() 106 | 107 | state = self.state[p] # get state dict for this param 108 | 109 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 110 | # if self.first_run_check==0: 111 | # self.first_run_check=1 112 | # print("Initializing slow buffer...should not see this at load from saved model!") 113 | state['step'] = 0 114 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 115 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 116 | 117 | # look ahead weight storage now in state dict 118 | state['slow_buffer'] = torch.empty_like(p.data) 119 | state['slow_buffer'].copy_(p.data) 120 | 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | # begin computations 126 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 127 | beta1, beta2 = group['betas'] 128 | 129 | # GC operation for Conv layers and FC layers 130 | if grad.dim() > self.gc_gradient_threshold: 131 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 132 | 133 | state['step'] += 1 134 | 135 | # compute variance mov avg 136 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 137 | # compute mean moving avg 138 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 139 | 140 | buffered = self.radam_buffer[int(state['step'] % 10)] 141 | 142 | if state['step'] == buffered[0]: 143 | N_sma, step_size = buffered[1], buffered[2] 144 | else: 145 | buffered[0] = state['step'] 146 | beta2_t = beta2 ** state['step'] 147 | N_sma_max = 2 / (1 - beta2) - 1 148 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 149 | buffered[1] = N_sma 150 | if N_sma > self.N_sma_threshhold: 151 | step_size = math.sqrt( 152 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 153 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 154 | else: 155 | step_size = 1.0 / (1 - beta1 ** state['step']) 156 | buffered[2] = step_size 157 | 158 | if group['weight_decay'] != 0: 159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 160 | 161 | # apply lr 162 | if N_sma > self.N_sma_threshhold: 163 | denom = exp_avg_sq.sqrt().add_(group['eps']) 164 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 165 | else: 166 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 167 | 168 | p.data.copy_(p_data_fp32) 169 | 170 | # integrated look ahead... 171 | # we do it at the param level instead of group level 172 | if state['step'] % group['k'] == 0: 173 | slow_p = state['slow_buffer'] # get access to slow param tensor 174 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 175 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 176 | 177 | return loss -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/optimizer/__init__.py -------------------------------------------------------------------------------- /pre_trained_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/pre_trained_model/__init__.py -------------------------------------------------------------------------------- /pre_trained_model/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut -------------------------------------------------------------------------------- /pre_trained_model/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from pre_trained_model.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model -------------------------------------------------------------------------------- /stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MontaEllis/Framework-of-GAN-Inversion/43a1dca71bb5ec10b7de457b48166d7111252704/stylegan2/__init__.py -------------------------------------------------------------------------------- /stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /stylegan2/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /stylegan2/stylegan2_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import utils 3 | import sys 4 | from stylegan2.model import Generator,Discriminator 5 | from tqdm import tqdm 6 | from torch.nn import functional as F 7 | from hparams import hparams as hp 8 | 9 | 10 | 11 | class infer_face(): 12 | 13 | def __init__(self,weight_path): 14 | 15 | self.device = "cuda" 16 | self.weight_path = weight_path 17 | 18 | 19 | self.num_truncation_mean = 4096 20 | 21 | self.truncation =0.5 22 | 23 | self.checkpoint = torch.load(self.weight_path) 24 | 25 | self.g_ema = Generator(hp.img_size, 512, 8, channel_multiplier=2).to(self.device) 26 | self.g_ema.load_state_dict(self.checkpoint["g_ema"]) 27 | # self.g_ema.eval() 28 | # for parm in self.g_ema.parameters(): 29 | # parm.requires_grad = False 30 | 31 | 32 | self.discriminator = Discriminator(hp.img_size, channel_multiplier=2).to(self.device) 33 | self.discriminator.load_state_dict(self.checkpoint["d"]) 34 | 35 | 36 | if self.truncation < 1: 37 | with torch.no_grad(): 38 | self.mean_latent = self.g_ema.mean_latent(self.num_truncation_mean) 39 | else: 40 | self.mean_latent = None 41 | # self.mean_latent = 0 42 | 43 | 44 | 45 | def random_init_w(self): 46 | sample_z = torch.randn(1, 512, device=self.device) 47 | w = self.g_ema.get_latent(sample_z) 48 | 49 | # with torch.no_grad(): 50 | # _, w = self.g_ema([sample_z], truncation=0.5, return_latents=True,truncation_latent=self.mean_latent) 51 | return w[0] 52 | 53 | def g_nonsaturating_loss(self,fake_pred): 54 | loss = F.softplus(-fake_pred).mean() 55 | 56 | return loss 57 | 58 | 59 | 60 | def generate_from_synthesis(self, w, direction, randomize_noise, return_latents): 61 | if direction is not None: 62 | if torch.is_tensor(direction): 63 | pass 64 | else: 65 | direction = torch.Tensor(direction).float().cuda() 66 | latent_code = (w + direction) 67 | else: 68 | latent_code = w 69 | 70 | sample, _ = self.g_ema( 71 | # [latent_code], truncation=1, input_is_latent=True, truncation_latent=self.mean_latent 72 | [latent_code], 73 | input_is_latent=True, 74 | randomize_noise=randomize_noise, 75 | return_latents=return_latents 76 | ) 77 | 78 | return sample 79 | 80 | 81 | 82 | def disc(self,img): 83 | 84 | fake_pred = self.discriminator(img) 85 | loss = self.g_nonsaturating_loss(fake_pred) 86 | # print(fake_pred) 87 | return loss 88 | 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from weights_init.weight_init_normal import weights_init_normal 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3' 4 | devicess = [0,1,2] 5 | import re 6 | import time 7 | import argparse 8 | import numpy as np 9 | from torch._six import container_abcs, string_classes, int_classes 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch import nn 14 | import torch.distributed as dist 15 | import math 16 | import warnings 17 | from tqdm import tqdm 18 | from torch.optim.lr_scheduler import ReduceLROnPlateau,StepLR,MultiStepLR 19 | from torchvision import utils 20 | from hparams import hparams as hp 21 | from torch.autograd import Variable 22 | from torch_warmup_lr import WarmupLR 23 | from optimizer.LookAhead import Lookahead 24 | from optimizer.RAdam import RAdam 25 | from optimizer.Ranger import Ranger 26 | warnings.filterwarnings("ignore") 27 | from weights_init.weight_init_normal import weights_init_normal 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | 30 | 31 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 32 | 33 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 34 | 35 | def parse_training_args(parser): 36 | """ 37 | Parse commandline arguments. 38 | """ 39 | 40 | parser.add_argument('-o', '--output_dir', type=str, default=hp.output_dir, required=False, help='Directory to save checkpoints') 41 | parser.add_argument('--latest-checkpoint-file', type=str, default=hp.latest_checkpoint_file, help='Store the latest checkpoint in each epoch') 42 | 43 | # training 44 | training = parser.add_argument_group('training setup') 45 | training.add_argument('--epochs', type=int, default=hp.epochs, help='Number of total epochs to run') 46 | training.add_argument('--epochs-per-checkpoint', type=int, default=hp.epochs_per_checkpoint, help='Number of epochs per checkpoint') 47 | training.add_argument('--batch', type=int, default=hp.batch, help='batch-size') 48 | 49 | parser.add_argument( 50 | '-k', 51 | "--ckpt", 52 | type=str, 53 | default=hp.ckpt, 54 | help="path to the checkpoints to resume training", 55 | ) 56 | 57 | parser.add_argument("--init-lr", type=float, default=hp.init_lr, help="learning rate") 58 | 59 | 60 | parser.add_argument( 61 | "--local_rank", type=int, default=0, help="local rank for distributed training" 62 | ) 63 | 64 | training.add_argument('--cudnn-enabled', default=True, help='Enable cudnn') 65 | training.add_argument('--cudnn-benchmark', default=True, help='Run cudnn benchmark') 66 | 67 | return parser 68 | 69 | 70 | 71 | def train(): 72 | 73 | parser = argparse.ArgumentParser(description=hp.description) 74 | parser = parse_training_args(parser) 75 | args, _ = parser.parse_known_args() 76 | args = parser.parse_args() 77 | torch.backends.cudnn.deterministic = True 78 | torch.backends.cudnn.enabled = args.cudnn_enabled 79 | torch.backends.cudnn.benchmark = args.cudnn_benchmark 80 | 81 | 82 | 83 | os.makedirs(args.output_dir, exist_ok=True) 84 | 85 | 86 | from stylegan2.stylegan2_infer import infer_face 87 | class_generate = infer_face(hp.weight_path_pytorch) 88 | 89 | 90 | n_styles = 2*int(math.log(hp.img_size, 2))-2 91 | if hp.backbone == 'GradualStyleEncoder': 92 | from models.fpn_encoders import GradualStyleEncoder 93 | model = GradualStyleEncoder(num_layers=50,n_styles=n_styles) 94 | elif hp.backbone == 'ResNetGradualStyleEncoder': 95 | from models.fpn_encoders import ResNetGradualStyleEncoder 96 | model = ResNetGradualStyleEncoder(n_styles=n_styles) 97 | else: 98 | Exception('Backbone error!') 99 | 100 | 101 | if hp.apply_init: 102 | model.apply(weights_init_normal) 103 | 104 | 105 | model = torch.nn.DataParallel(model, device_ids=devicess) 106 | 107 | 108 | # params = list(model.parameters()) + list(class_generate.g_ema.parameters()) 109 | params = list(model.parameters()) 110 | if hp.optimizer_mode == 'adam': 111 | optimizer = torch.optim.Adam(params, lr=args.init_lr, betas=(0.95, 0.999)) 112 | elif hp.optimizer_mode == 'sgd': 113 | optimizer = torch.optim.SGD(params, lr=args.init_lr, momentum=0.9, weight_decay=0.0005) 114 | elif hp.optimizer_mode == 'radam': 115 | optimizer = RAdam(params, lr=args.init_lr, betas=(0.95, 0.999)) 116 | elif hp.optimizer_mode == 'lookahead': 117 | optimizer = Lookahead(params) 118 | elif hp.optimizer_mode == 'ranger': 119 | optimizer = Ranger(params, lr=args.init_lr) 120 | else: 121 | raise Exception('Optimizer error!') 122 | 123 | 124 | 125 | 126 | 127 | 128 | if hp.scheduler_mode == 'StepLR': 129 | scheduler = StepLR(optimizer, step_size=10, gamma=0.1) 130 | elif hp.scheduler_mode == 'MultiStepLR': 131 | scheduler = MultiStepLR(optimizer, milestones=[3,6,9], gamma=0.1) 132 | elif hp.scheduler_mode == 'ReduceLROnPlateau': 133 | scheduler = ReduceLROnPlateau(optimizer, threshold=0.99, mode='min', patience=2, cooldown=5) 134 | else: 135 | raise Exception('Scheduler error!') 136 | 137 | if hp.open_warn_up: 138 | scheduler = WarmupLR(scheduler, init_lr=hp.init_lr, num_warmup=hp.num_warmup, warmup_strategy=hp.warn_up_strategy) 139 | 140 | 141 | 142 | if args.ckpt: 143 | print("load model:", args.ckpt) 144 | print(os.path.join(args.output_dir, args.latest_checkpoint_file)) 145 | ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage) 146 | 147 | model.load_state_dict(ckpt["model"]) 148 | 149 | optimizer.load_state_dict(ckpt["optim"]) 150 | for state in optimizer.state.values(): 151 | for k, v in state.items(): 152 | if torch.is_tensor(v): 153 | state[k] = v.cuda() 154 | 155 | # scheduler.load_state_dict(ckpt["scheduler"]) 156 | elapsed_epochs = ckpt["epoch"] 157 | 158 | else: 159 | elapsed_epochs = 0 160 | 161 | 162 | # model cuda 163 | model.cuda() 164 | 165 | from criteria import all_loss 166 | criterion = all_loss.Base_Loss() 167 | 168 | writer = SummaryWriter(args.output_dir) 169 | 170 | 171 | 172 | 173 | from data_function import ImageData 174 | 175 | train_dataset = ImageData(hp.dataset_path, hp.transform['transform_train']) 176 | train_loader = DataLoader(train_dataset, 177 | batch_size=args.batch, 178 | shuffle=True, 179 | pin_memory=False, 180 | drop_last=True) 181 | 182 | 183 | 184 | epochs = args.epochs - elapsed_epochs 185 | iteration = elapsed_epochs * len(train_loader) 186 | 187 | 188 | model.train() 189 | 190 | for epoch in range(1, epochs + 1): 191 | 192 | 193 | epoch += elapsed_epochs 194 | print("epoch:"+str(epoch)) 195 | 196 | 197 | for i, batch in enumerate(train_loader): 198 | 199 | 200 | print(f"Batch: {i}/{len(train_loader)} epoch {epoch}") 201 | 202 | img = batch.cuda() 203 | 204 | outputs = model(img) 205 | 206 | predicts = class_generate.generate_from_synthesis(outputs,None,randomize_noise=True,return_latents=True) 207 | if hp.resize: 208 | predicts = face_pool(predicts) 209 | if hp.dataset_type == 'car': 210 | predicts = predicts[:, :, 32:224, :] 211 | 212 | # torch.set_grad_enabled(True) 213 | optimizer.zero_grad() 214 | 215 | 216 | loss_all,loss_mse,loss_lpips,loss_per = criterion(img,predicts) 217 | ## log 218 | writer.add_scalar('Refine/Loss', loss_all.item(), iteration) 219 | writer.add_scalar('Refine/loss_mse', loss_mse.item(), iteration) 220 | writer.add_scalar('Refine/loss_lpips', loss_lpips.item(), iteration) 221 | writer.add_scalar('Refine/loss_per', loss_per.item(), iteration) 222 | loss_all.backward() 223 | optimizer.step() 224 | print("loss:"+str(loss_all.item())) 225 | print('lr:'+str(scheduler._last_lr[0])) 226 | iteration += 1 227 | scheduler.step() 228 | 229 | # Store latest checkpoint in each epoch 230 | torch.save( 231 | { 232 | "model": model.state_dict(), 233 | "optim": optimizer.state_dict(), 234 | "scheduler":scheduler.state_dict(), 235 | "epoch": epoch, 236 | 237 | }, 238 | os.path.join(args.output_dir, args.latest_checkpoint_file), 239 | ) 240 | 241 | # Save checkpoint 242 | if epoch % args.epochs_per_checkpoint == 0: 243 | torch.save( 244 | { 245 | 246 | "model": model.state_dict(), 247 | "optim": optimizer.state_dict(), 248 | "epoch": epoch, 249 | }, 250 | os.path.join(args.output_dir, f"checkpoint_{epoch:04d}.pt"), 251 | ) 252 | 253 | with torch.no_grad(): 254 | utils.save_image( 255 | predicts, 256 | os.path.join(args.output_dir,("step-{}-predict.png").format(epoch)), 257 | nrow=hp.row, 258 | normalize=hp.norm, 259 | range=hp.rangee, 260 | ) 261 | 262 | with torch.no_grad(): 263 | utils.save_image( 264 | img, 265 | os.path.join(args.output_dir,("step-{}-origin.png").format(epoch)), 266 | nrow=hp.row, 267 | normalize=hp.norm, 268 | range=hp.rangee, 269 | ) 270 | 271 | writer.close() 272 | 273 | 274 | 275 | 276 | 277 | if __name__ == '__main__': 278 | train() 279 | -------------------------------------------------------------------------------- /train_ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from weights_init.weight_init_normal import weights_init_normal 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3' 4 | devicess = [0,1,2] 5 | import re 6 | import time 7 | import argparse 8 | import numpy as np 9 | from torch._six import container_abcs, string_classes, int_classes 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch import nn 14 | import torch.distributed as dist 15 | import math 16 | import warnings 17 | from tqdm import tqdm 18 | from torch.optim.lr_scheduler import ReduceLROnPlateau,StepLR,MultiStepLR 19 | from torchvision import utils 20 | from hparams import hparams as hp 21 | from torch.autograd import Variable 22 | from torch_warmup_lr import WarmupLR 23 | from optimizer.LookAhead import Lookahead 24 | from optimizer.RAdam import RAdam 25 | from optimizer.Ranger import Ranger 26 | warnings.filterwarnings("ignore") 27 | from weights_init.weight_init_normal import weights_init_normal 28 | from distributed import ( 29 | get_rank, 30 | synchronize, 31 | reduce_loss_dict, 32 | reduce_sum, 33 | get_world_size, 34 | ) 35 | from torch.utils import data 36 | from time import sleep 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | 39 | 40 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 41 | 42 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 43 | 44 | 45 | def data_sampler(dataset, shuffle, distributed): 46 | if distributed: 47 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 48 | 49 | if shuffle: 50 | return data.RandomSampler(dataset) 51 | 52 | else: 53 | return data.SequentialSampler(dataset) 54 | 55 | 56 | 57 | 58 | def parse_training_args(parser): 59 | """ 60 | Parse commandline arguments. 61 | """ 62 | 63 | parser.add_argument('-o', '--output_dir', type=str, default=hp.output_dir, required=False, help='Directory to save checkpoints') 64 | parser.add_argument('--latest-checkpoint-file', type=str, default=hp.latest_checkpoint_file, help='Store the latest checkpoint in each epoch') 65 | 66 | # training 67 | training = parser.add_argument_group('training setup') 68 | training.add_argument('--epochs', type=int, default=hp.epochs, help='Number of total epochs to run') 69 | training.add_argument('--epochs-per-checkpoint', type=int, default=hp.epochs_per_checkpoint, help='Number of epochs per checkpoint') 70 | training.add_argument('--batch', type=int, default=hp.batch, help='batch-size') 71 | 72 | parser.add_argument( 73 | '-k', 74 | "--ckpt", 75 | type=str, 76 | default=hp.ckpt, 77 | help="path to the checkpoints to resume training", 78 | ) 79 | 80 | parser.add_argument("--init-lr", type=float, default=hp.init_lr, help="learning rate") 81 | 82 | 83 | parser.add_argument( 84 | "--local_rank", type=int, default=0, help="local rank for distributed training" 85 | ) 86 | 87 | training.add_argument('--cudnn-enabled', default=True, help='Enable cudnn') 88 | training.add_argument('--cudnn-benchmark', default=True, help='Run cudnn benchmark') 89 | 90 | return parser 91 | 92 | 93 | 94 | def train(): 95 | 96 | parser = argparse.ArgumentParser(description=hp.description) 97 | parser = parse_training_args(parser) 98 | args, _ = parser.parse_known_args() 99 | args = parser.parse_args() 100 | torch.backends.cudnn.deterministic = True 101 | torch.backends.cudnn.enabled = args.cudnn_enabled 102 | torch.backends.cudnn.benchmark = args.cudnn_benchmark 103 | 104 | 105 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 106 | args.distributed = n_gpu > 1 107 | if args.distributed: 108 | torch.cuda.set_device(args.local_rank) 109 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 110 | synchronize() 111 | 112 | 113 | 114 | os.makedirs(args.output_dir, exist_ok=True) 115 | 116 | 117 | from stylegan2.stylegan2_infer import infer_face 118 | class_generate = infer_face(hp.weight_path_pytorch) 119 | 120 | 121 | n_styles = 2*int(math.log(hp.img_size, 2))-2 122 | if hp.backbone == 'GradualStyleEncoder': 123 | from models.fpn_encoders import GradualStyleEncoder 124 | model = GradualStyleEncoder(num_layers=50,n_styles=n_styles) 125 | elif hp.backbone == 'ResNetGradualStyleEncoder': 126 | from models.fpn_encoders import ResNetGradualStyleEncoder 127 | model = ResNetGradualStyleEncoder(n_styles=n_styles) 128 | else: 129 | Exception('Backbone error!') 130 | 131 | 132 | if hp.apply_init: 133 | model.apply(weights_init_normal) 134 | 135 | 136 | # model cuda 137 | model.cuda() 138 | 139 | 140 | if args.distributed: 141 | model = nn.parallel.DistributedDataParallel( 142 | model, 143 | device_ids=[args.local_rank], 144 | output_device=args.local_rank, 145 | broadcast_buffers=False, 146 | ) 147 | 148 | 149 | # params = list(model.parameters()) + list(class_generate.g_ema.parameters()) 150 | params = list(model.parameters()) 151 | if hp.optimizer_mode == 'adam': 152 | optimizer = torch.optim.Adam(params, lr=args.init_lr, betas=(0.95, 0.999)) 153 | elif hp.optimizer_mode == 'sgd': 154 | optimizer = torch.optim.SGD(params, lr=args.init_lr, momentum=0.9, weight_decay=0.0005) 155 | elif hp.optimizer_mode == 'radam': 156 | optimizer = RAdam(params, lr=args.init_lr, betas=(0.95, 0.999)) 157 | elif hp.optimizer_mode == 'lookahead': 158 | optimizer = Lookahead(params) 159 | elif hp.optimizer_mode == 'ranger': 160 | optimizer = Ranger(params, lr=args.init_lr) 161 | else: 162 | raise Exception('Optimizer error!') 163 | 164 | 165 | 166 | 167 | 168 | 169 | if hp.scheduler_mode == 'StepLR': 170 | scheduler = StepLR(optimizer, step_size=10, gamma=0.1) 171 | elif hp.scheduler_mode == 'MultiStepLR': 172 | scheduler = MultiStepLR(optimizer, milestones=[3,6,9], gamma=0.1) 173 | elif hp.scheduler_mode == 'ReduceLROnPlateau': 174 | scheduler = ReduceLROnPlateau(optimizer, threshold=0.99, mode='min', patience=2, cooldown=5) 175 | else: 176 | raise Exception('Scheduler error!') 177 | 178 | if hp.open_warn_up: 179 | scheduler = WarmupLR(scheduler, init_lr=hp.init_lr, num_warmup=hp.num_warmup, warmup_strategy=hp.warn_up_strategy) 180 | 181 | 182 | 183 | if args.ckpt: 184 | print("load model:", args.ckpt) 185 | print(os.path.join(args.output_dir, args.latest_checkpoint_file)) 186 | ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage) 187 | 188 | model.load_state_dict(ckpt["model"]) 189 | 190 | optimizer.load_state_dict(ckpt["optim"]) 191 | for state in optimizer.state.values(): 192 | for k, v in state.items(): 193 | if torch.is_tensor(v): 194 | state[k] = v.cuda() 195 | 196 | # scheduler.load_state_dict(ckpt["scheduler"]) 197 | elapsed_epochs = ckpt["epoch"] 198 | 199 | else: 200 | elapsed_epochs = 0 201 | 202 | 203 | 204 | 205 | from criteria import all_loss 206 | criterion = all_loss.Base_Loss() 207 | 208 | if get_rank() == 0: 209 | writer = SummaryWriter(args.output_dir) 210 | 211 | 212 | 213 | 214 | from data_function import ImageData 215 | 216 | train_dataset = ImageData(hp.dataset_path, hp.transform['transform_train']) 217 | train_loader = DataLoader(train_dataset, 218 | batch_size=args.batch, 219 | sampler=data_sampler(train_dataset, shuffle=True, distributed=args.distributed), 220 | pin_memory=False, 221 | drop_last=True) 222 | 223 | 224 | 225 | epochs = args.epochs - elapsed_epochs 226 | iteration = elapsed_epochs * len(train_loader) 227 | 228 | if args.distributed: 229 | model = model.module 230 | else: 231 | pass 232 | 233 | 234 | model.train() 235 | 236 | if get_rank() == 0: 237 | progress_bar = tqdm(total = args.epochs, desc = "Total progress", dynamic_ncols=True) 238 | progress_bar.update(epochs) 239 | interior_step_bar = tqdm(dynamic_ncols=True) 240 | 241 | 242 | 243 | for epoch in range(1, epochs + 1): 244 | 245 | if get_rank() == 0: 246 | progress_bar.set_description('Epoch %i' % epoch) 247 | 248 | sleep(0.1) 249 | progress_bar.update(1) 250 | 251 | interior_step_bar.reset(total=len(train_loader)) 252 | interior_step_bar.set_description(f"Interior steps") 253 | 254 | epoch += elapsed_epochs 255 | # print("epoch:"+str(epoch)) 256 | 257 | 258 | for i, batch in enumerate(train_loader): 259 | 260 | if get_rank() == 0: 261 | 262 | interior_step_bar.update(1) 263 | # print(f"Batch: {i}/{len(train_loader)} epoch {epoch}") 264 | 265 | img = batch.cuda(non_blocking=True) 266 | 267 | outputs = model(img) 268 | 269 | predicts = class_generate.generate_from_synthesis(outputs,None,randomize_noise=True,return_latents=True) 270 | 271 | 272 | if hp.resize: 273 | predicts = face_pool(predicts) 274 | 275 | 276 | if hp.dataset_type == 'car': 277 | predicts = predicts[:, :, 32:224, :] 278 | 279 | # torch.set_grad_enabled(True) 280 | optimizer.zero_grad() 281 | 282 | 283 | loss_all,loss_mse,loss_lpips,loss_per = criterion(img,predicts) 284 | ## log 285 | if get_rank() == 0: 286 | writer.add_scalar('Refine/Loss', loss_all.item(), iteration) 287 | writer.add_scalar('Refine/loss_mse', loss_mse.item(), iteration) 288 | writer.add_scalar('Refine/loss_lpips', loss_lpips.item(), iteration) 289 | writer.add_scalar('Refine/loss_per', loss_per.item(), iteration) 290 | interior_step_bar.set_postfix(loss=str(loss_all.item()),lr=str(scheduler._last_lr[0])) 291 | loss_all.backward() 292 | optimizer.step() 293 | # print("loss:"+str(loss_all.item())) 294 | # print('lr:'+str(scheduler._last_lr[0])) 295 | iteration += 1 296 | scheduler.step() 297 | 298 | # Store latest checkpoint in each epoch 299 | if get_rank() == 0: 300 | torch.save( 301 | { 302 | "model": model.state_dict(), 303 | "optim": optimizer.state_dict(), 304 | "scheduler":scheduler.state_dict(), 305 | "epoch": epoch, 306 | 307 | }, 308 | os.path.join(args.output_dir, args.latest_checkpoint_file), 309 | ) 310 | 311 | # Save checkpoint 312 | if epoch % args.epochs_per_checkpoint == 0: 313 | 314 | torch.save( 315 | { 316 | 317 | "model": model.state_dict(), 318 | "optim": optimizer.state_dict(), 319 | "epoch": epoch, 320 | }, 321 | os.path.join(args.output_dir, f"checkpoint_{epoch:04d}.pt"), 322 | ) 323 | 324 | with torch.no_grad(): 325 | utils.save_image( 326 | predicts, 327 | os.path.join(args.output_dir,("step-{}-predict.png").format(epoch)), 328 | nrow=hp.row, 329 | normalize=hp.norm, 330 | range=hp.rangee, 331 | ) 332 | 333 | with torch.no_grad(): 334 | utils.save_image( 335 | img, 336 | os.path.join(args.output_dir,("step-{}-origin.png").format(epoch)), 337 | nrow=hp.row, 338 | normalize=hp.norm, 339 | range=hp.rangee, 340 | ) 341 | 342 | writer.close() 343 | 344 | 345 | 346 | 347 | 348 | if __name__ == '__main__': 349 | train() 350 | -------------------------------------------------------------------------------- /transforms_config/car_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | 4 | 5 | def get_transforms(): 6 | transforms_dict = { 7 | 'transform_train': transforms.Compose([ 8 | transforms.Resize((192, 256)), 9 | transforms.RandomHorizontalFlip(0.5), 10 | transforms.ToTensor(), 11 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 12 | 'transform_inference': transforms.Compose([ 13 | transforms.Resize((192, 256)), 14 | transforms.ToTensor(), 15 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 16 | } 17 | return transforms_dict 18 | 19 | 20 | -------------------------------------------------------------------------------- /transforms_config/normal_transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | 4 | 5 | 6 | def get_transforms(): 7 | transforms_dict = { 8 | 'transform_train': transforms.Compose([ 9 | transforms.Resize((256, 256)), 10 | transforms.RandomHorizontalFlip(0.5), 11 | transforms.ToTensor(), 12 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 13 | 'transform_inference': transforms.Compose([ 14 | transforms.Resize((256, 256)), 15 | transforms.ToTensor(), 16 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 17 | } 18 | return transforms_dict 19 | 20 | 21 | -------------------------------------------------------------------------------- /weights_init/weight_init_normal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from hparams import hparams as hp 3 | def weights_init_normal(m): 4 | classname = m.__class__.__name__ 5 | gain = 0.02 6 | init_type = hp.init_type 7 | 8 | if classname.find('BatchNorm2d') != -1: 9 | if hasattr(m, 'weight') and m.weight is not None: 10 | torch.nn.init.normal_(m.weight.data, 1.0, gain) 11 | if hasattr(m, 'bias') and m.bias is not None: 12 | torch.nn.init.constant_(m.bias.data, 0.0) 13 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 14 | if init_type == 'normal': 15 | torch.nn.init.normal_(m.weight.data, 0.0, gain) 16 | elif init_type == 'xavier': 17 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain) 18 | elif init_type == 'xavier_uniform': 19 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0) 20 | elif init_type == 'kaiming': 21 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 22 | elif init_type == 'orthogonal': 23 | torch.nn.init.orthogonal_(m.weight.data, gain=gain) 24 | elif init_type == 'none': # uses pytorch's default init method 25 | m.reset_parameters() 26 | else: 27 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | torch.nn.init.constant_(m.bias.data, 0.0) --------------------------------------------------------------------------------