├── models ├── __init__.py ├── bisenet │ ├── __init__.py │ └── resnet.py ├── encoders │ ├── __init__.py │ ├── .DS_Store │ └── model_irse.py ├── mtcnn │ ├── __init__.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy │ │ ├── visualization_utils.py │ │ ├── first_stage.py │ │ ├── detector.py │ │ └── get_nets.py ├── stylegan2 │ ├── __init__.py │ └── op │ │ ├── __init__.py │ │ ├── fused_bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── fused_act.py │ │ └── fused_bias_act_kernel.cu ├── hypernetworks │ ├── __init__.py │ ├── shared_weights_hypernet.py │ └── hypernetwork.py ├── invertibility │ ├── __init__.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── replicate.py │ │ └── comm.py │ ├── backbone │ │ └── __init__.py │ ├── decoder.py │ ├── deeplab.py │ └── aspp.py ├── discriminator.py ├── segmenter.py ├── latent_codes_pool.py └── encoder.py ├── utils ├── __init__.py ├── facer │ ├── facer │ │ ├── version.py │ │ ├── face_landmark │ │ │ ├── __init__.py │ │ │ └── base.py │ │ ├── face_parsing │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── farl.py │ │ ├── face_detection │ │ │ ├── __init__.py │ │ │ └── base.py │ │ ├── io.py │ │ ├── show.py │ │ ├── __init__.py │ │ └── util.py │ ├── scripts │ │ ├── publish.sh │ │ └── build.sh │ ├── requirements.txt │ ├── samples │ │ ├── data │ │ │ ├── fire.webp │ │ │ ├── girl.jpg │ │ │ ├── sideface.jpg │ │ │ ├── twogirls.jpg │ │ │ ├── weirdface.jpg │ │ │ ├── weirdface2.jpg │ │ │ └── weirdface3.jpg │ │ ├── transform.ipynb │ │ ├── face_detect.ipynb │ │ └── face_parsing.ipynb │ ├── LICENSE │ ├── README.md │ └── setup.py ├── data_utils.py ├── dist.py ├── wandb_utils.py └── train_utils.py ├── configs ├── __init__.py ├── hfgi │ ├── hfgi.yaml │ └── README.md ├── pti │ ├── pti_pivot.yaml │ ├── pti.yaml │ └── README.md ├── hyperstyle │ ├── hyperstyle.yaml │ ├── wencoder_ffhq_r50.yaml │ └── README.md ├── dhr │ ├── dhr.yaml │ └── README.md ├── optim │ ├── optim_celeba-hq.yaml │ └── README.md ├── sam │ ├── sam.yaml │ └── README.md ├── paths_config.py ├── psp │ ├── psp_ffhq_r50.yaml │ └── README.md ├── lsap │ ├── lsap_ffhq_r50.yaml │ └── README.md ├── e4e │ ├── e4e_ffhq_r50.yaml │ └── README.md ├── restyle │ ├── restyle_e4e_ffhq_r50.yaml │ └── README.md └── transforms_config.py ├── criteria ├── __init__.py ├── lpips │ ├── __init__.py │ ├── utils.py │ ├── lpips.py │ └── networks.py ├── w_norm.py ├── moco_loss.py └── id_loss.py ├── datasets ├── __init__.py ├── images_dataset.py └── inference_dataset.py ├── options ├── __init__.py ├── base_options.py └── train_options.py ├── .gitignore ├── training └── __init__.py ├── inference ├── __init__.py ├── inference.py ├── code_infer.py ├── encoder_infer.py ├── hfgi_infer.py ├── two_stage_inference.py ├── restyle_infer.py └── hyper_infer.py ├── docs ├── HFGI.png ├── PTI.png ├── SAM.png ├── dhr.png ├── e4e.png ├── lsap.png ├── pSp.png ├── optim.png ├── restyle.png ├── hyperstyle.png ├── gan_inverter.jpeg ├── inference_pipeline.png ├── install.md └── dataset.md ├── requirements.txt ├── editing ├── interfacegan_directions │ ├── age.pt │ ├── pose.pt │ └── smile.pt ├── __init__.py ├── base_editing.py ├── interfacegan.py └── ganspace.py ├── licenses ├── LICENSE_TreB1eN ├── LICENSE_HuangYG123 ├── LICENSE_rosinality └── LICENSE_S-aiueo32 ├── LICENSE ├── scripts ├── train.py ├── edit.py ├── calc_id_loss.py ├── test.py └── infer.py └── pretrained_models └── download_models.sh /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/bisenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/hypernetworks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/invertibility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/facer/facer/version.py: -------------------------------------------------------------------------------- 1 | __version__="0.0.1" -------------------------------------------------------------------------------- /utils/facer/scripts/publish.sh: -------------------------------------------------------------------------------- 1 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | experiments/ 3 | latent_fid_results/ -------------------------------------------------------------------------------- /utils/facer/scripts/build.sh: -------------------------------------------------------------------------------- 1 | python setup.py bdist_wheel -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder_trainer import EncoderTrainer -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .two_stage_inference import TwoStageInference 2 | -------------------------------------------------------------------------------- /docs/HFGI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/HFGI.png -------------------------------------------------------------------------------- /docs/PTI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/PTI.png -------------------------------------------------------------------------------- /docs/SAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/SAM.png -------------------------------------------------------------------------------- /docs/dhr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/dhr.png -------------------------------------------------------------------------------- /docs/e4e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/e4e.png -------------------------------------------------------------------------------- /docs/lsap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/lsap.png -------------------------------------------------------------------------------- /docs/pSp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/pSp.png -------------------------------------------------------------------------------- /utils/facer/facer/face_landmark/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import FaceLandmarkDetector -------------------------------------------------------------------------------- /docs/optim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/optim.png -------------------------------------------------------------------------------- /docs/restyle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/restyle.png -------------------------------------------------------------------------------- /docs/hyperstyle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/hyperstyle.png -------------------------------------------------------------------------------- /docs/gan_inverter.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/gan_inverter.jpeg -------------------------------------------------------------------------------- /models/encoders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/models/encoders/.DS_Store -------------------------------------------------------------------------------- /utils/facer/facer/face_parsing/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import FaceParser 2 | from .farl import FaRLFaceParser -------------------------------------------------------------------------------- /docs/inference_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/docs/inference_pipeline.png -------------------------------------------------------------------------------- /utils/facer/facer/face_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import FaceDetector 2 | from .retinaface import RetinaFaceDetector -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | tqdm 3 | matplotlib 4 | ninja 5 | opencv-python 6 | pyyaml 7 | scikit-image 8 | wandb 9 | validators -------------------------------------------------------------------------------- /editing/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/editing/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /utils/facer/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | pillow 3 | numpy 4 | ipywidgets 5 | scikit-image 6 | matplotlib 7 | validators 8 | colorsys -------------------------------------------------------------------------------- /editing/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_editing import BaseEditing 2 | from .interfacegan import InterFaceGAN 3 | from .ganspace import GANSpace -------------------------------------------------------------------------------- /editing/interfacegan_directions/pose.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/editing/interfacegan_directions/pose.pt -------------------------------------------------------------------------------- /editing/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/editing/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caopulan/GANInverter/HEAD/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /utils/facer/samples/data/fire.webp: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eb4f6043f17a868cb6618a97fb5ba9a130c7f10b13b1db83fcf2df10ecbe1f23 3 | size 82698 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/girl.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a72bd4c8aa42981f5988c8311ba6df7a908cc5db3e49b69f3d9f37ff053d88da 3 | size 48013 4 | -------------------------------------------------------------------------------- /utils/facer/samples/transform.ipynb: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7dc0672f3462c131e559c3c2612806d3d9777191b6199fe40d26e137431e1d62 3 | size 429045 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/sideface.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3973d2e664195f6a26f2e6609a5ea1f42b43c6d5828ff99d8ba781aafb7423e0 3 | size 34871 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/twogirls.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c9edeb3b1ed1ca645af8cb350dccb8106aaf19eedbb80243fda4859c385e7477 3 | size 43364 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/weirdface.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c97a31b51b239b8f727c15086cfe09a0b024900c4d0ebbba83e80fce8c6a51c 3 | size 153904 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/weirdface2.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4229c165563acf6b9ff101afa72daa0a67c28326a37a531aa1d784e93f38bd38 3 | size 49669 4 | -------------------------------------------------------------------------------- /utils/facer/samples/data/weirdface3.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:53f4734204658bf7d0570ba4aa6cd2adf2c8f2fc3c745d16cbc9578f98207ae3 3 | size 45389 4 | -------------------------------------------------------------------------------- /utils/facer/samples/face_detect.ipynb: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:34f556d2fff42f7ae1823153beb3af212df934b1465734fddfb8d78b34f3cba9 3 | size 180332 4 | -------------------------------------------------------------------------------- /utils/facer/samples/face_parsing.ipynb: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ca598e1831d76d924119a9a11038f0064131ac95e9dcc2e64dc04ed087fd5888 3 | size 367056 4 | -------------------------------------------------------------------------------- /editing/base_editing.py: -------------------------------------------------------------------------------- 1 | class BaseEditing(object): 2 | def __init__(self): 3 | pass 4 | 5 | def edit_code(self, code): 6 | pass 7 | 8 | def edit_feature(self, features): 9 | pass 10 | -------------------------------------------------------------------------------- /configs/hfgi/hfgi.yaml: -------------------------------------------------------------------------------- 1 | embed_mode: 'encoder' 2 | refine_mode: 'hfgi' 3 | test_dataset_path: "./data/CelebA-HQ/test" 4 | test_batch_size: 1 5 | output_dir: './inference_result/hfgi' 6 | checkpoint_path: "pretrained_models/hfgi/hfgi_ffhq_official.pt" -------------------------------------------------------------------------------- /configs/pti/pti_pivot.yaml: -------------------------------------------------------------------------------- 1 | embed_mode: 'optim' 2 | 3 | # optim options 4 | w_plus: False 5 | lr: 5e-3 6 | optim_step: 450 7 | noise: 0.05 8 | noise_ramp: 0.75 9 | noise_regularize: 1e5 10 | optim_l2_lambda: 0. 11 | optim_lpips_lambda: 1. -------------------------------------------------------------------------------- /configs/hyperstyle/hyperstyle.yaml: -------------------------------------------------------------------------------- 1 | embed_mode: 'encoder' 2 | refine_mode: 'hyperstyle' 3 | test_dataset_path: "./data/CelebA-HQ/test" 4 | test_batch_size: 1 5 | output_dir: './inference_result/hyperstyle' 6 | hyperstyle_iteration: 5 7 | hypernet_checkpoint_path: "pretrained_models/hyperstyle/hyperstyle_ffhq_official.pt" -------------------------------------------------------------------------------- /inference/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseInference(object): 5 | def __init__(self): 6 | self.decoder = None 7 | 8 | def inverse(self, **kwargs): 9 | pass 10 | 11 | @torch.no_grad() 12 | def generate(self, codes): 13 | return self.decoder([codes], input_is_latent=True, return_latents=False)[0] 14 | 15 | def edit(self, **kwargs): 16 | pass 17 | -------------------------------------------------------------------------------- /configs/dhr/dhr.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: 'experiments/dhr' 2 | refine_mode: 'dhr' 3 | test_batch_size: 1 4 | output_dir: './inference_result/dhr' 5 | 6 | # coarse inv 7 | optim_step: 40 8 | w_plus: False 9 | 10 | # dhr optiosn 11 | dhr_feature_idx: 11 12 | dhr_weight_lr: 1.5e-3 13 | dhr_feature_lr: 9e-2 14 | dhr_weight_step: 50 15 | dhr_feature_step: 100 16 | dhr_l2_lambda: 1. 17 | dhr_lpips_lambda: 1. 18 | dhr_theta1: 0.7 19 | dhr_theta2: 0.8 -------------------------------------------------------------------------------- /criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /configs/optim/optim_celeba-hq.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/optim/ffhq" 2 | 3 | # Data 4 | test_dataset_path: './data/CelebA-HQ/test' 5 | test_batch_size: 8 6 | test_workers: 8 7 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 8 | 9 | # Hyper-parameter of optimization 10 | optim_step: 1000 11 | optim_l2_lambda: 0. 12 | optim_lpips_lambda: 1. # only use lpips loss 13 | 14 | noise: 0.05 15 | noise_ramp: 0.75 16 | noise_regularize: 1e5 17 | -------------------------------------------------------------------------------- /utils/facer/facer/face_detection/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FaceDetector(nn.Module): 6 | """ face detector 7 | 8 | Args: 9 | images (torch.Tensor): b x c x h x w 10 | 11 | Returns: 12 | data (Dict[str, torch.Tensor]): 13 | 14 | * rects: nfaces x 4 (x1, y1, x2, y2) 15 | * points: nfaces x 5 x 2 (x, y) 16 | * scores: nfaces 17 | * image_ids: nfaces 18 | """ 19 | pass 20 | -------------------------------------------------------------------------------- /configs/sam/sam.yaml: -------------------------------------------------------------------------------- 1 | refine_mode: 'sam' 2 | test_dataset_path: "./data/CelebA-HQ/test" 3 | test_batch_size: 1 4 | test_workers: 0 5 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 6 | output_dir: './inference_result/sam' 7 | 8 | # SAM options 9 | latent_names: "W+,F4,F6,F10" 10 | thresh: 0.225 11 | sam_lr: 0.05 12 | sam_step: 1001 13 | sam_rec_l2_lambda: 1. 14 | sam_rec_lpips_lambda: 1. 15 | sam_lat_mvg_lambda: 1e-8 16 | sam_lat_delta_lambda: 1e-3 17 | sam_lat_frec_lambda: 5. 18 | 19 | -------------------------------------------------------------------------------- /configs/pti/pti.yaml: -------------------------------------------------------------------------------- 1 | refine_mode: 'pti' 2 | test_dataset_path: './data/CelebA-HQ/test' 3 | test_batch_size: 1 4 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 5 | output_dir: './inference_result/pti' 6 | 7 | # PTI options 8 | pti_lr: 3e-4 9 | pti_step: 350 10 | pti_l2_lambda: 1. 11 | pti_lpips_lambda: 1. 12 | pti_regulizer_lambda: 1. 13 | pti_regulizer_l2_lambda': 0.1 14 | pti_r_lpips_lambda': 0.1 15 | pti_locality_regularization_interval: 50 16 | pti_regulizer_alpha: 30. 17 | pti_latent_ball_num_of_samples: 1 -------------------------------------------------------------------------------- /models/invertibility/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | model_paths = { 2 | 'ir_se50': 'pretrained_models/model_ir_se50.pth', 3 | 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth', 4 | 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy', 5 | 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy', 6 | 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy', 7 | 'shape_predictor': 'shape_predictor_68_face_landmarks.dat', 8 | 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar', 9 | 'invert_predictor_faces': 'pretrained_models/invertibility_faces_sg2.pt', 10 | 'segmenter_faces': 'pretrained_models/79999_iter.pth' 11 | } 12 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /models/invertibility/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from models.invertibility.backbone import resnet, xception, drn, mobilenet 2 | 3 | 4 | def build_backbone(backbone, output_stride, BatchNorm): 5 | if backbone == 'resnet': 6 | return resnet.ResNet101(output_stride, BatchNorm) 7 | elif backbone == 'xception': 8 | return xception.AlignedXception(output_stride, BatchNorm) 9 | elif backbone == 'drn': 10 | return drn.drn_d_54(BatchNorm) 11 | elif backbone == 'mobilenet': 12 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 13 | else: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /configs/psp/psp_ffhq_r50.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/psp/ffhq_psp_r50" 2 | 3 | encoder_type: "GradualStyleEncoder" 4 | 5 | # Data 6 | train_dataset_path: './data/FFHQ' 7 | test_dataset_path: './data/CelebA-HQ/test' 8 | 9 | # Hyper-parameter of training 10 | checkpoint_path: "pretrained_models/psp/psp_ffhq_r50_wp_official.pt" 11 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 12 | batch_size: 8 13 | workers: 8 14 | test_batch_size: 8 15 | test_workers: 8 16 | start_from_latent_avg: True 17 | 18 | # Loss 19 | lpips_lambda: 0.8 20 | id_lambda: 0.1 21 | l2_lambda: 1.0 22 | w_norm_lambda: 0. 23 | 24 | # Wandb 25 | #use_wandb: True 26 | wandb_project: ffhq-inversion -------------------------------------------------------------------------------- /utils/facer/facer/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def read_hwc(path: str) -> torch.Tensor: 7 | """Read an image from a given path. 8 | 9 | Args: 10 | path (str): The given path. 11 | """ 12 | image = Image.open(path) 13 | np_image = np.array(image.convert('RGB')) 14 | return torch.from_numpy(np_image) 15 | 16 | 17 | def write_hwc(image: torch.Tensor, path: str): 18 | """Write an image to a given path. 19 | 20 | Args: 21 | image (torch.Tensor): The image. 22 | path (str): The given path. 23 | """ 24 | 25 | Image.fromarray(image.cpu().numpy()).save(path) 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /configs/hyperstyle/wencoder_ffhq_r50.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/hyperstyle/wencoder_psp_r50" 2 | 3 | encoder_type: "BackboneEncoderUsingLastLayerIntoW" 4 | 5 | # Data 6 | train_dataset_path: './data/FFHQ' 7 | test_dataset_path: './data/CelebA-HQ/test' 8 | 9 | # Hyper-parameter of training 10 | checkpoint_path: "pretrained_models/hyperstyle/hyperstyle_ffhq_r50_w_official.pt" 11 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 12 | batch_size: 8 13 | workers: 8 14 | test_batch_size: 8 15 | test_workers: 8 16 | start_from_latent_avg: True 17 | 18 | # Loss 19 | lpips_lambda: 0.8 20 | id_lambda: 0.1 21 | l2_lambda: 1.0 22 | w_norm_lambda: 0. 23 | 24 | # Wandb 25 | #use_wandb: True 26 | wandb_project: ffhq-inversion -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Installation Instruction 2 | 3 | - python 3.8 4 | - Pytorch 1.12.1 with CUDA11.6 5 | ### 1. Create the conda environment 6 | ```bash 7 | conda create -n gan_inverter python=3.8 8 | conda activate gan_inverter 9 | ``` 10 | ### 2. Install Pytorch 11 | 12 | Pytorch 1.12.1 with CUDA 11.6 is tested. Please install pytorch by [official website](https://pytorch.org/get-started/locally/). 13 | ### 3. Install Requirements 14 | ```bash 15 | pip3 install -r requirements.txt 16 | ``` 17 | 18 | ### Note 19 | - If you meet any error during StyleGAN opt building, please search issues in official StyleGAN series repositories first. 20 | - We recommend CUDA 11.x and Pytorch version > 1.9. 21 | - We have not tested cpu environment. -------------------------------------------------------------------------------- /configs/lsap/lsap_ffhq_r50.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/lsap/ffhq_lsap_r50" 2 | 3 | # Data 4 | train_dataset_path: './data/FFHQ' 5 | test_dataset_path: './data/CelebA-HQ/test' 6 | 7 | # Hyper-parameter of training 8 | checkpoint_path: "pretrained_models/lsap/lsap_ffhq_r50_wp_official.pt" 9 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 10 | batch_size: 8 11 | workers: 8 12 | test_batch_size: 8 13 | test_workers: 8 14 | start_from_latent_avg: True 15 | 16 | # Image Loss 17 | lpips_lambda: 0.8 18 | id_lambda: 0.1 19 | l2_lambda: 1.0 20 | w_norm_lambda: 0. 21 | delta_norm_lambda: 2e-4 22 | 23 | # Alignment Loss 24 | sncd_lambda: 0.5 25 | 26 | # Progressive Training 27 | progressive_start: 20000 28 | progressive_step_every: 100 #2000 29 | 30 | # Wandb 31 | #use_wandb: True 32 | wandb_project: ffhq-inversion -------------------------------------------------------------------------------- /utils/facer/facer/face_landmark/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FaceLandmarkDetector(nn.Module): 6 | """ face landmark detector 7 | 8 | Args: 9 | images (torch.Tensor): b x c x h x w 10 | 11 | data (Dict[str, Any]): 12 | 13 | * image_ids (torch.Tensor): nfaces 14 | * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2) 15 | * points (torch.Tensor): nfaces x 5 x 2 (x, y) 16 | 17 | Returns: 18 | data (Dict[str, Any]): 19 | 20 | * image_ids (torch.Tensor): nfaces 21 | * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2) 22 | * points (torch.Tensor): nfaces x 5 x 2 (x, y) 23 | * landmarks (torch.Tensor): nfaces x nlandmarks x 2 (x, y) 24 | """ 25 | pass 26 | -------------------------------------------------------------------------------- /configs/e4e/e4e_ffhq_r50.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/e4e/ffhq_e4e_r50" 2 | 3 | # Data 4 | train_dataset_path: './data/FFHQ' 5 | test_dataset_path: './data/CelebA-HQ/test' 6 | 7 | # Hyper-parameter of training 8 | checkpoint_path: "pretrained_models/e4e/e4e_ffhq_r50_wp_official.pt" 9 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 10 | batch_size: 8 11 | workers: 8 12 | test_batch_size: 8 13 | test_workers: 8 14 | start_from_latent_avg: True 15 | 16 | # Image Loss 17 | lpips_lambda: 0.8 18 | id_lambda: 0.1 19 | l2_lambda: 1.0 20 | w_norm_lambda: 0. 21 | delta_norm_lambda: 2e-4 22 | 23 | # Discriminator 24 | w_discriminator_lambda: 0.1 25 | r1: 10 26 | use_w_pool: True 27 | 28 | # Progressive Training 29 | progressive_start: 20000 30 | progressive_step_every: 2000 31 | 32 | # Wandb 33 | #use_wandb: True 34 | wandb_project: ffhq-inversion -------------------------------------------------------------------------------- /utils/facer/facer/face_parsing/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FaceParser(nn.Module): 5 | """ face parser 6 | 7 | Args: 8 | images (torch.Tensor): b x c x h x w 9 | 10 | data (Dict[str, Any]): 11 | 12 | * image_ids (torch.Tensor): nfaces 13 | * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2) 14 | * points (torch.Tensor): nfaces x 5 x 2 (x, y) 15 | 16 | Returns: 17 | data (Dict[str, Any]): 18 | 19 | * image_ids (torch.Tensor): nfaces 20 | * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2) 21 | * points (torch.Tensor): nfaces x 5 x 2 (x, y) 22 | * seg (Dict[str, Any]): 23 | 24 | * logits (torch.Tensor): nfaces x nclasses x h x w 25 | * label_names (List[str]): nclasses 26 | """ 27 | pass -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | IMG_EXTENSIONS_WINDOWS = [ 13 | '.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tiff' 14 | ] 15 | 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | 21 | def make_dataset(dir): 22 | images = [] 23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 24 | for root, _, fnames in sorted(os.walk(dir)): 25 | for fname in fnames: 26 | if is_image_file(fname): 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | return images 30 | -------------------------------------------------------------------------------- /editing/interfacegan.py: -------------------------------------------------------------------------------- 1 | from .base_editing import BaseEditing 2 | import torch 3 | import os 4 | 5 | 6 | class InterFaceGAN(BaseEditing): 7 | def __init__(self, opts): 8 | super(InterFaceGAN, self).__init__() 9 | self.opts = opts 10 | self.edit_vector = torch.load(opts.edit_path, map_location='cpu').cuda() 11 | self.factor = opts.edit_factor 12 | 13 | if opts.edit_save_path == '': 14 | self.save_folder = f'{os.path.basename(opts.edit_path).split(".")[0]}_{self.factor}' 15 | else: 16 | self.save_folder = opts.edit_save_path 17 | 18 | if self.edit_vector.dim() == 2: 19 | self.edit_vector = self.edit_vector[None] 20 | # elif self.edit_vector.dim() == 3: 21 | # self.edit_vector = self.edit_vector[None] 22 | 23 | def edit_code(self, code): 24 | return code + self.edit_vector * self.factor 25 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /models/invertibility/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /configs/restyle/restyle_e4e_ffhq_r50.yaml: -------------------------------------------------------------------------------- 1 | exp_dir: "experiments/e4e/ffhq_e4e_r50" 2 | output_dir: './inference_result/restyle' 3 | embed_mode: restyle 4 | 5 | # Data 6 | train_dataset_path: './data/FFHQ' 7 | test_dataset_path: './data/CelebA-HQ/test' 8 | 9 | # Hyper-parameter of training 10 | checkpoint_path: 'pretrained_models/restyle/restyle-e4e_ffhq_r50_wp_official.pt' 11 | stylegan_weights: "pretrained_models/stylegan2-ffhq-config-f.pt" 12 | batch_size: 8 13 | workers: 8 14 | test_batch_size: 8 15 | test_workers: 8 16 | start_from_latent_avg: False 17 | encoder_type: ProgressiveBackboneEncoder 18 | input_nc: 6 19 | 20 | # Image Loss 21 | lpips_lambda: 0.8 22 | id_lambda: 0.1 23 | l2_lambda: 1.0 24 | w_norm_lambda: 0. 25 | delta_norm_lambda: 2e-4 26 | 27 | # Discriminator 28 | w_discriminator_lambda: 0.1 29 | r1: 10 30 | use_w_pool: True 31 | 32 | # Progressive Training 33 | progressive_start: 20000 34 | progressive_step_every: 2000 35 | 36 | # Wandb 37 | #use_wandb: True 38 | wandb_project: ffhq-inversion -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class ImagesDataset(Dataset): 7 | 8 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 9 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 10 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 11 | self.source_transform = source_transform 12 | self.target_transform = target_transform 13 | self.opts = opts 14 | 15 | def __len__(self): 16 | return len(self.source_paths) 17 | 18 | def __getitem__(self, index): 19 | from_path = self.source_paths[index] 20 | from_im = Image.open(from_path) 21 | from_im = from_im.convert('RGB') 22 | 23 | to_path = self.target_paths[index] 24 | to_im = Image.open(to_path).convert('RGB') 25 | if self.target_transform: 26 | to_im = self.target_transform(to_im) 27 | 28 | if self.source_transform: 29 | from_im = self.source_transform(from_im) 30 | else: 31 | from_im = to_im 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /models/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import transforms 3 | 4 | from models.bisenet.model import BiSeNet 5 | 6 | 7 | class SegmenterFace: 8 | def __init__(self, ckpt_path="ckpt/79999_iter.pth", fuse_face_regions=True): 9 | 10 | self.net = BiSeNet(n_classes=19).cuda() 11 | self.net.load_state_dict(torch.load(ckpt_path)) 12 | self.net.eval() 13 | self.fuse_face_regions = fuse_face_regions 14 | 15 | def segment_pil(self, img_pil): 16 | out = self.net(img_pil)[0] 17 | parsed = out.squeeze(0).detach().cpu().numpy().argmax(0) 18 | if self.fuse_face_regions: 19 | """ 20 | 1 - skin 21 | 2/3 - left/right brow 22 | 4/5 - left/right eye 23 | 7/8 - left/right ear 24 | 10 - nose 25 | 11 - mouth 26 | 12/13 - upper/lower lips 27 | 14 - neck 28 | 17 - hair 29 | """ 30 | for idx in [1, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 17]: 31 | parsed[parsed == idx] = 3 32 | return parsed 33 | -------------------------------------------------------------------------------- /utils/facer/facer/show.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torch 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | from .util import bchw2hwc 7 | 8 | 9 | def set_figsize(*args): 10 | if len(args) == 0: 11 | plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"] 12 | elif len(args) == 1: 13 | plt.rcParams["figure.figsize"] = (args[0], args[0]) 14 | elif len(args) == 2: 15 | plt.rcParams["figure.figsize"] = tuple(args) 16 | else: 17 | raise RuntimeError( 18 | f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)') 19 | 20 | 21 | def show_hwc(image: torch.Tensor): 22 | if image.dtype != torch.uint8: 23 | image = image.to(torch.uint8) 24 | if image.size(2) == 1: 25 | image = image.repeat(1, 1, 3) 26 | pimage = Image.fromarray(image.cpu().numpy()) 27 | plt.imshow(pimage) 28 | plt.show() 29 | 30 | 31 | def show_bchw(image: torch.Tensor): 32 | show_hwc(bchw2hwc(image)) 33 | 34 | 35 | def show_bhw(image: torch.Tensor): 36 | show_bchw(image.unsqueeze(1)) 37 | -------------------------------------------------------------------------------- /licenses/LICENSE_TreB1eN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_HuangYG123: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HuangYG123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils/facer/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 FacePerceiver 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | 4 | 5 | class TransformsConfig(object): 6 | 7 | def __init__(self, opts): 8 | self.opts = opts 9 | 10 | @abstractmethod 11 | def get_transforms(self): 12 | pass 13 | 14 | 15 | class EncodeTransforms(TransformsConfig): 16 | 17 | def __init__(self, opts): 18 | super(EncodeTransforms, self).__init__(opts) 19 | 20 | def get_transforms(self): 21 | transforms_dict = { 22 | 'transform_gt_train': transforms.Compose([ 23 | transforms.Resize((256, 256)), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 27 | 'transform_source': None, 28 | 'transform_test': transforms.Compose([ 29 | transforms.Resize((256, 256)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 32 | 'transform_inference': transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 36 | 'transform_apply':transforms.Compose([ 37 | # transforms.Resize((256, 256)), 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 40 | } 41 | return transforms_dict -------------------------------------------------------------------------------- /editing/ganspace.py: -------------------------------------------------------------------------------- 1 | from .base_editing import BaseEditing 2 | import torch 3 | 4 | 5 | class GANSpace(BaseEditing): 6 | def __init__(self, opts): 7 | super(GANSpace, self).__init__() 8 | ganspace_pca = torch.load(opts.edit_path, map_location='cpu') 9 | self.pca_idx, self.start, self.end, self.strength = opts.ganspace_directions 10 | self.code_mean = ganspace_pca['mean'].cuda() 11 | self.code_comp = ganspace_pca['comp'].cuda()[self.pca_idx] 12 | self.code_std = ganspace_pca['std'].cuda()[self.pca_idx] 13 | if opts.edit_save_path == '': 14 | self.save_folder = f'ganspace_{self.pca_idx}_{self.start}_{self.end}_{self.strength}' 15 | else: 16 | self.save_folder = opts.edit_save_path 17 | 18 | def edit_code(self, code): 19 | edit_codes = [] 20 | for c in code: 21 | w_centered = c - self.code_mean 22 | w_coord = torch.sum(w_centered[0].reshape(-1) * self.code_comp.reshape(-1)) / self.code_std 23 | delta = (self.strength - w_coord) * self.code_comp * self.code_std 24 | delta_padded = torch.zeros(c.shape).to('cuda') 25 | delta_padded[self.start:self.end] += delta.repeat(self.end - self.start, 1) 26 | edit_codes.append(c + delta_padded) 27 | return torch.stack(edit_codes) -------------------------------------------------------------------------------- /datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from utils import data_utils 6 | import glob 7 | import os 8 | 9 | 10 | class InversionDataset(Dataset): 11 | 12 | def __init__(self, root, transform=None, transform_no_resize=None): 13 | self.paths = sorted(data_utils.make_dataset(root)) 14 | self.transform = transform 15 | self.transform_no_resize = transform_no_resize 16 | 17 | def __len__(self): 18 | return len(self.paths) 19 | 20 | def __getitem__(self, index): 21 | from_path = self.paths[index] 22 | from_im = Image.open(from_path) 23 | from_im = from_im.convert('RGB') 24 | if self.transform: 25 | from_im_aug = self.transform(from_im) 26 | else: 27 | from_im_aug = from_im 28 | 29 | if self.transform_no_resize is not None: 30 | from_im_no_resize_aug = self.transform_no_resize(from_im) 31 | return from_im_aug, from_path, from_im_no_resize_aug 32 | else: 33 | return from_im_aug, from_path 34 | 35 | 36 | class InversionCodeDataset(Dataset): 37 | 38 | def __init__(self, root): 39 | self.paths = sorted(glob.glob(os.path.join(root, '*.pt'))) 40 | 41 | def __len__(self): 42 | return len(self.paths) 43 | 44 | def __getitem__(self, index): 45 | code_path = self.paths[index] 46 | return torch.load(code_path, map_location='cpu'), code_path 47 | -------------------------------------------------------------------------------- /utils/facer/facer/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torch 3 | 4 | from .io import read_hwc, write_hwc 5 | from .util import hwc2bchw, bchw2hwc 6 | from .draw import draw_bchw 7 | from .show import show_bchw, show_bhw 8 | 9 | from .face_detection import FaceDetector 10 | from .face_parsing import FaceParser 11 | from .face_landmark import FaceLandmarkDetector 12 | 13 | 14 | def _split_name(name: str) -> Tuple[str, Optional[str]]: 15 | if '/' in name: 16 | detector_type, conf_name = name.split('/', 1) 17 | else: 18 | detector_type, conf_name = name, None 19 | return detector_type, conf_name 20 | 21 | 22 | def face_detector(name: str, device: torch.device) -> FaceDetector: 23 | detector_type, conf_name = _split_name(name) 24 | if detector_type == 'retinaface': 25 | from .face_detection import RetinaFaceDetector 26 | return RetinaFaceDetector(conf_name).to(device) 27 | else: 28 | raise RuntimeError(f'Unknown detector type: {detector_type}') 29 | 30 | 31 | def face_parser(name: str, device: torch.device) -> FaceParser: 32 | parser_type, conf_name = _split_name(name) 33 | if parser_type == 'farl': 34 | from .face_parsing import FaRLFaceParser 35 | return FaRLFaceParser(conf_name).to(device) 36 | else: 37 | raise RuntimeError(f'Unknown parser type: {parser_type}') 38 | -------------------------------------------------------------------------------- /licenses/LICENSE_S-aiueo32: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /utils/facer/README.md: -------------------------------------------------------------------------------- 1 | # FACER 2 | 3 | Face related toolkit. This repo is still under construction to include more models. 4 | 5 | ## Install 6 | 7 | The easiest way to install it is using pip: 8 | 9 | ```bash 10 | pip install pyfacer 11 | ``` 12 | No extra setup needs, pretrained weights will be downloaded automatically. 13 | 14 | 15 | ## Face Detection 16 | 17 | We simply wrap a retinaface detector for easy usage. 18 | Check [this notebook](./samples/face_detect.ipynb). 19 | 20 | Please consider citing 21 | ``` 22 | @inproceedings{deng2020retinaface, 23 | title={Retinaface: Single-shot multi-level face localisation in the wild}, 24 | author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos}, 25 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 26 | pages={5203--5212}, 27 | year={2020} 28 | } 29 | ``` 30 | 31 | ## Face Parsing 32 | 33 | We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face parsing. 34 | Check [this notebook](./samples/face_parsing.ipynb). 35 | 36 | Please consider citing 37 | ``` 38 | @article{zheng2021farl, 39 | title={General Facial Representation Learning in a Visual-Linguistic Manner}, 40 | author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang}, 41 | journal={arXiv preprint arXiv:2112.03109}, 42 | year={2021} 43 | } 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type)#.to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list)#.to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor, keep_res=False): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | 34 | if keep_res: 35 | res = [l(d) for d, l in zip(diff, self.lin)] 36 | return res 37 | else: 38 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 39 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 40 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | _init_dist_pytorch(backend, **kwargs) 14 | 15 | 16 | def _init_dist_pytorch(backend, **kwargs): 17 | rank = int(os.environ['RANK']) 18 | num_gpus = torch.cuda.device_count() 19 | torch.cuda.set_device(rank % num_gpus) 20 | dist.init_process_group(backend=backend, **kwargs) 21 | 22 | 23 | # ---------------------------------- 24 | # get rank and world_size 25 | # ---------------------------------- 26 | def get_dist_info(): 27 | if dist.is_available(): 28 | initialized = dist.is_initialized() 29 | else: 30 | initialized = False 31 | if initialized: 32 | rank = dist.get_rank() 33 | world_size = dist.get_world_size() 34 | else: 35 | rank = 0 36 | world_size = 1 37 | return rank, world_size 38 | 39 | 40 | def get_rank(): 41 | if not dist.is_available(): 42 | return 0 43 | 44 | if not dist.is_initialized(): 45 | return 0 46 | 47 | return dist.get_rank() 48 | 49 | 50 | def get_world_size(): 51 | if not dist.is_available(): 52 | return 1 53 | 54 | if not dist.is_initialized(): 55 | return 1 56 | 57 | return dist.get_world_size() -------------------------------------------------------------------------------- /models/hypernetworks/shared_weights_hypernet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class SharedWeightsHypernet(nn.Module): 7 | 8 | def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None): 9 | super(SharedWeightsHypernet, self).__init__() 10 | self.mode = mode 11 | self.z_dim = z_dim 12 | self.f_size = f_size 13 | if self.mode == 'delta_per_channel': 14 | self.f_size = 1 15 | self.out_size = out_size 16 | self.in_size = in_size 17 | 18 | self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).cuda() / 40, 2)) 19 | self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).cuda() / 40, 2)) 20 | 21 | self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).cuda() / 40, 2)) 22 | self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).cuda() / 40, 2)) 23 | 24 | def forward(self, z): 25 | batch_size = z.shape[0] 26 | h_in = torch.matmul(z, self.w2) + self.b2 27 | h_in = h_in.view(batch_size, self.in_size, self.z_dim) 28 | 29 | h_final = torch.matmul(h_in, self.w1) + self.b1 30 | kernel = h_final.view(batch_size, self.out_size, self.in_size, self.f_size, self.f_size) 31 | if self.mode == 'delta_per_channel': # repeat per channel values to the 3x3 conv kernels 32 | kernel = kernel.repeat(1, 1, 1, 3, 3) 33 | return kernel 34 | -------------------------------------------------------------------------------- /utils/facer/setup.py: -------------------------------------------------------------------------------- 1 | from os import path as os_path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | this_directory = os_path.abspath(os_path.dirname(__file__)) 6 | 7 | 8 | # 读取文件内容 9 | def read_file(filename): 10 | with open(os_path.join(this_directory, filename), encoding="utf-8") as f: 11 | long_description = f.read() 12 | return long_description 13 | 14 | 15 | # 获取依赖 16 | def read_requirements(filename): 17 | return [ 18 | line.strip() 19 | for line in read_file(filename).splitlines() 20 | if not line.startswith("#") 21 | ] 22 | 23 | def get_version(): 24 | version_file = 'facer/version.py' 25 | with open(version_file, 'r', encoding='utf-8') as f: 26 | exec(compile(f.read(), version_file, 'exec')) 27 | return locals()['__version__'] 28 | 29 | setup( 30 | name="facer", 31 | version=get_version(), 32 | description="Face related toolkit", 33 | author="FacePerceiver", 34 | author_email="admin@hypercube.top", 35 | url="https://github.com/FacePerceiver/facer", 36 | license="MIT", 37 | keywords="face-detection pytorch RetinaFace face-parsing farl", 38 | project_urls={ 39 | "Documentation": "https://github.com/FacePerceiver/facer", 40 | "Source": "https://github.com/FacePerceiver/facer", 41 | "Tracker": "https://github.com/FacePerceiver/facer/issues", 42 | }, 43 | long_description=read_file("README.md"), # 读取的Readme文档内容 44 | long_description_content_type="text/markdown", # 指定包文档格式为markdown 45 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 46 | install_requires=["numpy", "torch", "torchvision", "opencv-python","validators"] 47 | ) -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import numpy as np 4 | import wandb 5 | 6 | from utils import common 7 | 8 | 9 | class WBLogger: 10 | 11 | def __init__(self, opts): 12 | wandb_run_name = os.path.basename(opts.exp_dir) 13 | wandb.init(project=opts.wandb_project, config=vars(opts), name=wandb_run_name) 14 | 15 | @staticmethod 16 | def log_best_model(): 17 | wandb.run.summary["best-model-save-time"] = datetime.datetime.now() 18 | 19 | @staticmethod 20 | def log(prefix, metrics_dict, global_step): 21 | log_dict = {f'{prefix}_{key}': value for key, value in metrics_dict.items()} 22 | log_dict["global_step"] = global_step 23 | wandb.log(log_dict) 24 | 25 | @staticmethod 26 | def log_dataset_wandb(dataset, dataset_name, n_images=16): 27 | idxs = np.random.choice(a=range(len(dataset)), size=n_images, replace=False) 28 | data = [wandb.Image(dataset.source_paths[idx]) for idx in idxs] 29 | wandb.log({f"{dataset_name} Data Samples": data}) 30 | 31 | @staticmethod 32 | def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts): 33 | im_data = [] 34 | column_names = ["Source", "Target", "Output"] 35 | if id_logs is not None: 36 | column_names.append("ID Diff Output to Target") 37 | for i in range(len(x)): 38 | cur_im_data = [ 39 | wandb.Image(common.log_input_image(x[i], opts)), 40 | wandb.Image(common.tensor2im(y[i])), 41 | wandb.Image(common.tensor2im(y_hat[i])), 42 | ] 43 | if id_logs is not None: 44 | cur_im_data.append(id_logs[i]["diff_target"]) 45 | im_data.append(cur_im_data) 46 | outputs_table = wandb.Table(data=im_data, columns=column_names) 47 | wandb.log({f"{prefix.title()} Step {step} Output Samples": outputs_table}) 48 | -------------------------------------------------------------------------------- /configs/optim/README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN2 (Optim) [CVPR2020] 2 | 3 | > [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958) 4 | 5 | ## Abstract 6 | 7 | The style-based GAN architecture (StyleGAN) yields state-of-the-art results in data-driven unconditional generative image modeling. We expose and analyze several of its characteristic artifacts, and propose changes in both model architecture and training methods to address them. In particular, we redesign the generator normalization, revisit progressive growing, and regularize the generator to encourage good conditioning in the mapping from latent codes to images. In addition to improving image quality, this path length regularizer yields the additional benefit that the generator becomes significantly easier to invert. This makes it possible to reliably attribute a generated image to a particular network. We furthermore visualize how well the generator utilizes its output resolution, and identify a capacity problem, motivating us to train larger models for additional quality improvements. Overall, our improved model redefines the state of the art in unconditional image modeling, both in terms of existing distribution quality metrics as well as perceived image quality. 8 | 9 | ![Optim](../../docs/optim.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/optim/optim_celeba-hq.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | ## Citation 25 | 26 | ```latex 27 | @inproceedings{karras2020analyzing, 28 | title={Analyzing and improving the image quality of stylegan}, 29 | author={Karras, Tero and Laine, Samuli and Aittala, Miika and Hellsten, Janne and Lehtinen, Jaakko and Aila, Timo}, 30 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 31 | pages={8110--8119}, 32 | year={2020} 33 | } 34 | ``` -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import math 7 | import sys 8 | import pprint 9 | import torch 10 | import random 11 | import numpy as np 12 | from argparse import Namespace 13 | 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | from options.train_options import TrainOptions 18 | from training import * 19 | 20 | 21 | def main(): 22 | opts = TrainOptions().parse() 23 | set_seed(opts.seed) 24 | setup_progressive_steps(opts) 25 | trainer = EncoderTrainer(opts) 26 | if opts.rank == 0: 27 | create_initial_experiment_dir(opts) 28 | trainer.train() 29 | 30 | 31 | def set_seed(seed): 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | 37 | 38 | def setup_progressive_steps(opts): 39 | log_size = int(math.log(opts.resolution, 2)) 40 | num_style_layers = 2*log_size - 2 41 | num_deltas = num_style_layers - 1 42 | if opts.progressive_start is not None: # If progressive delta training 43 | opts.progressive_steps = [0] 44 | next_progressive_step = opts.progressive_start 45 | for i in range(num_deltas): 46 | opts.progressive_steps.append(next_progressive_step) 47 | next_progressive_step += opts.progressive_step_every 48 | 49 | assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ 50 | "Invalid progressive training input" 51 | 52 | 53 | def is_valid_progressive_steps(opts, num_style_layers): 54 | return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 55 | 56 | 57 | def create_initial_experiment_dir(opts): 58 | if not os.path.exists(opts.exp_dir): 59 | os.makedirs(opts.exp_dir) 60 | opts_dict = vars(opts) 61 | # pprint.pprint(opts_dict) 62 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 63 | json.dump(opts_dict, f, indent=4, sort_keys=True) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /inference/code_infer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from models.stylegan2.model import Generator 4 | import torch 5 | from utils.train_utils import load_train_checkpoint 6 | from inference.inference import BaseInference 7 | 8 | 9 | class CodeInference(BaseInference): 10 | 11 | def __init__(self, opts, decoder=None): 12 | super(CodeInference, self).__init__() 13 | self.opts = opts 14 | self.device = 'cuda' 15 | self.opts.device = self.device 16 | self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 17 | self.code_path = opts.code_path 18 | # resume from checkpoint 19 | checkpoint = load_train_checkpoint(opts) 20 | 21 | # initialize and decoder 22 | if decoder is not None: 23 | self.decoder = decoder 24 | else: 25 | self.decoder = Generator(opts.resolution, 512, 8).to(self.device) 26 | self.decoder.train() 27 | if checkpoint is not None: 28 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 29 | else: 30 | decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu') 31 | self.decoder.load_state_dict(decoder_checkpoint['g_ema']) 32 | 33 | def inverse(self, images, images_resize, image_name): 34 | codes = [] 35 | for path in image_name: 36 | code_path = os.path.join(self.code_path, f'{os.path.basename(path[:-4])}.pt') 37 | codes.append(torch.load(code_path, map_location='cpu')) 38 | codes = torch.stack(codes, dim=0).to(images.device) 39 | with torch.no_grad(): 40 | images, result_latent = self.decoder([codes], input_is_latent=True, return_latents=True) 41 | return images, result_latent, None 42 | 43 | def edit(self, images, images_resize, image_paths, editor): 44 | images, codes, _ = self.inverse(images, images_resize, image_paths) 45 | edit_codes = editor.edit_code(codes) 46 | edit_images = self.generate(edit_codes) 47 | return images, edit_images, codes, edit_codes, None 48 | -------------------------------------------------------------------------------- /configs/sam/README.md: -------------------------------------------------------------------------------- 1 | # Spatially-Adaptive Multilayer Selection (SAM) [CVPR2022] 2 | 3 | > [Spatially-Adaptive Multilayer Selection for GAN Inversion and Editing](https://arxiv.org/abs/2206.08357) 4 | 5 | ## Abstract 6 | 7 | Existing GAN inversion and editing methods work well for aligned objects with a clean background, such as portraits and animal faces, but often struggle for more difficult categories with complex scene layouts and object occlusions, such as cars, animals, and outdoor images. We propose a new method to invert and edit such complex images in the latent space of GANs, such as StyleGAN2. Our key idea is to explore inversion with a collection of layers, spatially adapting the inversion process to the difficulty of the image. We learn to predict the “invertibility” of different image segments and project each segment into a latent layer. Easier regions can be inverted into an earlier layer in the generator’s latent space, while more challenging regions can be inverted into a later feature space. Experiments show that our method obtains better inversion results compared to the recent approaches on complex categories, while maintaining downstream editability. Please refer to our project page at https://www.cs.cmu.edu/˜SAMInversion. 8 | 9 | ![SAM](../../docs/SAM.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/e4e/e4e_ffhq_r50.yaml configs/sam/sam.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | --checkpoint_path /path/to/e4e/weight 23 | ``` 24 | 25 | - `--save_intermidiated`: If true, DHR will save intermediated information like segmentation, modulated feature, and modulated weight. 26 | 27 | ## Citation 28 | 29 | ```latex 30 | @inproceedings{parmar2022spatially, 31 | title={Spatially-adaptive multilayer selection for gan inversion and editing}, 32 | author={Parmar, Gaurav and Li, Yijun and Lu, Jingwan and Zhang, Richard and Zhu, Jun-Yan and Singh, Krishna Kumar}, 33 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 34 | pages={11399--11409}, 35 | year={2022} 36 | } 37 | ``` 38 | 39 | -------------------------------------------------------------------------------- /configs/e4e/README.md: -------------------------------------------------------------------------------- 1 | # Encoder For Editing (e4e) [SIGGRAPH 2021] 2 | 3 | > [Designing an Encoder for StyleGAN Image Manipulation](https://arxiv.org/abs/2102.02766) 4 | 5 | ## Abstract 6 | 7 | Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop. 8 | 9 | ![e4e](../../docs/e4e.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ```bash 18 | python scripts/infer.py \ 19 | --config configs/e4e/e4e_ffhq_r50.yaml \ 20 | --test_dataset_path /path/to/test/data \ 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | ## Citation 25 | 26 | ```latex 27 | @article{tov2021designing, 28 | title={Designing an encoder for stylegan image manipulation}, 29 | author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel}, 30 | journal={ACM Transactions on Graphics (TOG)}, 31 | volume={40}, 32 | number={4}, 33 | pages={1--14}, 34 | year={2021}, 35 | publisher={ACM New York, NY, USA} 36 | } 37 | ``` -------------------------------------------------------------------------------- /configs/restyle/README.md: -------------------------------------------------------------------------------- 1 | # ReStyle: A Residual-Based StyleGAN Encoder via Iterative Refinement (ICCV 2021) 2 | 3 | > [ReStyle: A Residual-Based StyleGAN Encoder via Iterative Refinement](https://arxiv.org/abs/2104.02699) 4 | 5 | ## Abstract 6 | 7 | Recently, the power of unconditional image synthesis has significantly advanced through the use of Generative Adversarial Networks (GANs). The task of inverting an image into its corresponding latent code of the trained GAN is of utmost importance as it allows for the manipulation of real 8 | images, leveraging the rich semantics learned by the network. Recognizing the limitations of current inversion approaches, in this work we present a novel inversion scheme that extends current encoder-based inversion methods by introducing an iterative refinement mechanism. Instead of directly predicting the latent code of a given real image using a single pass, the encoder is tasked with predicting a residual with respect to the current estimate of the inverted latent code in a self-correcting manner. Our residualbased encoder, named ReStyle, attains improved accuracy compared to current state-of-the-art encoder-based methods with a negligible increase in inference time. We analyze the behavior of ReStyle to gain valuable insights into its iterative nature. We then evaluate the performance of our residual encoder and analyze its robustness compared to optimization-based inversion and state-of-the-art encoders. Code is available via our project page: https://yuval-alaluf.github.io/restyle-encoder/ 9 | 10 | ![Restyle](../../docs/restyle.png) 11 | 12 | ## Results 13 | 14 | TODO 15 | 16 | ## Inference 17 | 18 | ``` 19 | python scripts/infer.py \ 20 | --config configs/restyle/restyle_e4e_ffhq_r50.yaml \ 21 | --test_dataset_path /path/to/test/data 22 | --output_dir /path/to/output/dir 23 | ``` 24 | 25 | ## Citation 26 | 27 | ```latex 28 | @inproceedings{alaluf2021restyle, 29 | title={Restyle: A residual-based stylegan encoder via iterative refinement}, 30 | author={Alaluf, Yuval and Patashnik, Or and Cohen-Or, Daniel}, 31 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 32 | pages={6711--6720}, 33 | year={2021} 34 | } 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /configs/hfgi/README.md: -------------------------------------------------------------------------------- 1 | # HFGI: High-Fidelity GAN Inversion for Image Attribute Editing (CVPR 2022) 2 | 3 | > [High-Fidelity GAN Inversion for Image Attribute Editing](https://arxiv.org/abs/2109.06590.pdf) 4 | 5 | ## Abstract 6 | 7 | We present a novel high-fidelity generative adversarial network (GAN) inversion framework that enables attribute editing with image-specific details well-preserved (e.g., background, appearance, and illumination). We first analyze the challenges of high-fidelity GAN inversion from the perspective of lossy data compression. With a low bitrate latent code, previous works have difficulties in preserving high-fidelity details in reconstructed and edited images. Increasing the size of a latent code can improve the accuracy of GAN inversion but at the cost of inferior editability. To improve image fidelity without compromising editability, we propose a distortion consultation approach that employs a distortion map as a reference for high-fidelity reconstruction. In the distortion consultation inversion (DCI), the distortion map is first projected to a high-rate latent map, which then complements the basic low-rate latent code with more details via consultation fusion. To achieve high-fidelity editing, we propose an adaptive distortion alignment (ADA) module with a self-supervised training scheme, which bridges the gap between the edited and inversion images. Extensive experiments in the face and car domains show a clear improvement in both inversion and editing quality. The project page is https://tengfeiwang.github.io/HFGI/. 8 | 9 | ![HFGI](../../docs/HFGI.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/hfgi/hfgi.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | ## Citation 25 | 26 | ```latex 27 | @inproceedings{wang2022high, 28 | title={High-fidelity gan inversion for image attribute editing}, 29 | author={Wang, Tengfei and Zhang, Yong and Fan, Yanbo and Wang, Jue and Chen, Qifeng}, 30 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 31 | pages={11379--11388}, 32 | year={2022} 33 | } 34 | ``` 35 | 36 | -------------------------------------------------------------------------------- /configs/psp/README.md: -------------------------------------------------------------------------------- 1 | # pixel2style2pixel (pSp) [CVPR2021] 2 | 3 | > [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](https://arxiv.org/abs/2008.00951) 4 | 5 | ## Abstract 6 | 7 | We present a generic image-to-image translation framework, pixel2style2pixel (pSp). Our pSp framework is based on a novel encoder network that directly generates a series of style vectors which are fed into a pretrained StyleGAN generator, forming the extended W+ latent space. We first show that our encoder can directly embed real images into W+, with no additional optimization. Next, we propose utilizing our encoder to directly solve image-to-image translation tasks, defining them as encoding problems from some input domain into the latent domain. By deviating from the standard “invert first, edit later” methodology used with previous StyleGAN encoders, our approach can handle a variety of tasks even when the input image is not represented in the StyleGAN domain. We show that solving translation tasks through StyleGAN significantly simplifies the training process, as no adversary is required, has better support for solving tasks without pixel-to-pixel correspondence, and inherently supports multi-modal synthesis via the resampling of styles. Finally, we demonstrate the potential of our framework on a variety of facial image-to-image translation tasks, even when compared to state-of-the-art solutions designed specifically for a single task, and further show that it can be extended beyond the human facial domain. Code is available at https://github.com/eladrich/pixel2style2pixel. 8 | 9 | ![pSp](../../docs/pSp.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/psp/psp_ffhq_r50.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | ## Citation 25 | 26 | ```latex 27 | @inproceedings{richardson2021encoding, 28 | title={Encoding in style: a stylegan encoder for image-to-image translation}, 29 | author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel}, 30 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 31 | pages={2287--2296}, 32 | year={2021} 33 | } 34 | ``` 35 | 36 | -------------------------------------------------------------------------------- /inference/encoder_infer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from models.encoder import Encoder 3 | from models.stylegan2.model import Generator 4 | import torch 5 | from utils.train_utils import load_train_checkpoint 6 | from inference.inference import BaseInference 7 | 8 | 9 | class EncoderInference(BaseInference): 10 | 11 | def __init__(self, opts, decoder=None): 12 | super(EncoderInference, self).__init__() 13 | self.opts = opts 14 | self.device = 'cuda' 15 | self.opts.device = self.device 16 | self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 17 | 18 | # resume from checkpoint 19 | checkpoint = load_train_checkpoint(opts) 20 | 21 | # initialize encoder and decoder 22 | latent_avg = None 23 | if decoder is not None: 24 | self.decoder = decoder 25 | else: 26 | self.decoder = Generator(opts.resolution, 512, 8).to(self.device) 27 | self.decoder.eval() 28 | if checkpoint is not None: 29 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 30 | else: 31 | decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu') 32 | self.decoder.load_state_dict(decoder_checkpoint['g_ema']) 33 | latent_avg = decoder_checkpoint['latent_avg'] 34 | if latent_avg is None: 35 | latent_avg = self.decoder.mean_latent(int(1e5))[0].detach() if checkpoint is None else None 36 | self.encoder = Encoder(opts, checkpoint, latent_avg, device=self.device).to(self.device) 37 | self.encoder.set_progressive_stage(self.opts.n_styles) 38 | self.encoder.eval() 39 | 40 | def inverse(self, images, images_resize, image_path): 41 | with torch.no_grad(): 42 | codes = self.encoder(images_resize) 43 | images, result_latent = self.decoder([codes], input_is_latent=True, return_latents=True, randomize_noise=False) 44 | return images, result_latent, None 45 | 46 | def edit(self, images, images_resize, image_path, editor): 47 | images, codes, _ = self.inverse(images, images_resize, image_path) 48 | edit_codes = editor.edit_code(codes) 49 | edit_images = self.generate(edit_codes) 50 | return images, edit_images, codes, edit_codes, None 51 | -------------------------------------------------------------------------------- /configs/dhr/README.md: -------------------------------------------------------------------------------- 1 | # Domain-Specific Hybrid Refinement (DHR) [Arxiv2023] 2 | 3 | > [What Decreases Editing Capability? Domain-Specific Hybrid Refinement for Improved GAN Inversion](https://arxiv.org/abs/2301.12141) 4 | 5 | ## Abstract 6 | 7 | Recently, inversion methods have focused on additional high-rate information in the generator (e.g., weights or intermediate features) to refine inversion and editing results from embedded latent codes. Although these techniques gain reasonable improvement in reconstruction, they decrease editing capability, especially on complex images (e.g., containing occlusions, detailed backgrounds, and artifacts). A vital crux is refining inversion results, avoiding editing capability degradation. To tackle this problem, we introduce \textbf{D}omain-Specific \textbf{H}ybrid \textbf{R}efinement (DHR), which draws on the advantages and disadvantages of two mainstream refinement techniques to maintain editing ability with fidelity improvement. Specifically, we first propose Domain-Specific Segmentation to segment images into two parts: in-domain and out-of-domain parts. The refinement process aims to maintain the editability for in-domain areas and improve two domains' fidelity. We refine these two parts by weight modulation and feature modulation, which we call Hybrid Modulation Refinement. Our proposed method is compatible with all latent code embedding methods. Extension experiments demonstrate that our approach achieves state-of-the-art in real image inversion and editing. 8 | 9 | ![DHR](../../docs/dhr.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ```bash 18 | python scripts/infer.py \ 19 | --config configs/e4e/e4e_ffhq_r50.yaml configs/dhr/dhr.yaml \ 20 | --test_dataset_path /path/to/test/data \ 21 | --output_dir /path/to/output/dir \ 22 | --checkpoint_path /path/to/e4e/weight 23 | ``` 24 | 25 | - `--save_intermidiated`: If true, DHR will save intermediated information like segmentation, modulated feature, and modulated weight. 26 | 27 | ## Citation 28 | 29 | ```latex 30 | @article{cao2023decreases, 31 | title={What Decreases Editing Capability? Domain-Specific Hybrid Refinement for Improved GAN Inversion}, 32 | author={Cao, Pu and Yang, Lu and Liu, Dongxu and Liu, Zhiwei and Li, Shan and Song, Qing}, 33 | journal={arXiv preprint arXiv:2301.12141}, 34 | year={2023} 35 | } 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /configs/lsap/README.md: -------------------------------------------------------------------------------- 1 | # LSAP: Rethinking Inversion Fidelity, Perception and Editability in GAN Latent Space [Arxiv2022] 2 | 3 | > [LSAP: Rethinking Inversion Fidelity, Perception and Editability in GAN Latent Space](https://arxiv.org/abs/2301.12141) 4 | 5 | ## Abstract 6 | 7 | As the research progresses, inversion is mainly divided into two steps. The first step is \emph{Image Embedding}, in which an encoder or optimization process embeds images to get the corresponding latent codes. Afterward, the second step aims to refine the inversion and editing results, which we named \emph{Result Refinement}. Although the second step significantly improves fidelity, perception and editability are almost unchanged and deeply depend on inverse latent codes from first step. Therefore, a crucial problem is gaining the latent codes with better perception and editability while retaining the reconstruction fidelity. In this work, we first point out that these two characteristics are related to the degree of alignment (or disalignment) of the inverse codes with the synthetic distribution. Then, we propose \textbf{L}atent \textbf{S}pace \textbf{A}lignment Inversion \textbf{P}aradigm (LSAP), which consists of an evaluation metric and solutions for inversion. Specifically, we introduce Normalized Style Space ($\mathcal{S^N}$ space) and **N**ormalized **S**tyle Space **C**osine **D**istance (NSCD) to measure disalignment of inversion methods. Meanwhile, it can be optimized in both encoder-based and optimization-based embedding methods to conduct a uniform alignment solution. Extensive experiments in various domains demonstrate that NSCD effectively reflects perception and editability, and our alignment paradigm archives the state-of-the-art in both two stages. 8 | 9 | ![DHR](../../docs/lsap.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/lsap/lsap_ffhq_r50.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | ## Citation 25 | 26 | ```latex 27 | @article{cao2022lsap, 28 | title={LSAP: Rethinking Inversion Fidelity, Perception and Editability in GAN Latent Space}, 29 | author={Cao, Pu and Yang, Lu and Liu, Dongxv and Liu, Zhiwei and Li, Shan and Song, Qing}, 30 | journal={arXiv preprint arXiv:2209.12746}, 31 | year={2022} 32 | } 33 | ``` 34 | 35 | -------------------------------------------------------------------------------- /configs/hyperstyle/README.md: -------------------------------------------------------------------------------- 1 | # HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing (CVPR 2022) 2 | 3 | > [HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing](https://arxiv.org/abs/2111.15666) 4 | 5 | ## Abstract 6 | 7 | The inversion of real images into StyleGAN’s latent space is a well-studied problem. Nevertheless, applying existing approaches to real-world scenarios remains an open challenge, due to an inherent trade-off between reconstruction and editability: latent space regions which can accurately represent real images typically suffer from degraded semantic control. Recent work proposes to mitigate this trade-off by fine-tuning the generator to add the target image to well-behaved, editable regions of the latent space. While promising, this fine-tuning scheme is impractical for 8 | prevalent use as it requires a lengthy training phase for each new image. In this work, we introduce this approach into the realm of encoder-based inversion. We propose HyperStyle, a hypernetwork that learns to modulate StyleGAN’s weights to faithfully express a given image in editable regions of the latent space. A naive modulation approach would require training a hypernetwork with over three billion parameters. Through careful network design, we reduce this to be in line with existing encoders. HyperStyle yields reconstructions comparable to those of optimization techniques with the near real-time inference capabilities of encoders. Lastly, we demonstrate HyperStyle’s effectiveness on several applications beyond the inversion task, including the editing of out-of-domain images which were never seen during training. Code is available on our project page: https://yuval- alaluf.github.io/hyperstyle/. 9 | 10 | ![HyperStyle](../../docs/hyperstyle.png) 11 | 12 | ## Results 13 | 14 | TODO 15 | 16 | ## Inference 17 | 18 | ``` 19 | python scripts/infer.py \ 20 | --config configs/hyperstyle/wencoder_ffhq_r50.yaml configs/hyperstyle/hyperstyle.yaml \ 21 | --test_dataset_path /path/to/test/data 22 | --output_dir /path/to/output/dir 23 | ``` 24 | 25 | ## Citation 26 | 27 | ```latex 28 | @inproceedings{alaluf2022hyperstyle, 29 | title={Hyperstyle: Stylegan inversion with hypernetworks for real image editing}, 30 | author={Alaluf, Yuval and Tov, Omer and Mokady, Ron and Gal, Rinon and Bermano, Amit}, 31 | booktitle={Proceedings of the IEEE/CVF conference on computer Vision and pattern recognition}, 32 | pages={18511--18521}, 33 | year={2022} 34 | } 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /models/invertibility/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.invertibility.sync_batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, backbone, BatchNorm): 10 | super(Decoder, self).__init__() 11 | if backbone == 'resnet' or backbone == 'drn': 12 | low_level_inplanes = 256 13 | elif backbone == 'xception': 14 | low_level_inplanes = 128 15 | elif backbone == 'mobilenet': 16 | low_level_inplanes = 24 17 | else: 18 | raise NotImplementedError 19 | 20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 21 | self.bn1 = BatchNorm(48) 22 | self.relu = nn.ReLU() 23 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 24 | BatchNorm(256), 25 | nn.ReLU(), 26 | nn.Dropout(0.5), 27 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 28 | BatchNorm(256), 29 | nn.ReLU(), 30 | nn.Dropout(0.1), 31 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 32 | self._init_weight() 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | 57 | def build_decoder(num_classes, backbone, BatchNorm): 58 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /models/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class LatentCodesPool: 6 | """This class implements latent codes buffer that stores previously generated w latent codes. 7 | This buffer enables us to update discriminators using a history of generated w's 8 | rather than the ones produced by the latest encoder. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_ws = 0 19 | self.ws = [] 20 | 21 | def query(self, ws): 22 | """Return w's from the pool. 23 | Parameters: 24 | ws: the latest generated w's from the generator 25 | Returns w's from the buffer. 26 | By 50/100, the buffer will return input w's. 27 | By 50/100, the buffer will return w's previously stored in the buffer, 28 | and insert the current w's to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return ws 32 | return_ws = [] 33 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 34 | # w = torch.unsqueeze(image.data, 0) 35 | if w.ndim == 2: 36 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 37 | w = w[i] 38 | self.handle_w(w, return_ws) 39 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 40 | return return_ws 41 | 42 | def handle_w(self, w, return_ws): 43 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 44 | self.num_ws = self.num_ws + 1 45 | self.ws.append(w) 46 | return_ws.append(w) 47 | else: 48 | p = random.uniform(0, 1) 49 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 50 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 51 | tmp = self.ws[random_id].clone() 52 | self.ws[random_id] = w 53 | return_ws.append(tmp) 54 | else: # by another 50% chance, the buffer will return the current image 55 | return_ws.append(w) 56 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | from loguru import logger 6 | 7 | 8 | class MocoLoss(nn.Module): 9 | 10 | def __init__(self): 11 | super(MocoLoss, self).__init__() 12 | logger.info("Loading MOCO model from path: {}".format(model_paths["moco"])) 13 | self.model = self.__load_model() 14 | self.model.cuda() 15 | self.model.eval() 16 | 17 | @staticmethod 18 | def __load_model(): 19 | import torchvision.models as models 20 | model = models.__dict__["resnet50"]() 21 | # freeze all layers but the last fc 22 | for name, param in model.named_parameters(): 23 | if name not in ['fc.weight', 'fc.bias']: 24 | param.requires_grad = False 25 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 26 | state_dict = checkpoint['state_dict'] 27 | # rename moco pre-trained keys 28 | for k in list(state_dict.keys()): 29 | # retain only encoder_q up to before the embedding layer 30 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 31 | # remove prefix 32 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 33 | # delete renamed or unused k 34 | del state_dict[k] 35 | msg = model.load_state_dict(state_dict, strict=False) 36 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 37 | # remove output layer 38 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 39 | return model 40 | 41 | def extract_feats(self, x): 42 | x = F.interpolate(x, size=224) 43 | x_feats = self.model(x) 44 | x_feats = nn.functional.normalize(x_feats, dim=1) 45 | x_feats = x_feats.squeeze() 46 | return x_feats 47 | 48 | def forward(self, y_hat, y, x): 49 | n_samples = x.shape[0] 50 | x_feats = self.extract_feats(x) 51 | y_feats = self.extract_feats(y) 52 | y_hat_feats = self.extract_feats(y_hat) 53 | y_feats = y_feats.detach() 54 | loss = 0 55 | sim_improvement = 0 56 | sim_logs = [] 57 | count = 0 58 | for i in range(n_samples): 59 | diff_target = y_hat_feats[i].dot(y_feats[i]) 60 | diff_input = y_hat_feats[i].dot(x_feats[i]) 61 | diff_views = y_feats[i].dot(x_feats[i]) 62 | sim_logs.append({'diff_target': float(diff_target), 63 | 'diff_input': float(diff_input), 64 | 'diff_views': float(diff_views)}) 65 | loss += 1 - diff_target 66 | sim_diff = float(diff_target) - float(diff_views) 67 | sim_improvement += sim_diff 68 | count += 1 69 | 70 | return loss / count, sim_improvement / count, sim_logs 71 | -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import model_paths 4 | from models.encoders.model_irse import Backbone 5 | from loguru import logger 6 | 7 | 8 | class IDLoss(nn.Module): 9 | def __init__(self): 10 | super(IDLoss, self).__init__() 11 | logger.info(f'Loading ResNet ArcFace from path {model_paths["ir_se50"]}.') 12 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 13 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'], map_location='cpu')) 14 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 15 | self.facenet.eval() 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def forward(self, y_hat, y, x): 24 | n_samples = x.shape[0] 25 | x_feats = self.extract_feats(x) 26 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 27 | y_hat_feats = self.extract_feats(y_hat) 28 | y_feats = y_feats.detach() 29 | loss = 0 30 | sim_improvement = 0 31 | id_logs = [] 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | diff_input = y_hat_feats[i].dot(x_feats[i]) 36 | diff_views = y_feats[i].dot(x_feats[i]) 37 | id_logs.append({'diff_target': float(diff_target), 38 | 'diff_input': float(diff_input), 39 | 'diff_views': float(diff_views)}) 40 | loss += 1 - diff_target 41 | id_diff = float(diff_target) - float(diff_views) 42 | sim_improvement += id_diff 43 | count += 1 44 | 45 | return loss / count, sim_improvement / count, id_logs 46 | 47 | 48 | class IDLoss1(nn.Module): 49 | def __init__(self): 50 | super(IDLoss1, self).__init__() 51 | logger.info(f'Loading ResNet ArcFace from path {model_paths["ir_se50"]}.') 52 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 53 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'], map_location='cpu')) 54 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 55 | self.facenet.eval() 56 | 57 | def extract_feats(self, x): 58 | x = x[:, :, 35:223, 32:220] # Crop interesting region 59 | x = self.face_pool(x) 60 | x_feats = self.facenet(x) 61 | return x_feats 62 | 63 | def forward(self, y_hat, x): 64 | n_samples = x.shape[0] 65 | x_feats = self.extract_feats(x) 66 | y_hat_feats = self.extract_feats(y_hat) 67 | loss = 0 68 | count = 0 69 | for i in range(n_samples): 70 | diff_input = y_hat_feats[i].dot(x_feats[i]) 71 | loss += 1 - diff_input 72 | count += 1 73 | 74 | return loss / count -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /configs/pti/README.md: -------------------------------------------------------------------------------- 1 | # Pivotal Tuning (PTI) [ACM TOG 2022] 2 | 3 | > [Pivotal Tuning for Latent-based Editing of Real Images](https://arxiv.org/abs/2106.05744) 4 | 5 | ## Abstract 6 | 7 | Recently, a surge of advanced facial editing techniques have been proposed that leverage the generative power of a pre-trained StyleGAN. To successfully edit an image this way, one must first project (or invert) the image into the pre-trained generator’s domain. As it turns out, however, StyleGAN’s latent space induces an inherent tradeoff between distortion and editability, i.e. between maintaining the original appearance and convincingly altering some of its attributes. Practically, this means it is still challenging to apply ID-preserving facial latent-space editing to faces which are out of the generator’s domain. In this paper, we present an approach to bridge this gap. Our technique slightly alters the generator, so that an out-of-domain image is faithfully mapped into an in-domain latent code. The key idea is pivotal tuning — a brief training process that preserves the editing quality of an in-domain latent region, while changing its portrayed identity and appearance. In Pivotal Tuning Inversion (PTI), an initial inverted latent code serves as a pivot, around which the generator is finedtuned. At the same time, a regularization term keeps nearby identities intact, to locally contain the effect. This surgical training process ends up altering appearance features that represent mostly identity, without affecting editing capabilities. To supplement this, we further show that pivotal tuning can also adjust the generator to accommodate a multitude of faces, while introducing negligible distortion on the rest of the domain. We validate our technique through inversion and editing metrics, and show preferable scores to state-of-the-art methods. We further qualitatively demonstrate our technique by applying advanced edits (such as pose, age, or expression) to numerous images of well-known and recognizable identities. Finally, we demonstrate resilience to harder cases, including heavy make-up, elaborate hairstyles and/or headwear, which otherwise could not have been successfully inverted and edited by state-of-the-art methods. Source code can be found at:https://github.com/danielroich/PTI. 8 | 9 | ![PTI](../../docs/PTI.png) 10 | 11 | ## Results 12 | 13 | TODO 14 | 15 | ## Inference 16 | 17 | ``` 18 | python scripts/infer.py \ 19 | --config configs/pti/pti_pivot.yaml configs/pti/pti.yaml \ 20 | --test_dataset_path /path/to/test/data 21 | --output_dir /path/to/output/dir 22 | ``` 23 | 24 | - `--save_intermidiated`: If true, DHR will save intermediated information like segmentation, modulated feature, and modulated weight. 25 | 26 | ## Citation 27 | 28 | ```latex 29 | @article{roich2022pivotal, 30 | title={Pivotal tuning for latent-based editing of real images}, 31 | author={Roich, Daniel and Mokady, Ron and Bermano, Amit H and Cohen-Or, Daniel}, 32 | journal={ACM Transactions on Graphics (TOG)}, 33 | volume={42}, 34 | number={1}, 35 | pages={1--13}, 36 | year={2022}, 37 | publisher={ACM New York, NY} 38 | } 39 | ``` 40 | 41 | -------------------------------------------------------------------------------- /models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/invertibility/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append(".") 7 | sys.path.append("..") 8 | sys.path.append("...") 9 | 10 | from models.invertibility.aspp import build_aspp 11 | from models.invertibility.decoder import build_decoder 12 | from models.invertibility.backbone import build_backbone 13 | from models.invertibility.sync_batchnorm import SynchronizedBatchNorm2d 14 | 15 | 16 | class DeepLab(nn.Module): 17 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 18 | sync_bn=True, freeze_bn=False): 19 | super(DeepLab, self).__init__() 20 | if backbone == 'drn': 21 | output_stride = 8 22 | 23 | if sync_bn == True: 24 | BatchNorm = SynchronizedBatchNorm2d 25 | else: 26 | BatchNorm = nn.BatchNorm2d 27 | 28 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 29 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 30 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 31 | 32 | self.freeze_bn = freeze_bn 33 | 34 | def forward(self, input): 35 | x, low_level_feat = self.backbone(input) 36 | x = self.aspp(x) 37 | x = self.decoder(x, low_level_feat) 38 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 39 | 40 | return x 41 | 42 | def freeze_bn(self): 43 | for m in self.modules(): 44 | if isinstance(m, SynchronizedBatchNorm2d): 45 | m.eval() 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.eval() 48 | 49 | def get_1x_lr_params(self): 50 | modules = [self.backbone] 51 | for i in range(len(modules)): 52 | for m in modules[i].named_modules(): 53 | if self.freeze_bn: 54 | if isinstance(m[1], nn.Conv2d): 55 | for p in m[1].parameters(): 56 | if p.requires_grad: 57 | yield p 58 | else: 59 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 60 | or isinstance(m[1], nn.BatchNorm2d): 61 | for p in m[1].parameters(): 62 | if p.requires_grad: 63 | yield p 64 | 65 | def get_10x_lr_params(self): 66 | modules = [self.aspp, self.decoder] 67 | for i in range(len(modules)): 68 | for m in modules[i].named_modules(): 69 | if self.freeze_bn: 70 | if isinstance(m[1], nn.Conv2d): 71 | for p in m[1].parameters(): 72 | if p.requires_grad: 73 | yield p 74 | else: 75 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 76 | or isinstance(m[1], nn.BatchNorm2d): 77 | for p in m[1].parameters(): 78 | if p.requires_grad: 79 | yield p 80 | 81 | 82 | -------------------------------------------------------------------------------- /inference/hfgi_infer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from criteria.lpips.lpips import LPIPS 3 | import math 4 | from models.stylegan2.model import Generator 5 | from models.encoders.psp_encoders import ResidualAligner, ResidualEncoder 6 | from models.encoder import get_keys 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from utils.train_utils import load_train_checkpoint 11 | from inference.inference import BaseInference 12 | 13 | 14 | class HFGIInference(BaseInference): 15 | 16 | def __init__(self, opts, decoder=None): 17 | super(HFGIInference, self).__init__() 18 | self.opts = opts 19 | self.device = 'cuda' 20 | self.opts.device = self.device 21 | self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 22 | 23 | # resume from checkpoint 24 | checkpoint = load_train_checkpoint(opts) 25 | 26 | # initialize encoder and decoder 27 | if decoder is not None: 28 | self.decoder = decoder 29 | else: 30 | self.decoder = Generator(opts.resolution, 512, 8).to(self.device) 31 | self.decoder.eval() 32 | if checkpoint is not None: 33 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 34 | else: 35 | decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu') 36 | self.decoder.load_state_dict(decoder_checkpoint['g_ema']) 37 | 38 | self.align = ResidualAligner().to(self.device).eval() 39 | self.align.load_state_dict(checkpoint['align'], strict=False) 40 | self.residue = ResidualEncoder().to(self.device).eval() 41 | self.residue.load_state_dict(checkpoint['res'], strict=False) 42 | 43 | def inverse(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info): 44 | with torch.no_grad(): 45 | emb_images_resize = torch.nn.functional.interpolate(torch.clamp(emb_images, -1., 1.), size=(256, 256), mode='bilinear') 46 | res = images_resize - emb_images_resize 47 | res_align = self.align(torch.cat((res, emb_images_resize), 1)) 48 | conditions = self.residue(res_align) 49 | 50 | images, result_latent = self.decoder([emb_codes], 51 | input_is_latent=True, 52 | return_latents=True, 53 | randomize_noise=False, 54 | hfgi_conditions=conditions) 55 | 56 | return images, result_latent, None 57 | 58 | def edit(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info, editor): 59 | images, codes, refine_info = self.inverse(images, images_resize, image_paths, emb_codes, emb_images, emb_info) 60 | refine_info = refine_info[0] 61 | with torch.no_grad(): 62 | decoder = Generator(self.opts.resolution, 512, 8).to(self.device) 63 | decoder.train() 64 | decoder.load_state_dict(refine_info['generator'], strict=True) 65 | edit_codes = editor.edit_code(codes) 66 | 67 | edit_images, edit_codes = decoder([edit_codes], input_is_latent=True, randomize_noise=False) 68 | return images, edit_images, codes, edit_codes, refine_info -------------------------------------------------------------------------------- /inference/two_stage_inference.py: -------------------------------------------------------------------------------- 1 | from .code_infer import CodeInference 2 | from .encoder_infer import EncoderInference 3 | from .hfgi_infer import HFGIInference 4 | from .optim_infer import OptimizerInference 5 | from .pti_infer import PTIInference 6 | from .dhr_infer import DHRInference 7 | from .sam_infer import SamInference 8 | from .restyle_infer import RestyleInference 9 | from .hyper_infer import HyperstyleInference 10 | from inference.inference import BaseInference 11 | 12 | 13 | class TwoStageInference(): 14 | def __init__(self, opts, decoder=None): 15 | super(TwoStageInference, self).__init__() 16 | # mode in two stages 17 | embed_mode = opts.embed_mode 18 | refine_mode = opts.refine_mode 19 | self.refine_mode = refine_mode 20 | 21 | # Image Embedding 22 | if embed_mode == 'encoder': 23 | self.embedding_module = EncoderInference(opts) 24 | elif embed_mode == 'optim': 25 | self.embedding_module = OptimizerInference(opts) 26 | elif embed_mode == 'code': 27 | self.embedding_module = CodeInference(opts) 28 | elif embed_mode == 'restyle': 29 | self.embedding_module = RestyleInference(opts) 30 | else: 31 | raise Exception(f'Wrong embedding mode: {embed_mode}.') 32 | 33 | # Result Refinement 34 | if refine_mode == 'pti': 35 | self.refinement_module = PTIInference(opts) 36 | elif refine_mode == 'dhr': 37 | self.refinement_module = DHRInference(opts) 38 | elif refine_mode == 'sam': 39 | self.refinement_module = SamInference(opts) 40 | elif refine_mode == 'hfgi': 41 | self.refinement_module = HFGIInference(opts) 42 | elif refine_mode == 'hyperstyle': 43 | self.refinement_module = HyperstyleInference(opts) 44 | elif refine_mode is None: 45 | self.refinement_module = None 46 | else: 47 | raise Exception(f'Wrong embedding mode: {refine_mode}.') 48 | 49 | def inverse(self, images, images_resize, image_paths): 50 | emb_images, emb_codes, emb_info = self.embedding_module.inverse(images, images_resize, image_paths) 51 | if self.refine_mode is not None: 52 | refine_images, refine_codes, refine_info = \ 53 | self.refinement_module.inverse(images, images_resize, image_paths, emb_codes, emb_images, emb_info) 54 | else: 55 | refine_images, refine_codes, refine_info = None, None, None 56 | 57 | return emb_images, emb_codes, emb_info, refine_images, refine_codes, refine_info 58 | 59 | def edit(self, images, images_resize, image_paths, editor): 60 | emb_codes, emb_codes_edit, emb_images, emb_images_edit, emb_info, \ 61 | refine_codes, refine_codes_edit, refine_images, refine_images_edit, refine_info = [None] * 10 62 | emb_images, emb_images_edit, emb_codes, emb_codes_edit, emb_info = \ 63 | self.embedding_module.edit(images, images_resize, image_paths, editor) 64 | if self.refine_mode is not None: 65 | refine_images, refine_images_edit, refine_codes, refine_codes_edit, refine_info = \ 66 | self.refinement_module.edit(images, images_resize, image_paths, emb_codes, emb_images, emb_info, editor) 67 | 68 | return emb_images_edit, emb_codes_edit, emb_info, refine_images_edit, refine_codes_edit, refine_info 69 | -------------------------------------------------------------------------------- /models/invertibility/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /scripts/edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | sys.path.append('..') 5 | 6 | from tqdm import tqdm 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | 11 | from datasets.inference_dataset import InversionDataset 12 | from inference import TwoStageInference 13 | from editing import * 14 | from utils.common import tensor2im 15 | from options.test_options import TestOptions 16 | import torchvision.transforms as transforms 17 | 18 | 19 | def main(): 20 | opts = TestOptions().parse() 21 | if opts.checkpoint_path is None: 22 | opts.auto_resume = True 23 | 24 | # load edit direction 25 | if opts.edit_mode == 'interfacegan': 26 | editor = InterFaceGAN(opts) 27 | elif opts.edit_mode == 'ganspace': 28 | editor = GANSpace(opts) 29 | else: 30 | raise ValueError(f'Undefined editing mode: {opts.edit_mode}') 31 | 32 | save_folder = editor.save_folder 33 | inversion = TwoStageInference(opts) 34 | 35 | if opts.output_dir is None: 36 | opts.output_dir = opts.exp_dir 37 | os.makedirs(opts.output_dir, exist_ok=True) 38 | os.makedirs(os.path.join(opts.output_dir, save_folder), exist_ok=True) 39 | 40 | if opts.output_resolution is not None and len(opts.output_resolution) == 1: 41 | opts.output_resolution = (opts.output_resolution, opts.output_resolution) 42 | 43 | transform = transforms.Compose([ 44 | transforms.Resize((256, 256)), 45 | transforms.ToTensor(), 46 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 47 | transform_no_resize = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 50 | 51 | if os.path.isdir(opts.test_dataset_path): 52 | dataset = InversionDataset(root=opts.test_dataset_path, transform=transform, 53 | transform_no_resize=transform_no_resize) 54 | dataloader = DataLoader(dataset, 55 | batch_size=opts.test_batch_size, 56 | shuffle=False, 57 | num_workers=int(opts.test_workers), 58 | drop_last=False) 59 | else: 60 | img = Image.open(opts.test_dataset_path) 61 | img = img.convert('RGB') 62 | img_aug = transform(img) 63 | img_aug_no_resize = transform_no_resize(img) 64 | dataloader = [(img_aug[None], [opts.test_dataset_path], img_aug_no_resize[None])] 65 | 66 | for input_batch in tqdm(dataloader): 67 | images_resize, img_paths, images = input_batch 68 | images_resize, images = images_resize.cuda(), images.cuda() 69 | 70 | # with torch.no_grad(): 71 | emb_images_edit, emb_codes_edit, emb_info, refine_images_edit, refine_codes_edit, refine_info \ 72 | = inversion.edit(images, images_resize, img_paths, editor) 73 | 74 | edit_images = refine_images_edit if refine_images_edit is not None else emb_images_edit 75 | 76 | H, W = edit_images.shape[2:] 77 | for path, edit_img in zip(img_paths, edit_images): 78 | basename = os.path.basename(path).split('.')[0] 79 | if opts.output_resolution is not None and ((H, W) != opts.output_resolution): 80 | edit_img = torch.nn.functional.resize(edit_img, opts.output_resolution) 81 | edit_result = tensor2im(edit_img) 82 | edit_result.save(os.path.join(opts.output_dir, save_folder, f'{basename}.jpg')) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset Instruction 2 | 3 | ## Prepare Datasets 4 | Link dataset to ```./data/``` 5 | ```bash 6 | ln -sf /path/to/dataset data/ 7 | ``` 8 | For example: 9 | ```bash 10 | ln -sf /Database/FFHQ data/ 11 | ln -sf /Database/CelebA-HQ data/ 12 | ``` 13 | Default path in our configs: 14 | ```bash 15 | # Human face 16 | data/FFHQ 17 | data/CelebA-HQ 18 | # Animal 19 | data/AFHQ 20 | # Church 21 | data/Lsun-church 22 | ``` 23 | --- 24 | ## Commonly Used Datasets 25 | | Domain | Dataset | Generator | Height | Width | Comment | 26 | | :---------: | :----------: | :-------: | :----: | :---: | :-----: | 27 | | Face | [FFHQ](#ffhq) | [StyleGAN2*](https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-f.pkl)
[convert](https://drive.google.com/file/d/1fgehC3QTtEayc_AFdsCx-UkxU8d1d9zW/view?usp=sharing) | 1024 | 1024 | | 28 | | Face | [CelebA](#celeba-hq) | - | 1024 | 1024 | evaluation | 29 | | Cat | [AFHQ Cat](#afhq) | [StyleGAN2-ADA](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl) | 512 | 512 | | 30 | | Cat | [LSUN Cat](#lsun) | [StyleGAN2*](https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkll)
[convert](https://drive.google.com/file/d/1s6x2BApGb0Sfp_hQyi59yM0IES5qYlBM/view?usp=sharing) | 256 | 256 | | 31 | | Dog | [AFHQ Dog](#afhq) | [StyleGAN2-ADA](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl) | 512 | 512 | | 32 | | Wild Animal | [AFHQ Wild](#afhq) | [StyleGAN2-ADA](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/#:~:text=41%3A12%20AM-,afhqwild.pkl,-363959591) | 512 | 512 | | 33 | | Horse | [LSUN Horse](#lsun) | [StyleGAN2*](https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-horse-config-f.pkl)
convert | 256 | 256 | | 34 | | Car | [Stanford Car](#stanford-car) | - | | | train inversion | 35 | | Car | [LSUN Car](#lsun) | [StyleGAN2*](https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-f.pkl)
[convert](https://drive.google.com/file/d/1d5ATre9K1cVo9m6WCzjrONOPsMdhrHL3/view?usp=sharing) | 384 | 512 | Crop | 36 | | Church | [LSUN Church](#lsun) | [StyleGAN2*](https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-f.pkl)
[convert](https://drive.google.com/file/d/1TWeD0I2zfB53LAKoQ16hTnblZ2AhspJu/view?usp=sharing) | 256 | 256 | | 37 | 38 | * ```*```: weight is from the original [StyleGAN2]([NVlabs/stylegan2: StyleGAN2 - Official TensorFlow Implementation (github.com)](https://github.com/NVlabs/stylegan2)) (TensorFlow-based), which needs to convert by [script](scripts/convert_weight.py). We also provide the converted weights, which are converted by [this implementation](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/export_weights.py). 39 | 40 | ## Face Dataset 41 | We use FFHQ (70,000 images) for training and CelebA-HQ test dataset (2824 images) for testing. 42 | ### FFHQ 43 | Download script can be found in [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset). 44 | 45 | By default, the image path follows: ```data/FFHQ/xxxxx.png```. 46 | 47 | ### CelebA-HQ Test 48 | CelebA-HQ is a subset of CelebA dataset, we share the test split (2824 images) on drive. 49 | 50 | By default, the image path follows: ```data/CelebA-HQ/test/xxxxxx.jpg```. 51 | 52 | ## AFHQ (Animal) 53 | AFHQ consists of three categories: cat, dog, and wild. There are two versions (i.e., AFHQ and AFHQv2), and we use AFHQ by default. 54 | 55 | TODO 56 | 57 | ## LSUN (Scene and Object) 58 | 59 | TODO -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from loguru import logger 4 | 5 | 6 | def load_train_checkpoint(opts, best=False): 7 | if opts.auto_resume: 8 | if best: 9 | train_ckpt_path = os.path.join(opts.exp_dir, 'checkpoints/last.pt') 10 | else: 11 | train_ckpt_path = os.path.join(opts.exp_dir, 'checkpoints/best_model.pt') 12 | if os.path.isfile(train_ckpt_path): 13 | previous_train_ckpt = torch.load(train_ckpt_path, map_location='cpu') 14 | else: 15 | previous_train_ckpt = None 16 | else: 17 | train_ckpt_path = opts.checkpoint_path 18 | if train_ckpt_path is None: 19 | previous_train_ckpt = None 20 | else: 21 | previous_train_ckpt = torch.load(opts.checkpoint_path, map_location='cpu') 22 | 23 | previous_train_ckpt = convert_weight(previous_train_ckpt, opts) 24 | if previous_train_ckpt is not None: 25 | opts.checkpoint_path = train_ckpt_path 26 | return previous_train_ckpt 27 | 28 | 29 | def aggregate_loss_dict(agg_loss_dict): 30 | mean_vals = {} 31 | for output in agg_loss_dict: 32 | for key in output: 33 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 34 | for key in mean_vals: 35 | if len(mean_vals[key]) > 0: 36 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 37 | else: 38 | print('{} has no value'.format(key)) 39 | mean_vals[key] = 0 40 | return mean_vals 41 | 42 | 43 | def get_train_progressive_stage(stages, step): 44 | if stages is None: 45 | return -1 46 | for i in range(len(stages) - 1): 47 | if stages[i] <= step < stages[i + 1]: 48 | return i 49 | return len(stages) - 1 50 | 51 | 52 | def requires_grad(model, flag=True): 53 | for p in model.parameters(): 54 | p.requires_grad = flag 55 | 56 | 57 | def convert_weight(weight, opts): 58 | """Convert psp/e4e weights from original repo to GAN Inverter.""" 59 | if weight is not None: 60 | if 'encoder' not in weight: 61 | logger.info('Resume from official weight. Converting to GAN Inverter weight.......') 62 | encoder_weight, decoder_weight = dict(), dict() 63 | if opts.refine_mode == 'hfgi': 64 | align_weight, res_weight = dict(), dict() 65 | elif opts.refine_mode == 'hyperstyle': 66 | hypernet_weight = dict() 67 | for k, v in weight['state_dict'].items(): 68 | if k.startswith('encoder.'): 69 | encoder_weight[k] = v 70 | elif k.startswith('decoder.'): 71 | decoder_weight[k[8:]] = v 72 | 73 | if opts.refine_mode == 'hfgi': 74 | if k.startswith('residue.'): 75 | res_weight[k[8:]] = v 76 | elif k.startswith('grid_align.'): 77 | align_weight[k[11:]] = v 78 | elif opts.refine_mode == 'hyperstyle': 79 | if k.startswith('hypernet.'): 80 | hypernet_weight[k[9:]] = v 81 | 82 | encoder_weight['latent_avg'] = weight['latent_avg'] 83 | weight = dict( 84 | encoder=encoder_weight, 85 | decoder=decoder_weight, 86 | ) 87 | if opts.refine_mode == 'hfgi': 88 | weight['align'] = align_weight 89 | weight['res'] = res_weight 90 | elif opts.refine_mode == 'hyperstyle': 91 | weight['hypernet'] = hypernet_weight 92 | return weight 93 | -------------------------------------------------------------------------------- /utils/facer/facer/util.py: -------------------------------------------------------------------------------- 1 | from numpy import isin 2 | import torch 3 | from typing import Any, Generator, Optional, Tuple, Union, List, Dict 4 | import math 5 | import os 6 | from urllib.parse import urlparse 7 | import errno 8 | import sys 9 | import validators 10 | 11 | 12 | def hwc2bchw(images: torch.Tensor) -> torch.Tensor: 13 | return images.unsqueeze(0).permute(0, 3, 1, 2) 14 | 15 | 16 | def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2, 17 | background_value: float = 0) -> torch.Tensor: 18 | """ make a grid image from an image batch. 19 | 20 | Args: 21 | images (torch.Tensor): input image batch. 22 | nrows: rows of grid. 23 | border: border size in pixel. 24 | background_value: color value of background. 25 | """ 26 | assert images.ndim == 4 # n x c x h x w 27 | images = images.permute(0, 2, 3, 1) # n x h x w x c 28 | n, h, w, c = images.shape 29 | if nrows is None: 30 | nrows = max(int(math.sqrt(n)), 1) 31 | ncols = (n + nrows - 1) // nrows 32 | result = torch.full([(h + border) * nrows - border, 33 | (w + border) * ncols - border, c], background_value, 34 | device=images.device, 35 | dtype=images.dtype) 36 | 37 | for i, single_image in enumerate(images): 38 | row = i // ncols 39 | col = i % ncols 40 | yy = (h + border) * row 41 | xx = (w + border) * col 42 | result[yy:(yy+h), xx:(xx+w), :] = single_image 43 | return result 44 | 45 | 46 | def select_data(selection, data): 47 | if isinstance(data, dict): 48 | return {name: select_data(selection, val) for name, val in data.items()} 49 | elif isinstance(data, (list, tuple)): 50 | return [select_data(selection, val) for val in data] 51 | elif isinstance(data, torch.Tensor): 52 | return data[selection] 53 | return data 54 | 55 | 56 | def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None): 57 | if isinstance(url_or_paths, str): 58 | url_or_paths = [url_or_paths] 59 | 60 | for url_or_path in url_or_paths: 61 | try: 62 | if validators.url(url_or_path): 63 | url = url_or_path 64 | if model_dir is None: 65 | hub_dir = torch.hub.get_dir() 66 | model_dir = os.path.join(hub_dir, 'checkpoints') 67 | 68 | try: 69 | os.makedirs(model_dir) 70 | except OSError as e: 71 | if e.errno == errno.EEXIST: 72 | # Directory already exists, ignore. 73 | pass 74 | else: 75 | # Unexpected OSError, re-raise. 76 | raise 77 | 78 | parts = urlparse(url) 79 | filename = os.path.basename(parts.path) 80 | cached_file = os.path.join(model_dir, filename) 81 | if not os.path.exists(cached_file): 82 | sys.stderr.write( 83 | 'Downloading: "{}" to {}\n'.format(url, cached_file)) 84 | hash_prefix = None 85 | torch.hub.download_url_to_file( 86 | url, cached_file, hash_prefix, progress=True) 87 | else: 88 | cached_file = url_or_path 89 | 90 | return torch.jit.load(cached_file, map_location=map_location) 91 | except: 92 | sys.stderr.write(f'failed downloading from {url_or_path}\n') 93 | 94 | raise RuntimeError('failed to download jit models from all given urls') 95 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = 'cuda' 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, 'float32') 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack([ 93 | np.round((stride * inds[1] + 1.0) / scale), 94 | np.round((stride * inds[0] + 1.0) / scale), 95 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 96 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 97 | score, offsets 98 | ]) 99 | # why one is added? 100 | 101 | return bounding_boxes.T 102 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import yaml 4 | from argparse import ArgumentParser 5 | 6 | 7 | def str2bool(arg): 8 | ua = str(arg).upper() 9 | if 'TRUE'.startswith(ua): 10 | return True 11 | elif 'FALSE'.startswith(ua): 12 | return False 13 | else: 14 | raise Exception('Error!') 15 | 16 | 17 | class BaseOptions: 18 | 19 | def __init__(self): 20 | self.config_parser = self.parser = ArgumentParser() 21 | self.initialize() 22 | 23 | def initialize(self): 24 | self.parser.add_argument('-c', '--config', default='', type=str, nargs='+', metavar='FILE', help='YAML config file ' 25 | 'specifying default ' 26 | 'arguments') 27 | self.parser.add_argument('--exp_dir', default='one_shot', type=str, help='Path to experiment output directory') 28 | self.parser.add_argument('--resolution', default=1024, type=int, help='Resolution of generator') 29 | 30 | # Data options 31 | self.parser.add_argument('--transform_type', default='encodetransforms', type=str, help='Type of dataset trans') 32 | self.parser.add_argument('--train_dataset_path', default=None, type=str) 33 | self.parser.add_argument('--test_dataset_path', default='../test', type=str) 34 | self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training.') 35 | self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference.') 36 | self.parser.add_argument('--workers', default=0, type=int, help='Number of train dataloader workers.') 37 | self.parser.add_argument('--test_workers', default=0, type=int, 38 | help='Number of test/inference dataloader workers.') 39 | self.parser.add_argument('--auto_resume', default='False', type=str2bool, 40 | help='Whether to automatically resume. During training, the last checkpoint will be ' 41 | 'used to resume while the best checkpoint will be used in ' 42 | 'inference/evaluation/editing process.') 43 | 44 | # Model options 45 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint.') 46 | self.parser.add_argument('--stylegan_weights', default="", type=str, help='Path to StyleGAN model weights.') 47 | self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use') 48 | self.parser.add_argument('--start_from_latent_avg', action='store_true', help='Whether to add average latent ' 49 | 'vector to generate codes from ' 50 | 'encoder.') 51 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+') 52 | self.parser.add_argument('--input_nc', default=3, type=int, help='number of channels of the first encoder layer') 53 | self.parser.add_argument('--layers', default=50, type=int, help='Number of layers of backbone') 54 | 55 | def parse(self): 56 | opts = self.config_parser.parse_args() 57 | if opts.config: 58 | for config in opts.config: 59 | with open(config, 'r') as f: 60 | cfg = yaml.safe_load(f) 61 | self.parser.set_defaults(**cfg) 62 | opts = self.parser.parse_args() 63 | return opts 64 | -------------------------------------------------------------------------------- /models/bisenet/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.model_zoo as modelzoo 7 | 8 | # from modules.bn import InPlaceABNSync as BatchNorm2d 9 | 10 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | def __init__(self, in_chan, out_chan, stride=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(in_chan, out_chan, stride) 23 | self.bn1 = nn.BatchNorm2d(out_chan) 24 | self.conv2 = conv3x3(out_chan, out_chan) 25 | self.bn2 = nn.BatchNorm2d(out_chan) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = None 28 | if in_chan != out_chan or stride != 1: 29 | self.downsample = nn.Sequential( 30 | nn.Conv2d(in_chan, out_chan, 31 | kernel_size=1, stride=stride, bias=False), 32 | nn.BatchNorm2d(out_chan), 33 | ) 34 | 35 | def forward(self, x): 36 | residual = self.conv1(x) 37 | residual = F.relu(self.bn1(residual)) 38 | residual = self.conv2(residual) 39 | residual = self.bn2(residual) 40 | 41 | shortcut = x 42 | if self.downsample is not None: 43 | shortcut = self.downsample(x) 44 | 45 | out = shortcut + residual 46 | out = self.relu(out) 47 | return out 48 | 49 | 50 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 51 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 52 | for i in range(bnum-1): 53 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 54 | return nn.Sequential(*layers) 55 | 56 | 57 | class Resnet18(nn.Module): 58 | def __init__(self): 59 | super(Resnet18, self).__init__() 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = nn.BatchNorm2d(64) 63 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 64 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 65 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 66 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 67 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 68 | self.init_weight() 69 | 70 | def forward(self, x): 71 | x = self.conv1(x) 72 | x = F.relu(self.bn1(x)) 73 | x = self.maxpool(x) 74 | 75 | x = self.layer1(x) 76 | feat8 = self.layer2(x) # 1/8 77 | feat16 = self.layer3(feat8) # 1/16 78 | feat32 = self.layer4(feat16) # 1/32 79 | return feat8, feat16, feat32 80 | 81 | def init_weight(self): 82 | state_dict = modelzoo.load_url(resnet18_url) 83 | self_state_dict = self.state_dict() 84 | for k, v in state_dict.items(): 85 | if 'fc' in k: continue 86 | self_state_dict.update({k: v}) 87 | self.load_state_dict(self_state_dict) 88 | 89 | def get_params(self): 90 | wd_params, nowd_params = [], [] 91 | for name, module in self.named_modules(): 92 | if isinstance(module, (nn.Linear, nn.Conv2d)): 93 | wd_params.append(module.weight) 94 | if not module.bias is None: 95 | nowd_params.append(module.bias) 96 | elif isinstance(module, nn.BatchNorm2d): 97 | nowd_params += list(module.parameters()) 98 | return wd_params, nowd_params 99 | -------------------------------------------------------------------------------- /pretrained_models/download_models.sh: -------------------------------------------------------------------------------- 1 | root="pretrained_models" 2 | 3 | function download_model() { 4 | if [ ! -d "$root/$1" ]; then 5 | mkdir "$root/$1" 6 | fi 7 | if [ -f "$root/$1/$2" ]; then 8 | echo "文件 $2 已经存在,不需要下载" 9 | else 10 | wget "$3" -O "$root/$1/$2" 11 | echo "已下载文件:$2" 12 | fi 13 | } 14 | 15 | # 如果没有指定任何参数,则下载所有文件 16 | if [ $# -eq 0 ]; then 17 | echo "没有指定任何参数,将下载所有文件..." 18 | download_model "generator" "stylegan2-ffhq-config-f.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/stylegan2-ffhq-config-f.pt" 19 | download_model "e4e" "e4e_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/e4e_ffhq_r50_wp_official.pt" 20 | download_model "hfgi" "hfgi_ffhq_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hfgi_ffhq_official.pt" 21 | download_model "hyperstyle" "hyperstyle_ffhq_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hyperstyle_ffhq_official.pt" 22 | download_model "hyperstyle" "hyperstyle_ffhq_r50_w_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hyperstyle_ffhq_r50_w_official.pt" 23 | download_model "lsap" "lsap_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/lsap_ffhq_r50_wp_official.pt" 24 | download_model "other" "model_ir_se50.pth" "https://github.com/caopulan/GANInverter/releases/download/v0.1/model_ir_se50.pth" 25 | download_model "psp" "psp_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/psp_ffhq_r50_wp_official.pt" 26 | download_model "psp" "psp_ffhq_r50_w_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/psp_ffhq_r50_w_official.pt" 27 | download_model "restyle" "restyle-e4e_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/restyle-e4e_ffhq_r50_wp_official.pt" 28 | echo "所有文件下载完成..." 29 | exit 0 30 | fi 31 | 32 | # 下载指定的文件 33 | for type in "$@"; do 34 | if [ "$type" == "generator" ]; then 35 | download_model "generator" "stylegan2-ffhq-config-f.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/stylegan2-ffhq-config-f.pt" 36 | fi 37 | if [ "$type" == "e4e" ]; then 38 | download_model "e4e" "e4e_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/e4e_ffhq_r50_wp_official.pt" 39 | fi 40 | if [ "$type" == "hfgi" ]; then 41 | download_model "hfgi" "hfgi_ffhq_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hfgi_ffhq_official.pt" 42 | fi 43 | if [ "$type" == "hyperstyle" ]; then 44 | download_model "hyperstyle" "hyperstyle_ffhq_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hyperstyle_ffhq_official.pt" 45 | download_model "hyperstyle" "hyperstyle_ffhq_r50_w_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/hyperstyle_ffhq_r50_w_official.pt" 46 | fi 47 | if [ "$type" == "lsap" ]; then 48 | download_model "lsap" "lsap_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/lsap_ffhq_r50_wp_official.pt" 49 | fi 50 | if [ "$type" == "other" ]; then 51 | download_model "other" "model_ir_se50.pth" "https://github.com/caopulan/GANInverter/releases/download/v0.1/model_ir_se50.pth" 52 | fi 53 | if [ "$type" == "psp" ]; then 54 | download_model "psp" "psp_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/psp_ffhq_r50_wp_official.pt" 55 | download_model "psp" "psp_ffhq_r50_w_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/psp_ffhq_r50_w_official.pt" 56 | fi 57 | if [ "$type" == "restyle" ]; then 58 | download_model "restyle" "restyle-e4e_ffhq_r50_wp_official.pt" "https://github.com/caopulan/GANInverter/releases/download/v0.1/restyle-e4e_ffhq_r50_wp_official.pt" 59 | fi 60 | done 61 | -------------------------------------------------------------------------------- /models/invertibility/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.invertibility.sync_batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class _ASPPModule(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 10 | super(_ASPPModule, self).__init__() 11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 12 | stride=1, padding=padding, dilation=dilation, bias=False) 13 | self.bn = BatchNorm(planes) 14 | self.relu = nn.ReLU() 15 | 16 | self._init_weight() 17 | 18 | def forward(self, x): 19 | x = self.atrous_conv(x) 20 | x = self.bn(x) 21 | 22 | return self.relu(x) 23 | 24 | def _init_weight(self): 25 | for m in self.modules(): 26 | if isinstance(m, nn.Conv2d): 27 | torch.nn.init.kaiming_normal_(m.weight) 28 | elif isinstance(m, SynchronizedBatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.BatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | 35 | 36 | class ASPP(nn.Module): 37 | def __init__(self, backbone, output_stride, BatchNorm): 38 | super(ASPP, self).__init__() 39 | if backbone == 'drn': 40 | inplanes = 512 41 | elif backbone == 'mobilenet': 42 | inplanes = 320 43 | else: 44 | inplanes = 2048 45 | if output_stride == 16: 46 | dilations = [1, 6, 12, 18] 47 | elif output_stride == 8: 48 | dilations = [1, 12, 24, 36] 49 | else: 50 | raise NotImplementedError 51 | 52 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 53 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 54 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 55 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 56 | 57 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 58 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 59 | BatchNorm(256), 60 | nn.ReLU()) 61 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 62 | self.bn1 = BatchNorm(256) 63 | self.relu = nn.ReLU() 64 | self.dropout = nn.Dropout(0.5) 65 | self._init_weight() 66 | 67 | def forward(self, x): 68 | x1 = self.aspp1(x) 69 | x2 = self.aspp2(x) 70 | x3 = self.aspp3(x) 71 | x4 = self.aspp4(x) 72 | x5 = self.global_avg_pool(x) 73 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 74 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 75 | 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | 80 | return self.dropout(x) 81 | 82 | def _init_weight(self): 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | torch.nn.init.kaiming_normal_(m.weight) 88 | elif isinstance(m, SynchronizedBatchNorm2d): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | elif isinstance(m, nn.BatchNorm2d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | 95 | 96 | def build_aspp(backbone, output_stride, BatchNorm): 97 | return ASPP(backbone, output_stride, BatchNorm) 98 | -------------------------------------------------------------------------------- /inference/restyle_infer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | from models.encoder import Encoder 5 | from models.stylegan2.model import Generator 6 | import torch 7 | 8 | from utils.common import tensor2im 9 | from utils.train_utils import load_train_checkpoint 10 | from inference.inference import BaseInference 11 | 12 | 13 | class RestyleInference(BaseInference): 14 | 15 | def __init__(self, opts, decoder=None): 16 | super(RestyleInference, self).__init__() 17 | self.opts = opts 18 | self.device = 'cuda' 19 | self.opts.device = self.device 20 | self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 21 | 22 | # resume from checkpoint 23 | checkpoint = load_train_checkpoint(opts) 24 | # initialize encoder and decoder 25 | latent_avg = None 26 | if decoder is not None: 27 | self.decoder = decoder 28 | else: 29 | self.decoder = Generator(opts.resolution, 512, 8).to(self.device) 30 | self.decoder.eval() 31 | if checkpoint is not None: 32 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 33 | else: 34 | decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu') 35 | self.decoder.load_state_dict(decoder_checkpoint['g_ema']) 36 | latent_avg = decoder_checkpoint['latent_avg'] 37 | if latent_avg is None: 38 | latent_avg = self.decoder.mean_latent(int(1e5))[0].detach() if checkpoint is None else checkpoint['encoder']['latent_avg'].unsqueeze(0).to(self.device) 39 | self.encoder = Encoder(opts, checkpoint, latent_avg, device=self.device).to(self.device).eval() 40 | self.encoder.set_progressive_stage(self.opts.n_styles) 41 | 42 | with torch.no_grad(): 43 | self.avg_image, self.avg_latent = self.decoder([latent_avg], 44 | input_is_latent=True, 45 | randomize_noise=False, 46 | return_latents=True) 47 | self.avg_image = self.avg_image.float().detach() 48 | 49 | # inv_result = tensor2im(self.avg_image[0]) 50 | # inv_result.save(os.path.join(self.opts.output_dir, 'inversion', f'avg.jpg')) 51 | 52 | def inverse(self, images, images_resize, image_path, **kwargs): 53 | with torch.no_grad(): 54 | for iter in range(self.opts.restyle_iteration): 55 | if iter == 0: 56 | avg_image = torch.nn.AdaptiveAvgPool2d((256, 256))(self.avg_image) 57 | avg_images = avg_image.repeat(images_resize.shape[0], 1, 1, 1) 58 | x_input = torch.cat([images_resize, avg_images], dim=1) 59 | result_latent = self.avg_latent.repeat(images_resize.shape[0], 1, 1) 60 | else: 61 | images = torch.nn.AdaptiveAvgPool2d((256, 256))(images) 62 | x_input = torch.cat([images_resize, images], dim=1) 63 | 64 | codes = self.encoder(x_input) 65 | codes = codes + result_latent 66 | images, result_latent = self.decoder([codes], 67 | input_is_latent=True, 68 | randomize_noise=False, 69 | return_latents=True) 70 | # for path, inv_img in zip(image_path, images): 71 | # basename = os.path.basename(path).split('.')[0] + '_' + str(iter) 72 | # inv_result = tensor2im(inv_img) 73 | # inv_result.save(os.path.join(self.opts.output_dir, 'inversion', f'{basename}.jpg')) 74 | 75 | return images, result_latent, None 76 | 77 | def edit(self, images, images_resize, image_path, editor): 78 | images, codes, _ = self.inverse(images, images_resize, image_path) 79 | edit_codes = editor.edit_code(codes) 80 | edit_images = self.generate(edit_codes) 81 | return images, edit_images, codes, edit_codes, None -------------------------------------------------------------------------------- /scripts/calc_id_loss.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | import sys 7 | from PIL import Image 8 | import multiprocessing as mp 9 | import math 10 | import torch 11 | import torchvision.transforms as trans 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from models.mtcnn.mtcnn import MTCNN 17 | from models.encoders.model_irse import IR_101 18 | from configs.paths_config import model_paths 19 | 20 | CIRCULAR_FACE_PATH = model_paths['circular_face'] 21 | 22 | 23 | def chunks(lst, n): 24 | """Yield successive n-sized chunks from lst.""" 25 | for i in range(0, len(lst), n): 26 | yield lst[i:i + n] 27 | 28 | 29 | def extract_on_paths(file_paths): 30 | facenet = IR_101(input_size=112) 31 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 32 | facenet.cuda() 33 | facenet.eval() 34 | mtcnn = MTCNN() 35 | id_transform = trans.Compose([ 36 | trans.ToTensor(), 37 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 38 | ]) 39 | 40 | pid = mp.current_process().name 41 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 42 | tot_count = len(file_paths) 43 | count = 0 44 | 45 | scores_dict = {} 46 | for res_path, gt_path in file_paths: 47 | count += 1 48 | if count % 100 == 0: 49 | print('{} done with {}/{}'.format(pid, count, tot_count)) 50 | if True: 51 | input_im = Image.open(res_path) 52 | input_im, _ = mtcnn.align(input_im) 53 | if input_im is None: 54 | print('{} skipping {}'.format(pid, res_path)) 55 | continue 56 | 57 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 58 | 59 | result_im = Image.open(gt_path) 60 | result_im, _ = mtcnn.align(result_im) 61 | if result_im is None: 62 | print('{} skipping {}'.format(pid, gt_path)) 63 | continue 64 | 65 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 66 | score = float(input_id.dot(result_id)) 67 | scores_dict[os.path.basename(gt_path)] = score 68 | 69 | return scores_dict 70 | 71 | 72 | def parse_args(): 73 | parser = ArgumentParser(add_help=False) 74 | parser.add_argument('--num_threads', type=int, default=4) 75 | parser.add_argument('--data_path', type=str, default='results') 76 | parser.add_argument('--gt_path', type=str, default='gt_images') 77 | args = parser.parse_args() 78 | return args 79 | 80 | 81 | def run(args): 82 | file_paths = [] 83 | for f in os.listdir(args.data_path): 84 | image_path = os.path.join(args.data_path, f) 85 | gt_path = os.path.join(args.gt_path, f) 86 | if f.endswith(".jpg") or f.endswith('.png'): 87 | file_paths.append([image_path, gt_path.replace('.png', '.jpg')]) 88 | 89 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 90 | pool = mp.Pool(args.num_threads) 91 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 92 | 93 | tic = time.time() 94 | results = pool.map(extract_on_paths, file_chunks) 95 | scores_dict = {} 96 | for d in results: 97 | scores_dict.update(d) 98 | 99 | all_scores = list(scores_dict.values()) 100 | mean = np.mean(all_scores) 101 | std = np.std(all_scores) 102 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 103 | print(result_str) 104 | 105 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 106 | if not os.path.exists(out_path): 107 | os.makedirs(out_path) 108 | 109 | with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: 110 | f.write(result_str) 111 | with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: 112 | json.dump(scores_dict, f) 113 | 114 | toc = time.time() 115 | print('Mischief managed in {}s'.format(toc - tic)) 116 | 117 | 118 | if __name__ == '__main__': 119 | args = parse_args() 120 | run(args) 121 | -------------------------------------------------------------------------------- /utils/facer/facer/face_parsing/farl.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | import functools 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ..util import download_jit 7 | from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, 8 | make_inverted_tanh_warp_grid, make_tanh_warp_grid) 9 | from .base import FaceParser 10 | 11 | pretrain_settings = { 12 | 'celebm/448': { 13 | 'url': [ 14 | 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt', 15 | ], 16 | 'matrix_src_tag': 'points', 17 | 'get_matrix_fn': functools.partial(get_face_align_matrix, 18 | target_shape=(448, 448), target_face_scale=0.8), 19 | 'get_grid_fn': functools.partial(make_tanh_warp_grid, 20 | warp_factor=0.0, warped_shape=(448, 448)), 21 | 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid, 22 | warp_factor=0.0, warped_shape=(448, 448)), 23 | 'label_names': ['background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're', 24 | 'le', 'nose', 'imouth', 'llip', 'ulip', 'hair', 25 | 'glass', 'hat', 'earr', 'neckl'] 26 | }, 27 | 'lapa/448': { 28 | 'url': [ 29 | 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit.pt', 30 | ], 31 | 'matrix_src_tag': 'points', 32 | 'get_matrix_fn': functools.partial(get_face_align_matrix, 33 | target_shape=(448, 448), target_face_scale=1.0), 34 | 'get_grid_fn': functools.partial(make_tanh_warp_grid, 35 | warp_factor=0.8, warped_shape=(448, 448)), 36 | 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid, 37 | warp_factor=0.8, warped_shape=(448, 448)), 38 | 'label_names': ['background', 'face', 'rb', 'lb', 're', 39 | 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] 40 | } 41 | } 42 | 43 | 44 | class FaRLFaceParser(FaceParser): 45 | """ The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL). 46 | 47 | Please consider citing 48 | ```bibtex 49 | @article{zheng2021farl, 50 | title={General Facial Representation Learning in a Visual-Linguistic Manner}, 51 | author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, 52 | Dongdong and Huang, Yangyu and Yuan, Lu and Chen, 53 | Dong and Zeng, Ming and Wen, Fang}, 54 | journal={arXiv preprint arXiv:2112.03109}, 55 | year={2021} 56 | } 57 | ``` 58 | """ 59 | 60 | def __init__(self, conf_name: Optional[str] = None, 61 | model_path: Optional[str] = None) -> None: 62 | super().__init__() 63 | if conf_name is None: 64 | conf_name = 'lapa/448' 65 | if model_path is None: 66 | model_path = pretrain_settings[conf_name]['url'] 67 | self.conf_name = conf_name 68 | self.net = download_jit(model_path) 69 | self.eval() 70 | 71 | def forward(self, images: torch.Tensor, data: Dict[str, Any]): 72 | setting = pretrain_settings[self.conf_name] 73 | images = images.float() / 255.0 74 | _, _, h, w = images.shape 75 | 76 | simages = images[data['image_ids']] 77 | matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']]) 78 | grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w)) 79 | inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w)) 80 | 81 | w_images = F.grid_sample( 82 | simages, grid, mode='bilinear', align_corners=False) 83 | 84 | w_seg_logits, _ = self.net(w_images) # (b*n) x c x h x w 85 | 86 | seg_logits = F.grid_sample( 87 | w_seg_logits, inv_grid, mode='bilinear', align_corners=False) 88 | 89 | data['seg'] = {'logits': seg_logits, 90 | 'label_names': setting['label_names']} 91 | return data 92 | -------------------------------------------------------------------------------- /inference/hyper_infer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from criteria.lpips.lpips import LPIPS 3 | import math 4 | 5 | from models.hypernetworks.hypernetwork import SharedWeightsHyperNetResNet, SharedWeightsHyperNetResNetSeparable 6 | from models.stylegan2.model import Generator 7 | from models.encoders.psp_encoders import ResidualAligner, ResidualEncoder 8 | from models.encoder import get_keys 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | from utils.train_utils import load_train_checkpoint, convert_weight 13 | from inference.inference import BaseInference 14 | 15 | 16 | class HyperstyleInference(BaseInference): 17 | 18 | def __init__(self, opts, decoder=None): 19 | super(HyperstyleInference, self).__init__() 20 | self.opts = opts 21 | self.device = 'cuda' 22 | self.opts.device = self.device 23 | self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2 24 | 25 | # resume from checkpoint 26 | # TODO: hyperstyle ckpt load 27 | checkpoint = torch.load(opts.hypernet_checkpoint_path, map_location='cpu') 28 | checkpoint = convert_weight(checkpoint, opts) 29 | 30 | # initialize encoder and decoder 31 | if decoder is not None: 32 | self.decoder = decoder 33 | else: 34 | self.decoder = Generator(opts.resolution, 512, 8).to(self.device) 35 | self.decoder.eval() 36 | if checkpoint is not None: 37 | self.decoder.load_state_dict(checkpoint['decoder'], strict=True) 38 | else: 39 | decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu') 40 | self.decoder.load_state_dict(decoder_checkpoint['g_ema']) 41 | 42 | if self.opts.hyperstyle_encoder_type == "SharedWeightsHyperNetResNet": 43 | self.hypernet = SharedWeightsHyperNetResNet(opts=self.opts).to(self.device) 44 | elif self.opts.hyperstyle_encoder_type == "SharedWeightsHyperNetResNetSeparable": 45 | self.hypernet = SharedWeightsHyperNetResNetSeparable(opts=self.opts).to(self.device) 46 | self.hypernet.eval() 47 | self.hypernet.load_state_dict(checkpoint['hypernet'], strict=True) 48 | 49 | def inverse(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info): 50 | with torch.no_grad(): 51 | weights_deltas = None 52 | for iter in range(self.opts.hyperstyle_iteration): 53 | emb_images = torch.nn.AdaptiveAvgPool2d((256, 256))(emb_images) 54 | x_input = torch.cat([images_resize, emb_images], dim=1) 55 | 56 | hypernet_outputs = self.hypernet(x_input) 57 | if weights_deltas is None: 58 | weights_deltas = hypernet_outputs 59 | else: 60 | weights_deltas = [weights_deltas[i] + hypernet_outputs[i] if weights_deltas[i] is not None else None 61 | for i in range(len(hypernet_outputs))] 62 | 63 | emb_images, result_latent = self.decoder([emb_codes], 64 | weights_deltas=weights_deltas, 65 | input_is_latent=True, 66 | randomize_noise=False, 67 | return_latents=True) 68 | # for path, inv_img in zip(image_path, images): 69 | # basename = os.path.basename(path).split('.')[0] + '_' + str(iter) 70 | # inv_result = tensor2im(inv_img) 71 | # inv_result.save(os.path.join(self.opts.output_dir, 'inversion', f'{basename}.jpg')) 72 | 73 | return emb_images, result_latent, None 74 | 75 | def edit(self, images, images_resize, image_paths, emb_codes, emb_images, emb_info, editor): 76 | images, codes, refine_info = self.inverse(images, images_resize, image_paths, emb_codes, emb_images, emb_info) 77 | refine_info = refine_info[0] 78 | with torch.no_grad(): 79 | decoder = Generator(self.opts.resolution, 512, 8).to(self.device) 80 | decoder.train() 81 | decoder.load_state_dict(refine_info['generator'], strict=True) 82 | edit_codes = editor.edit_code(codes) 83 | 84 | edit_images, edit_codes = decoder([edit_codes], input_is_latent=True, randomize_noise=False) 85 | return images, edit_images, codes, edit_codes, refine_info -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.encoders import psp_encoders 4 | from configs.paths_config import model_paths 5 | from loguru import logger 6 | from torch.nn.parallel import DistributedDataParallel 7 | 8 | 9 | def get_keys(d, name): 10 | if 'state_dict' in d: 11 | d = d['state_dict'] 12 | d_ = dict() 13 | for k, v in d.items(): 14 | if k.startswith('module.'): 15 | d_[k[7:]] = v 16 | else: 17 | d_[k] = v 18 | d_filt = {k[len(name) + 1:]: v for k, v in d_.items() if k[:len(name)] == name} 19 | 20 | return d_filt 21 | 22 | 23 | class Encoder(nn.Module): 24 | 25 | def __init__(self, opts, checkpoint=None, latent_avg=None, device="cuda"): 26 | super(Encoder, self).__init__() 27 | self.opts = opts 28 | 29 | # Define architecture 30 | self.encoder = self.set_encoder().to(device) 31 | self.log_parameters() 32 | self.load_weights(checkpoint, latent_avg) 33 | 34 | if 'dist' in opts and opts.dist: 35 | self.encoder = DistributedDataParallel(self.encoder, device_ids=[torch.cuda.current_device()], 36 | find_unused_parameters=True) 37 | self.dist = True 38 | else: 39 | self.dist = False 40 | 41 | def log_parameters(self): 42 | parameter = 0 43 | for v in list(self.encoder.parameters()): 44 | parameter += v.view(-1).shape[0] 45 | logger.info(f'Encoder parameters: {parameter/1e6}M.') 46 | self.opts.parameters = parameter 47 | 48 | def set_encoder(self): 49 | if self.opts.encoder_type == 'GradualStyleEncoder': 50 | encoder = psp_encoders.GradualStyleEncoder(self.opts.layers, 'ir_se', self.opts) 51 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': 52 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(self.opts.layers, 'ir_se', self.opts) 53 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': 54 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(self.opts.layers, 'ir_se', self.opts) 55 | elif self.opts.encoder_type == 'Encoder4Editing': 56 | encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) 57 | elif self.opts.encoder_type == 'ProgressiveBackboneEncoder': 58 | encoder = psp_encoders.ProgressiveBackboneEncoder(50, 'ir_se', self.opts) 59 | else: 60 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 61 | return encoder 62 | 63 | @staticmethod 64 | def check_module(ckpt, module_name): 65 | for key in ckpt['state_dict'].keys(): 66 | if key.startswith(f'{module_name}.'): 67 | return True 68 | return False 69 | 70 | def set_progressive_stage(self, stage): 71 | if self.dist: 72 | self.encoder.module.progressive_stage = stage 73 | else: 74 | self.encoder.progressive_stage = stage 75 | 76 | def load_weights(self, checkpoint, latent_avg): 77 | if checkpoint is not None: 78 | logger.info('Loading Encoder from checkpoint: {}'.format(self.opts.checkpoint_path)) 79 | encoder_load_status = self.encoder.load_state_dict(get_keys(checkpoint['encoder'], 'encoder'), strict=False) 80 | latent_avg = checkpoint['encoder']['latent_avg'] 81 | logger.info(f"encoder loading results: {encoder_load_status}") 82 | else: 83 | if self.opts.layers == 50: 84 | logger.info('Loading encoders weights from irse50!') 85 | encoder_ckpt = torch.load(model_paths['ir_se50'], map_location='cpu') 86 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 87 | else: 88 | logger.warning("Randomly initialize the Encoder!") 89 | 90 | self.register_buffer("latent_avg", latent_avg) 91 | 92 | def forward(self, x): 93 | codes = self.encoder(x) 94 | 95 | # normalize with respect to the center of an average latent codes 96 | if self.opts.start_from_latent_avg: 97 | if self.opts.learn_in_w or codes.dim() == 2: 98 | codes = codes + self.latent_avg[0].repeat(codes.shape[0], 1) 99 | else: 100 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 101 | 102 | return codes 103 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import tqdm 9 | from PIL import Image 10 | from torch.utils.data import DataLoader 11 | from datasets.inference_dataset import InversionDataset 12 | from inference import TwoStageInference 13 | from utils.common import tensor2im 14 | from options.test_options import TestOptions 15 | import torchvision.transforms as transforms 16 | from criteria.lpips.lpips import LPIPS 17 | 18 | 19 | def main(): 20 | opts = TestOptions().parse() 21 | if opts.checkpoint_path is None: 22 | opts.auto_resume = True 23 | 24 | if opts.output_dir is None: 25 | opts.output_dir = os.path.join(opts.exp_dir, 'inference_results') 26 | os.makedirs(opts.output_dir, exist_ok=True) 27 | os.makedirs(os.path.join(opts.output_dir, 'inversion'), exist_ok=True) 28 | 29 | inversion = TwoStageInference(opts) 30 | lpips_cri = LPIPS(net_type='alex').cuda().eval() 31 | 32 | float2uint2float = lambda x: (((x + 1) / 2 * 255.).clamp(min=0, max=255).to(torch.uint8).float().div(255.) - 0.5) / 0.5 33 | 34 | if opts.output_resolution is not None and len(opts.output_resolution) == 1: 35 | opts.output_resolution = (opts.output_resolution, opts.output_resolution) 36 | 37 | transform = transforms.Compose([ 38 | transforms.Resize((256, 256)), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 41 | transform_no_resize = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 44 | 45 | if os.path.isdir(opts.test_dataset_path): 46 | dataset = InversionDataset(root=opts.test_dataset_path, transform=transform, 47 | transform_no_resize=transform_no_resize) 48 | dataloader = DataLoader(dataset, 49 | batch_size=opts.test_batch_size, 50 | shuffle=False, 51 | num_workers=int(opts.test_workers), 52 | drop_last=False) 53 | else: 54 | img = Image.open(opts.test_dataset_path) 55 | img = img.convert('RGB') 56 | img_aug = transform(img) 57 | img_aug_no_resize = transform_no_resize(img) 58 | dataloader = [(img_aug[None], [opts.test_dataset_path], img_aug_no_resize[None])] 59 | 60 | lpips, count = 0, 0. 61 | mse, psnr, id = torch.zeros([0]).cuda(), torch.zeros([0]).cuda(), torch.zeros([0]).cuda() 62 | for input_batch in tqdm.tqdm(dataloader): 63 | # Inversion 64 | images_resize, img_paths, images = input_batch 65 | images_resize, images = images_resize.cuda(), images.cuda() 66 | count += len(img_paths) 67 | emb_images, emb_codes, emb_info, refine_images, refine_codes, refine_info = \ 68 | inversion.inverse(images, images_resize, img_paths) 69 | H, W = emb_images.shape[2:] 70 | if refine_images is not None: 71 | images_inv, codes = refine_images, refine_codes 72 | else: 73 | images_inv, codes = emb_images, emb_codes 74 | 75 | for path, inv_img in zip(img_paths, images_inv): 76 | basename = os.path.basename(path).split('.')[0] 77 | if opts.output_resolution is not None and ((H, W) != opts.output_resolution): 78 | inv_img = torch.nn.functional.resize(inv_img, opts.output_resolution) 79 | inv_result = tensor2im(inv_img) 80 | inv_result.save(os.path.join(opts.output_dir, 'inversion', f'{basename}.png')) 81 | 82 | # Evaluation 83 | images_inv = float2uint2float(images_inv) 84 | images_inv_resize = transforms.Resize((256, 256), antialias=True)(images_inv) 85 | batch_mse, batch_psnr = calculate_mse_and_psnr(images_inv_resize, images_resize) 86 | batch_lpips = lpips_cri(images_inv_resize, images_resize) 87 | 88 | mse = torch.cat([mse, batch_mse]) 89 | psnr = torch.cat([psnr, batch_psnr]) 90 | lpips += len(img_paths) * batch_lpips.item() 91 | print(f'Batch result: MSE {batch_mse.mean().item()}, PSNR {batch_psnr.mean().item()}') 92 | 93 | print('MSE ', mse.mean().item()) 94 | print('PSNR:', psnr.mean().item()) 95 | print('LPIPS:', lpips / count) 96 | 97 | 98 | def calculate_mse_and_psnr(img1, img2): 99 | mse = ((img1 - img2) ** 2).mean(dim=[1, 2, 3]) 100 | psnr = 10 * torch.log10(2 / mse) 101 | return mse, psnr 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /models/invertibility/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /models/hypernetworks/hypernetwork.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import BatchNorm2d, PReLU, Sequential, Module 3 | from torchvision.models import resnet34 4 | 5 | from models.hypernetworks.refinement_blocks import HyperRefinementBlock, RefinementBlock, RefinementBlockSeparable 6 | from models.hypernetworks.shared_weights_hypernet import SharedWeightsHypernet 7 | 8 | 9 | class SharedWeightsHyperNetResNet(Module): 10 | 11 | def __init__(self, opts): 12 | super(SharedWeightsHyperNetResNet, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(opts.hypernet_input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 15 | self.bn1 = BatchNorm2d(64) 16 | self.relu = PReLU(64) 17 | 18 | resnet_basenet = resnet34(pretrained=True) 19 | blocks = [ 20 | resnet_basenet.layer1, 21 | resnet_basenet.layer2, 22 | resnet_basenet.layer3, 23 | resnet_basenet.layer4 24 | ] 25 | modules = [] 26 | for block in blocks: 27 | for bottleneck in block: 28 | modules.append(bottleneck) 29 | self.body = Sequential(*modules) 30 | 31 | if len(opts.layers_to_tune) == 0: 32 | self.layers_to_tune = list(range(opts.n_hypernet_outputs)) 33 | else: 34 | self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] 35 | 36 | self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] 37 | self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None) 38 | 39 | self.refinement_blocks = nn.ModuleList() 40 | self.n_outputs = opts.n_hypernet_outputs 41 | for layer_idx in range(self.n_outputs): 42 | if layer_idx in self.layers_to_tune: 43 | if layer_idx in self.shared_layers: 44 | refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) 45 | else: 46 | refinement_block = RefinementBlock(layer_idx, opts, n_channels=512, inner_c=256) 47 | else: 48 | refinement_block = None 49 | self.refinement_blocks.append(refinement_block) 50 | 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = self.bn1(x) 54 | x = self.relu(x) 55 | x = self.body(x) 56 | weight_deltas = [] 57 | for j in range(self.n_outputs): 58 | if self.refinement_blocks[j] is not None: 59 | delta = self.refinement_blocks[j](x) 60 | else: 61 | delta = None 62 | weight_deltas.append(delta) 63 | return weight_deltas 64 | 65 | 66 | class SharedWeightsHyperNetResNetSeparable(Module): 67 | 68 | def __init__(self, opts): 69 | super(SharedWeightsHyperNetResNetSeparable, self).__init__() 70 | 71 | self.conv1 = nn.Conv2d(opts.hypernet_input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 72 | self.bn1 = BatchNorm2d(64) 73 | self.relu = PReLU(64) 74 | 75 | resnet_basenet = resnet34(pretrained=True) 76 | blocks = [ 77 | resnet_basenet.layer1, 78 | resnet_basenet.layer2, 79 | resnet_basenet.layer3, 80 | resnet_basenet.layer4 81 | ] 82 | modules = [] 83 | for block in blocks: 84 | for bottleneck in block: 85 | modules.append(bottleneck) 86 | self.body = Sequential(*modules) 87 | 88 | if len(opts.layers_to_tune) == 0: 89 | self.layers_to_tune = list(range(opts.n_hypernet_outputs)) 90 | else: 91 | self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')] 92 | 93 | self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12] 94 | self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None) 95 | 96 | self.refinement_blocks = nn.ModuleList() 97 | self.n_outputs = opts.n_hypernet_outputs 98 | for layer_idx in range(self.n_outputs): 99 | if layer_idx in self.layers_to_tune: 100 | if layer_idx in self.shared_layers: 101 | refinement_block = HyperRefinementBlock(self.shared_weight_hypernet, n_channels=512, inner_c=128) 102 | else: 103 | refinement_block = RefinementBlockSeparable(layer_idx, opts, n_channels=512, inner_c=256) 104 | else: 105 | refinement_block = None 106 | self.refinement_blocks.append(refinement_block) 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.body(x) 113 | weight_deltas = [] 114 | for j in range(self.n_outputs): 115 | if self.refinement_blocks[j] is not None: 116 | delta = self.refinement_blocks[j](x) 117 | else: 118 | delta = None 119 | weight_deltas.append(delta) 120 | return weight_deltas 121 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from configs.paths_config import model_paths 3 | from utils.dist import init_dist, get_dist_info 4 | from options.base_options import BaseOptions, str2bool 5 | 6 | 7 | class TrainOptions(BaseOptions): 8 | def initialize(self): 9 | super(TrainOptions, self).initialize() 10 | 11 | # 1. Basic training options 12 | self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps.') 13 | self.parser.add_argument('--image_interval', default=500, type=int, 14 | help='Interval for logging train images during training.') 15 | self.parser.add_argument('--board_interval', default=50, type=int, 16 | help='Interval for logging metrics to tensorboard') 17 | self.parser.add_argument('--val_interval', default=5000, type=int, help='Validation interval.') 18 | self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval.') 19 | self.parser.add_argument('--start_step', default=0, type=int, help='Initial step.') 20 | self.parser.add_argument('--seed', default=0, type=int, help="Random seed.") 21 | 22 | # optimizer 23 | self.parser.add_argument('--optimizer', default='ranger', type=str, help='Which optimizer to use.') 24 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate.') 25 | self.parser.add_argument('--weight_decay', default=0., type=float, help='Weight decay.') 26 | self.parser.add_argument('--optim_beta1', default=0.95, type=float, help='beta1.') 27 | self.parser.add_argument('--optim_beta2', default=0.999, type=float, help='beta2.') 28 | 29 | # Wandb 30 | self.parser.add_argument('--use_wandb', default=False, type=str2bool, 31 | help='Whether to use Weights & Biases to track experiment.') 32 | self.parser.add_argument('--wandb_project', default='GAN_Inverter', type=str, help='continue to train.') 33 | 34 | # 2. Loss options 35 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor.') 36 | self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor.') 37 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor.') 38 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor.') 39 | self.parser.add_argument('--moco_lambda', default=0, type=float, 40 | help='Moco-based feature similarity loss multiplier factor.') 41 | self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas") 42 | self.parser.add_argument('--delta_norm_lambda', type=float, default=0., help="lambda for delta norm loss") 43 | 44 | # e4e 45 | self.parser.add_argument('--w_discriminator_lambda', default=0., type=float, help='Dw loss multiplier.') 46 | self.parser.add_argument("--r1", type=float, default=10, help="Weight of the r1 regularization.") 47 | self.parser.add_argument("--d_reg_every", type=int, default=16, 48 | help="Interval for applying r1 regularization.") 49 | self.parser.add_argument('--discriminator_lr', default=2e-5, type=float, help='Dw learning rate') 50 | 51 | self.parser.add_argument('--use_w_pool', action='store_true', 52 | help='Whether to store a latnet codes pool for the discriminator\'s training') 53 | self.parser.add_argument("--w_pool_size", type=int, default=50, 54 | help="W\'s pool size, depends on --use_w_pool") 55 | 56 | self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None, 57 | help="The training steps of training new deltas. steps[i] starts the delta_i training") 58 | self.parser.add_argument('--progressive_start', type=int, default=0, 59 | help="The training step to start training the deltas, overrides progressive_steps") 60 | self.parser.add_argument('--progressive_step_every', type=int, default=0, 61 | help="Amount of training steps for each progressive step") 62 | 63 | # lsap 64 | self.parser.add_argument('--sncd_lambda', default=0., type=float, help='SNCD loss multiplier factor.') 65 | 66 | # HFGI 67 | self.parser.add_argument('--distortion_scale', type=float, default=0.15, help="lambda for delta norm loss") 68 | self.parser.add_argument('--aug_rate', type=float, default=0.8, help="lambda for delta norm loss") 69 | self.parser.add_argument('--res_lambda', default=0., type=float, help='L2 loss multiplier factor') 70 | 71 | # Distributed Training 72 | self.parser.add_argument('--local_rank', default=0, type=int, help='local rank for distributed training.') 73 | self.parser.add_argument('--gpu_num', default=1, type=int, help='num of gpu.') 74 | 75 | def parse(self): 76 | opts = super(TrainOptions, self).parse() 77 | opts.dist = True if opts.gpu_num != 1 else False 78 | if opts.dist: 79 | init_dist() 80 | opts.rank, opts.world_size = get_dist_info() 81 | return opts 82 | -------------------------------------------------------------------------------- /scripts/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import cv2 8 | import tqdm 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from torch.utils.data import DataLoader 13 | from datasets.inference_dataset import InversionDataset 14 | from inference import TwoStageInference 15 | from utils.common import tensor2im 16 | from options.test_options import TestOptions 17 | import torchvision.transforms as transforms 18 | 19 | 20 | def save_intermediate(info_dict, output_dir, basename, keys): 21 | if info_dict is None: 22 | return None 23 | for k, v in info_dict.items(): 24 | if keys is not None and k not in k: 25 | continue 26 | os.makedirs(os.path.join(output_dir, k), exist_ok=True) 27 | if isinstance(v, torch.Tensor): 28 | # image tensor 29 | if v.dim() == 4 and v.shape[0] == 1 and v.shape[1] == 3: 30 | img = tensor2im(v[0]) 31 | img.save(os.path.join(output_dir, k, f'{basename}.png')) 32 | elif v.dim() == 3 and v.shape[0] == 3: 33 | img = tensor2im(v) 34 | img.save(os.path.join(output_dir, k, f'{basename}.png')) 35 | else: # tensor but not image 36 | torch.save(v, os.path.join(output_dir, k, f'{basename}.pt')) 37 | # model weight 38 | elif (isinstance(v, dict) and isinstance(list(v.values())[0], torch.Tensor)): 39 | torch.save(v, os.path.join(output_dir, k, f'{basename}.pt')) 40 | # numpy array 41 | elif isinstance(v, np.ndarray): 42 | if v.dtype == np.uint8: 43 | cv2.imwrite(os.path.join(output_dir, k, f'{basename}.png'), v) 44 | else: 45 | np.save(os.path.join(output_dir, k, f'{basename}.npy'), v) 46 | else: 47 | raise Exception('Intermediate information can not be saved:', k) 48 | 49 | 50 | def main(): 51 | opts = TestOptions().parse() 52 | if opts.checkpoint_path is None: 53 | opts.auto_resume = True 54 | 55 | inversion = TwoStageInference(opts) 56 | 57 | if opts.output_dir is None: 58 | opts.output_dir = os.path.join(opts.exp_dir, 'inference_results') 59 | os.makedirs(opts.output_dir, exist_ok=True) 60 | os.makedirs(os.path.join(opts.output_dir, 'inversion'), exist_ok=True) 61 | 62 | if opts.save_code: 63 | os.makedirs(os.path.join(opts.output_dir, 'code'), exist_ok=True) 64 | 65 | if opts.output_resolution is not None and len(opts.output_resolution) == 1: 66 | opts.output_resolution = (opts.output_resolution, opts.output_resolution) 67 | 68 | transform = transforms.Compose([ 69 | transforms.Resize((256, 256)), 70 | transforms.ToTensor(), 71 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 72 | transform_no_resize = transforms.Compose([ 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 75 | 76 | if os.path.isdir(opts.test_dataset_path): 77 | dataset = InversionDataset(root=opts.test_dataset_path, transform=transform, 78 | transform_no_resize=transform_no_resize) 79 | dataloader = DataLoader(dataset, 80 | batch_size=opts.test_batch_size, 81 | shuffle=False, 82 | num_workers=int(opts.test_workers), 83 | drop_last=False) 84 | else: 85 | img = Image.open(opts.test_dataset_path) 86 | img = img.convert('RGB') 87 | img_aug = transform(img) 88 | img_aug_no_resize = transform_no_resize(img) 89 | dataloader = [(img_aug[None], [opts.test_dataset_path], img_aug_no_resize[None])] 90 | 91 | for input_batch in tqdm.tqdm(dataloader): 92 | images_resize, img_paths, images = input_batch 93 | images_resize, images = images_resize.cuda(), images.cuda() 94 | emb_images, emb_codes, emb_info, refine_images, refine_codes, refine_info = \ 95 | inversion.inverse(images, images_resize, img_paths) 96 | 97 | H, W = emb_images.shape[2:] 98 | if refine_images is not None: 99 | images, codes = refine_images, refine_codes 100 | else: 101 | images, codes = emb_images, emb_codes 102 | 103 | emb_info = [None] * len(img_paths) if emb_info is None else emb_info 104 | refine_info = [None] * len(img_paths) if refine_info is None else refine_info 105 | 106 | for path, inv_img, code, e_info, r_info in zip(img_paths, images, codes, emb_info, refine_info): 107 | basename = os.path.basename(path).split('.')[0] 108 | if opts.save_code: 109 | torch.save(code, os.path.join(opts.output_dir, 'code', f'{basename}.pt')) 110 | if opts.output_resolution is not None and ((H, W) != opts.output_resolution): 111 | inv_img = torch.nn.functional.resize(inv_img, opts.output_resolution) 112 | inv_result = tensor2im(inv_img) 113 | inv_result.save(os.path.join(opts.output_dir, 'inversion', f'{basename}.png')) 114 | 115 | # save intermediate info 116 | if opts.save_intermediate: 117 | save_intermediate(e_info, opts.output_dir, basename, opts.save_keys) 118 | save_intermediate(r_info, opts.output_dir, basename, opts.save_keys) 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | --------------------------------------------------------------------------------