├── models ├── __init__.py ├── stylegan2 │ ├── __init__.py │ └── op │ │ ├── __init__.py │ │ ├── fused_bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── fused_act.py │ │ └── fused_bias_act_kernel.cu └── stylegan3 │ ├── __init__.py │ └── model.py ├── utils ├── __init__.py ├── model_utils.py ├── data_utils.py ├── common.py ├── train_utils.py └── inference_utils.py ├── configs ├── __init__.py ├── data_configs.py ├── transforms_config.py └── paths_config.py ├── criteria ├── __init__.py ├── lpips │ ├── __init__.py │ ├── utils.py │ ├── lpips.py │ └── networks.py ├── w_norm.py ├── clip_loss.py ├── id_loss.py ├── moco_loss.py └── ms_ssim.py ├── editing ├── __init__.py ├── interfacegan │ ├── __init__.py │ ├── helpers │ │ ├── __init__.py │ │ ├── anycostgan.py │ │ └── manipulator.py │ ├── face_editor.py │ ├── train_boundaries.py │ └── generate_latents_and_attribute_scores.py ├── styleclip_mapper │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── latents_dataset.py │ ├── options │ │ ├── __init__.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── scripts │ │ ├── __init__.py │ │ ├── train.py │ │ └── inference.py │ ├── training │ │ └── __init__.py │ ├── styleclip_mapper.py │ └── latent_mappers.py └── styleclip_global_directions │ ├── __init__.py │ ├── preprocess │ ├── __init__.py │ ├── download_all_files.py │ ├── s_statistics.py │ └── create_delta_i_c.py │ ├── templates.txt │ └── global_direction.py ├── inversion ├── __init__.py ├── models │ ├── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ ├── map2style.py │ │ ├── model_irse.py │ │ ├── restyle_psp_encoders.py │ │ └── helpers.py │ ├── mtcnn │ │ ├── __init__.py │ │ └── mtcnn_pytorch │ │ │ ├── __init__.py │ │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── weights │ │ │ ├── onet.npy │ │ │ ├── pnet.npy │ │ │ └── rnet.npy │ │ │ ├── visualization_utils.py │ │ │ ├── first_stage.py │ │ │ └── detector.py │ ├── e4e_modules │ │ ├── __init__.py │ │ ├── discriminator.py │ │ └── latent_codes_pool.py │ └── psp3.py ├── video │ ├── __init__.py │ ├── generate_videos.py │ ├── post_processing.py │ ├── video_handler.py │ └── video_config.py ├── datasets │ ├── __init__.py │ ├── gt_res_dataset.py │ ├── images_dataset.py │ ├── pti_dataset.py │ └── inference_dataset.py ├── options │ ├── __init__.py │ ├── e4e_train_options.py │ ├── test_options.py │ └── train_options.py ├── training │ └── __init__.py └── scripts │ ├── train_restyle_psp.py │ ├── train_restyle_e4e.py │ ├── calc_losses_on_images.py │ ├── calc_id_loss_parallel.py │ ├── inference_iterative.py │ └── create_inversion_animation.py ├── torch_utils ├── __init__.py └── ops │ ├── __init__.py │ ├── bias_act.h │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── filtered_lrelu_ns.cu │ ├── upfirdn2d.h │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── filtered_lrelu.h │ └── bias_act.cpp ├── prepare_data ├── __init__.py ├── compute_landmarks_transforms.py ├── landmarks_handler.py └── preparing_faces_parallel.py ├── docs ├── teaser.jpg ├── inversion.jpg ├── styleclip.jpg ├── styleclip_afhq.jpg ├── styleclip_ffhq.jpg ├── interfacegan_edits.jpg ├── interfacegan_smith.jpg ├── styleclip_landscapes.jpg ├── stylegan_nada_afhq.jpg ├── stylegan_nada_ffhq.jpg └── interfacegan_witherspoon.jpg ├── edit ├── pic │ ├── 001.jpg │ ├── 002.jpg │ ├── 003.jpg │ ├── 004.jpg │ └── 005.jpg └── video │ ├── 01.mp4 │ └── 02.mp4 ├── notebooks ├── images │ └── face_image.jpg └── notebook_utils.py ├── dnnlib └── __init__.py ├── LICENSE ├── environment └── sg3_env.yaml └── function.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/video/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/stylegan3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prepare_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/interfacegan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/interfacegan/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/models/e4e_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/teaser.jpg -------------------------------------------------------------------------------- /edit/pic/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/pic/001.jpg -------------------------------------------------------------------------------- /edit/pic/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/pic/002.jpg -------------------------------------------------------------------------------- /edit/pic/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/pic/003.jpg -------------------------------------------------------------------------------- /edit/pic/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/pic/004.jpg -------------------------------------------------------------------------------- /edit/pic/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/pic/005.jpg -------------------------------------------------------------------------------- /docs/inversion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/inversion.jpg -------------------------------------------------------------------------------- /docs/styleclip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/styleclip.jpg -------------------------------------------------------------------------------- /edit/video/01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/video/01.mp4 -------------------------------------------------------------------------------- /edit/video/02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/edit/video/02.mp4 -------------------------------------------------------------------------------- /docs/styleclip_afhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/styleclip_afhq.jpg -------------------------------------------------------------------------------- /docs/styleclip_ffhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/styleclip_ffhq.jpg -------------------------------------------------------------------------------- /docs/interfacegan_edits.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/interfacegan_edits.jpg -------------------------------------------------------------------------------- /docs/interfacegan_smith.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/interfacegan_smith.jpg -------------------------------------------------------------------------------- /docs/styleclip_landscapes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/styleclip_landscapes.jpg -------------------------------------------------------------------------------- /docs/stylegan_nada_afhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/stylegan_nada_afhq.jpg -------------------------------------------------------------------------------- /docs/stylegan_nada_ffhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/stylegan_nada_ffhq.jpg -------------------------------------------------------------------------------- /notebooks/images/face_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/notebooks/images/face_image.jpg -------------------------------------------------------------------------------- /docs/interfacegan_witherspoon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/docs/interfacegan_witherspoon.jpg -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import detect_faces 2 | from .visualization_utils import show_bboxes 3 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/inversion/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/inversion/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylegan3-editing/main/inversion/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # specify the encoder types for pSp and e4e - this is mainly used for the inference alignment 2 | ENCODER_TYPES = { 3 | 'pSp': ['BackboneEncoder', 'ResNetBackboneEncoder'], 4 | 'e4e': ['ProgressiveBackboneEncoder', 'ResNetProgressiveBackboneEncoder'] 5 | } 6 | -------------------------------------------------------------------------------- /configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from configs import transforms_config 2 | from configs.paths_config import dataset_paths 3 | 4 | 5 | DATASETS = { 6 | 'ffhq_encode': { 7 | 'transforms': transforms_config.EncodeTransforms, 8 | 'train_source_root': dataset_paths['ffhq'], 9 | 'train_target_root': dataset_paths['ffhq'], 10 | 'test_source_root': dataset_paths['celeba_test'], 11 | 'test_target_root': dataset_paths['celeba_test'] 12 | } 13 | } -------------------------------------------------------------------------------- /criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/datasets/latents_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class LatentsDataset(Dataset): 6 | 7 | def __init__(self, latents, opts, transforms=None): 8 | self.latents = latents 9 | self.transforms = transforms 10 | self.opts = opts 11 | 12 | def __len__(self): 13 | return self.latents.shape[0] 14 | 15 | def __getitem__(self, index): 16 | if self.transforms is not None: 17 | return self.latents[index], torch.from_numpy(self.transforms[index][3]).float() 18 | return self.latents[index] 19 | -------------------------------------------------------------------------------- /criteria/clip_loss.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | 4 | 5 | class CLIPLoss(torch.nn.Module): 6 | 7 | def __init__(self, opts): 8 | super(CLIPLoss, self).__init__() 9 | self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") 10 | self.upsample = torch.nn.Upsample(scale_factor=7) 11 | self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) 12 | 13 | def forward(self, image, text): 14 | image = self.avg_pool(self.upsample(image)) 15 | similarity = 1 - self.model(image, text)[0] / 100 16 | return similarity 17 | -------------------------------------------------------------------------------- /inversion/models/e4e_modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LatentCodesDiscriminator(nn.Module): 5 | def __init__(self, style_dim, n_mlp): 6 | super().__init__() 7 | 8 | self.style_dim = style_dim 9 | 10 | layers = [] 11 | for i in range(n_mlp-1): 12 | layers.append( 13 | nn.Linear(style_dim, style_dim) 14 | ) 15 | layers.append(nn.LeakyReLU(0.2)) 16 | layers.append(nn.Linear(512, 1)) 17 | self.mlp = nn.Sequential(*layers) 18 | 19 | def forward(self, w): 20 | return self.mlp(w) 21 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD (https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py) 3 | """ 4 | from pathlib import Path 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 9 | ] 10 | 11 | 12 | def is_image_file(filename: Path): 13 | return any(str(filename).endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | 16 | def make_dataset(dir: Path): 17 | images = [] 18 | assert dir.is_dir(), '%s is not a valid directory' % dir 19 | for fname in dir.glob("*"): 20 | if is_image_file(fname): 21 | path = dir / fname 22 | images.append(path) 23 | return images 24 | -------------------------------------------------------------------------------- /inversion/scripts/train_restyle_psp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pprint 3 | import sys 4 | 5 | import dataclasses 6 | import pyrallis 7 | 8 | sys.path.append(".") 9 | sys.path.append("..") 10 | 11 | from inversion.options.train_options import TrainOptions 12 | from inversion.training.coach_restyle_psp import Coach 13 | 14 | 15 | @pyrallis.wrap() 16 | def main(opts: TrainOptions): 17 | opts.exp_dir.mkdir(exist_ok=True, parents=True) 18 | 19 | opts_dict = dataclasses.asdict(opts) 20 | pprint.pprint(opts_dict) 21 | with open(opts.exp_dir / 'opt.json', 'w') as f: 22 | json.dump(opts_dict, f, indent=4, sort_keys=True, default=str) 23 | 24 | coach = Coach(opts) 25 | coach.train() 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import json 5 | import os 6 | import pprint 7 | import sys 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from editing.styleclip_mapper.options.train_options import TrainOptions 13 | from editing.styleclip_mapper.training.coach import Coach 14 | 15 | 16 | def main(opts): 17 | if os.path.exists(opts.exp_dir): 18 | raise Exception('Oops... {} already exists'.format(opts.exp_dir)) 19 | os.makedirs(opts.exp_dir, exist_ok=True) 20 | 21 | opts_dict = vars(opts) 22 | pprint.pprint(opts_dict) 23 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 24 | json.dump(opts_dict, f, indent=4, sort_keys=True, default=str) 25 | 26 | coach = Coach(opts) 27 | coach.train() 28 | 29 | 30 | if __name__ == '__main__': 31 | opts = TrainOptions().parse() 32 | main(opts) 33 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /inversion/models/encoders/map2style.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import Conv2d, Module 4 | 5 | from models.stylegan2.model import EqualLinear 6 | 7 | 8 | class GradualStyleBlock(Module): 9 | def __init__(self, in_c, out_c, spatial): 10 | super(GradualStyleBlock, self).__init__() 11 | self.out_c = out_c 12 | self.spatial = spatial 13 | num_pools = int(np.log2(spatial)) 14 | modules = [] 15 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 16 | for i in range(num_pools - 1): 17 | modules += [Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] 18 | self.convs = nn.Sequential(*modules) 19 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 20 | 21 | def forward(self, x): 22 | x = self.convs(x) 23 | x = x.view(-1, self.out_c) 24 | x = self.linear(x) 25 | return x 26 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /inversion/datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class GTResDataset(Dataset): 9 | 10 | def __init__(self, root_path: Path, gt_dir: Path, transform=None, transform_train=None): 11 | self.pairs = [] 12 | for f in root_path.glob("*"): 13 | image_path = root_path / f 14 | gt_path = gt_dir / f 15 | if f.suffix in [".jpg", ".png", ".jpeg"]: 16 | self.pairs.append([image_path, gt_path, None]) 17 | self.transform = transform 18 | self.transform_train = transform_train 19 | 20 | def __len__(self): 21 | return len(self.pairs) 22 | 23 | def __getitem__(self, index): 24 | from_path, to_path, _ = self.pairs[index] 25 | from_im = Image.open(from_path).convert('RGB') 26 | to_im = Image.open(to_path).convert('RGB') 27 | if self.transform: 28 | to_im = self.transform(to_im) 29 | from_im = self.transform(from_im) 30 | return from_im, to_im 31 | -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /inversion/datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | from utils import data_utils 6 | 7 | 8 | class ImagesDataset(Dataset): 9 | 10 | def __init__(self, source_root: Path, target_root: Path, target_transform=None, source_transform=None): 11 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 12 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 13 | self.source_transform = source_transform 14 | self.target_transform = target_transform 15 | 16 | def __len__(self): 17 | return len(self.source_paths) 18 | 19 | def __getitem__(self, index): 20 | from_path = self.source_paths[index] 21 | to_path = self.target_paths[index] 22 | 23 | from_im = Image.open(from_path).convert('RGB') 24 | to_im = Image.open(to_path).convert('RGB') 25 | 26 | if self.target_transform: 27 | to_im = self.target_transform(to_im) 28 | 29 | if self.source_transform: 30 | from_im = self.source_transform(from_im) 31 | else: 32 | from_im = to_im 33 | 34 | return from_im, to_im 35 | -------------------------------------------------------------------------------- /configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | 4 | 5 | class TransformsConfig: 6 | 7 | def __init__(self, opts): 8 | self.opts = opts 9 | 10 | @abstractmethod 11 | def get_transforms(self): 12 | pass 13 | 14 | 15 | class EncodeTransforms(TransformsConfig): 16 | 17 | def __init__(self, opts): 18 | super(EncodeTransforms, self).__init__(opts) 19 | 20 | def get_transforms(self): 21 | transforms_dict = { 22 | 'transform_gt_train': transforms.Compose([ 23 | transforms.Resize((256, 256)), 24 | transforms.RandomHorizontalFlip(0.5), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 27 | 'transform_source': None, 28 | 'transform_test': transforms.Compose([ 29 | transforms.Resize((256, 256)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 32 | 'transform_inference': transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 36 | } 37 | return transforms_dict 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yuval Alaluf, Or Patashnik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prepare_data/compute_landmarks_transforms.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pyrallis 4 | from dataclasses import dataclass 5 | 6 | from prepare_data.landmarks_handler import LandmarksHandler 7 | 8 | 9 | @dataclass 10 | class Options: 11 | 12 | """ Input Args """ 13 | # Path to raw images 14 | raw_root: Path 15 | # Path to aligned images 16 | aligned_root: Path 17 | # Path to cropped images 18 | cropped_root: Path 19 | 20 | """ Output Args """ 21 | # Path to output directory 22 | output_root: Path 23 | 24 | """ General Args """ 25 | # Replacing the landmarks if file already exist 26 | replace: bool = True 27 | 28 | 29 | @pyrallis.wrap() 30 | def main(args: Options): 31 | args.output_root.mkdir(exist_ok=True, parents=True) 32 | landmarks_handler = LandmarksHandler(args.output_root) 33 | input_images = list(args.raw_root.iterdir()) 34 | landmarks_handler.get_landmarks_transforms(input_paths=input_images, 35 | cropped_frames_path=args.cropped_root, 36 | aligned_frames_path=args.aligned_root, 37 | force_computing=args.replace) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /inversion/datasets/pti_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class PTIDataset(Dataset): 7 | 8 | def __init__(self, latents, targets, landmarks_transforms=None, transforms=None): 9 | self.latents = latents 10 | self.targets = targets 11 | if landmarks_transforms is not None: 12 | self.landmarks_transforms = [] 13 | for t in landmarks_transforms: 14 | if type(t) == np.ndarray: 15 | t = torch.from_numpy(t) 16 | self.landmarks_transforms.append(t.cpu().numpy()) 17 | else: 18 | self.landmarks_transforms = None 19 | self.transforms = transforms 20 | 21 | def __len__(self): 22 | return len(self.targets) 23 | 24 | def __getitem__(self, index): 25 | latent = self.latents[index] 26 | target = self.targets[index] 27 | landmarks_transforms = self.landmarks_transforms[index] if self.landmarks_transforms is not None else None 28 | if self.transforms is not None: 29 | target = self.transforms(target) 30 | if self.landmarks_transforms is not None: 31 | return target, latent, landmarks_transforms, index 32 | else: 33 | return target, latent, index 34 | -------------------------------------------------------------------------------- /environment/sg3_env.yaml: -------------------------------------------------------------------------------- 1 | name: sg3_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2020.4.5.1=hecc5488_0 8 | - certifi=2020.4.5.1=py36h9f0ad1d_0 9 | - libedit=3.1.20181209=hc058e9b_0 10 | - libffi=3.2.1=hd88cf55_4 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - ninja=1.10.0=hc9558a2_0 15 | - openssl=1.1.1g=h516909a_0 16 | - pip=20.0.2=py36_3 17 | - python=3.6.7=h0371630_0 18 | - python_abi=3.6=1_cp36m 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=46.4.0=py36_0 21 | - sqlite=3.31.1=h62c20be_1 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.5=h7b6447c_0 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - scipy==1.4.1 28 | - matplotlib==3.2.1 29 | - tqdm==4.46.0 30 | - numpy==1.18.4 31 | - opencv-python==4.2.0.34 32 | - pillow==7.1.2 33 | - tensorboard==2.2.1 34 | - torch==1.10.0 35 | - torchvision==0.11.1 36 | - scikit-learn==0.24.2 37 | - imageio==2.9.0 38 | - imageio-ffmpeg==0.4.3 39 | - dataclasses==0.8 40 | - dlib==19.22.1 41 | - ftfy==6.0.3 42 | - gdown==4.2.0 43 | - git+https://github.com/openai/CLIP.git 44 | - pyrallis==0.2.2 45 | prefix: ~/anaconda3/envs/sg3_env 46 | -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/preprocess/download_all_files.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | import os 4 | 5 | PATHS = { 6 | "sg3-r-ffhq-1024": {"delta_i_c.npy": "1HOUGvtumLFwjbwOZrTbIloAwBBzs2NBN", 7 | "s_stats": "1FVm_Eh7qmlykpnSBN1Iy533e_A2xM78z"}, 8 | "sg3-r-ffhqu-1024": {"delta_i_c.npy": "1EcLy3ya7p-cWs8kQZKgudyZnhGvsrRUO", 9 | "s_stats": "1It-M23K31ABgGiH7CAUfmmj_SEqFMfM_"}, 10 | "sg3-r-afhq-512": {"delta_i_c.npy": "1CKDn0BcbAosGLEYo4fW2YnAyaERvtJ7s", 11 | "s_stats": "1omJCjPSyamP01Pr1rPx0wO4eI1Jpohat"}, 12 | "sg3-t-landscape-256": {"delta_i_c.npy": "1Po4S_zPuefQZFttT4tW9z7dt-nu4P4iF", 13 | "s_stats": "12XqJ4DX31n2AtVpPFZXfOiUxTUJ5CxhK"} 14 | } 15 | 16 | 17 | def main(): 18 | save_dir = Path("editing") / "styleclip_global_directions" 19 | save_dir.mkdir(exist_ok=True, parents=True) 20 | for name, file_ids in PATHS.items(): 21 | model_dir = save_dir / name 22 | model_dir.mkdir(exist_ok=True, parents=True) 23 | print(f"Downloading models for {name}...") 24 | for file_name, file_id in file_ids.items(): 25 | subprocess.run(["gdown", "--id", file_id, "-O", model_dir / file_name]) 26 | # remove extra files 27 | try: 28 | for path in model_dir.glob("*"): 29 | if str(path.name).startswith("sg3"): 30 | os.remove(path) 31 | except Exception: 32 | pass 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /inversion/options/e4e_train_options.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | from dataclasses import dataclass, field 4 | 5 | from inversion.options.train_options import TrainOptions 6 | 7 | 8 | @dataclass 9 | class e4eTrainOptions(TrainOptions): 10 | """ Training args for e4e-based models. """ 11 | 12 | """ General args """ 13 | # Dw loss multiplier factor 14 | w_discriminator_lambda: float = 0 15 | # Dw learning rate 16 | w_discriminator_lr: float = 2e-5 17 | # Weight of the R1 regularization 18 | r1: float = 10 19 | # Interval for applying R1 regularization 20 | d_reg_every: int = 16 21 | # Whether to store a latent codes pool for the discriminator training 22 | use_w_pool: bool = True 23 | # W pool size when using pool for discrminator training 24 | w_pool_size: int = 50 25 | # Truncation psi for sampling real latents for discriminator 26 | truncation_psi: float = 1 27 | 28 | """ e4e modules args """ 29 | # Norm type for delta loss 30 | delta_norm: int = 2 31 | # Delta regularization loss weight 32 | delta_norm_lambda: float = 2e-4 33 | 34 | """ Progressive training args """ 35 | progressive_steps: Optional[List[int]] = None 36 | progressive_start: Optional[int] = None 37 | progressive_step_every: Optional[int] = 2000 38 | 39 | """ Saving and resume training args """ 40 | save_training_data: bool = False 41 | sub_exp_dir: Optional[str] = None 42 | resume_training_from_ckpt: Optional[Path] = None 43 | update_param_list: Optional[str] = None 44 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import imageio 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | def make_transform(translate: Tuple[float, float], angle: float): 10 | m = np.eye(3) 11 | s = np.sin(angle / 360.0 * np.pi * 2) 12 | c = np.cos(angle / 360.0 * np.pi * 2) 13 | m[0][0] = c 14 | m[0][1] = s 15 | m[0][2] = translate[0] 16 | m[1][0] = -s 17 | m[1][1] = c 18 | m[1][2] = translate[1] 19 | return m 20 | 21 | 22 | def get_identity_transform(): 23 | translate = (0, 0) 24 | rotate = 0. 25 | m = make_transform(translate, rotate) 26 | m = np.linalg.inv(m) 27 | return m 28 | 29 | 30 | def generate_random_transform(translate=0.3, rotate=25): 31 | rotate = np.random.uniform(low=-1 * rotate, high=rotate) 32 | translate = (np.random.uniform(low=-1 * translate, high=translate), 33 | np.random.uniform(low=-1 * translate, high=translate)) 34 | m = make_transform(translate, rotate) 35 | user_transforms = np.linalg.inv(m) 36 | return user_transforms 37 | 38 | 39 | def tensor2im(var: torch.tensor): 40 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 41 | var = ((var + 1) / 2) 42 | var[var < 0] = 0 43 | var[var > 1] = 1 44 | var = var * 255 45 | return Image.fromarray(var.astype('uint8')) 46 | 47 | 48 | def generate_mp4(out_name, images: List[np.ndarray], kwargs): 49 | writer = imageio.get_writer(str(out_name) + '.mp4', **kwargs) 50 | for image in images: 51 | writer.append_data(np.array(image)) 52 | writer.close() 53 | -------------------------------------------------------------------------------- /inversion/video/generate_videos.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from inversion.video.video_config import VideoConfig 6 | from inversion.video.video_handler import VideoHandler 7 | from utils.common import generate_mp4 8 | 9 | OUTPUT_SIZE = (1024, 1024) 10 | 11 | 12 | def generate_reconstruction_videos(input_images: List, result_images: List, result_images_smoothed: List, 13 | video_handler: VideoHandler, opts: VideoConfig): 14 | kwargs = {'fps': video_handler.fps} 15 | 16 | # save the original cropped input 17 | output_path = opts.output_path / 'input_video' 18 | generate_mp4(output_path, [np.array(im) for im in input_images], kwargs) 19 | 20 | # generate video of original reconstruction (without smoothing) 21 | output_path = opts.output_path / 'result_video' 22 | generate_mp4(output_path, result_images, kwargs) 23 | 24 | # generate video of smoothed reconstruction 25 | output_path_smoothed = opts.output_path / "result_video_smoothed" 26 | generate_mp4(output_path_smoothed, result_images_smoothed, kwargs) 27 | 28 | # generate coupled video of original frames and smoothed side-by-side 29 | coupled_images = [] 30 | for im, smooth_im in zip(input_images[2:-2], result_images_smoothed): 31 | height, width = smooth_im.shape[:2] 32 | coupled_im = np.concatenate([im.resize((height, height)), smooth_im], axis=1) 33 | coupled_images.append(coupled_im) 34 | output_path_coupled = opts.output_path / "result_video_coupled" 35 | generate_mp4(output_path_coupled, coupled_images, kwargs) 36 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/styleclip_mapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from editing.styleclip_mapper import latent_mappers 5 | from models.stylegan3.model import SG3Generator 6 | 7 | 8 | def get_keys(d, name): 9 | if 'state_dict' in d: 10 | d = d['state_dict'] 11 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 12 | return d_filt 13 | 14 | 15 | class StyleCLIPMapper(nn.Module): 16 | 17 | def __init__(self, opts): 18 | super(StyleCLIPMapper, self).__init__() 19 | self.opts = opts 20 | # Define architecture 21 | self.mapper = self.set_mapper() 22 | self.decoder = SG3Generator(opts.stylegan_weights, res=opts.stylegan_size).decoder 23 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 24 | # Load weights if needed 25 | self.load_weights() 26 | 27 | def set_mapper(self): 28 | if self.opts.mapper_type == 'SingleMapper': 29 | mapper = latent_mappers.SingleMapper(self.opts) 30 | elif self.opts.mapper_type == 'LevelsMapper': 31 | mapper = latent_mappers.LevelsMapper(self.opts) 32 | else: 33 | raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) 34 | return mapper 35 | 36 | def load_weights(self): 37 | if self.opts.checkpoint_path is not None: 38 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) 39 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 40 | self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) 41 | 42 | def forward(self, x, input_code=False): 43 | if input_code: 44 | codes = x 45 | else: 46 | codes = self.mapper(x) 47 | 48 | return codes 49 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def aggregate_loss_dict(agg_loss_dict): 7 | mean_vals = {} 8 | for output in agg_loss_dict: 9 | for key in output: 10 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 11 | for key in mean_vals: 12 | if len(mean_vals[key]) > 0: 13 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 14 | else: 15 | print(f'{key} has no value') 16 | mean_vals[key] = 0 17 | return mean_vals 18 | 19 | 20 | def vis_faces(log_hooks: List[Dict]): 21 | display_count = len(log_hooks) 22 | n_outputs = len(log_hooks[0]['output_face']) if type(log_hooks[0]['output_face']) == list else 1 23 | fig = plt.figure(figsize=(6 + (n_outputs * 2), 4 * display_count)) 24 | gs = fig.add_gridspec(display_count, (2 + n_outputs)) 25 | for i in range(display_count): 26 | hooks_dict = log_hooks[i] 27 | fig.add_subplot(gs[i, 0]) 28 | vis_faces_iterative(hooks_dict, fig, gs, i) 29 | plt.tight_layout() 30 | return fig 31 | 32 | 33 | def vis_faces_iterative(hooks_dict: Dict[str, Any], fig, gs, i: int): 34 | plt.imshow(hooks_dict['input_face']) 35 | plt.title(f'Input\nOut Sim={float(hooks_dict["diff_input"]):.2f}') 36 | fig.add_subplot(gs[i, 1]) 37 | plt.imshow(hooks_dict['target_face']) 38 | plt.title(f'Target\nIn={float(hooks_dict["diff_views"]):.2f}, Out={float(hooks_dict["diff_target"]):.2f}') 39 | for idx, output_idx in enumerate(range(len(hooks_dict['output_face']) - 1, -1, -1)): 40 | output_image, similarity = hooks_dict['output_face'][output_idx] 41 | fig.add_subplot(gs[i, 2 + idx]) 42 | plt.imshow(output_image) 43 | plt.title(f'Output {output_idx}\n Target Sim={float(similarity):.2f}') 44 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint') 14 | self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side') 15 | 16 | self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') 17 | self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") 18 | self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") 19 | self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") 20 | self.parser.add_argument('--stylegan_size', default=1024, type=int) 21 | 22 | 23 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 24 | self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") 25 | self.parser.add_argument('--test_workers', default=0, type=int, help='Number of test/inference dataloader workers') 26 | 27 | self.parser.add_argument('--fourier_features_transforms_path', default=None, type=str, help="Optional path to tranforms") 28 | 29 | self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') 30 | 31 | def parse(self): 32 | opts = self.parser.parse_args() 33 | return opts -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from configs.paths_config import model_paths 5 | from inversion.models.encoders.model_irse import Backbone 6 | 7 | 8 | class IDLoss(nn.Module): 9 | def __init__(self): 10 | super(IDLoss, self).__init__() 11 | print('Loading ResNet ArcFace') 12 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 13 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 14 | self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 15 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 16 | self.facenet.eval() 17 | 18 | def extract_feats(self, x): 19 | if x.shape[2] != 256: 20 | x = self.pool(x) 21 | x = x[:, :, 35:223, 32:220] # Crop interesting region 22 | x = self.face_pool(x) 23 | x_feats = self.facenet(x) 24 | return x_feats 25 | 26 | def forward(self, y_hat, y, x): 27 | n_samples = x.shape[0] 28 | x_feats = self.extract_feats(x) 29 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 30 | y_hat_feats = self.extract_feats(y_hat) 31 | y_feats = y_feats.detach() 32 | loss = 0 33 | sim_improvement = 0 34 | id_logs = [] 35 | count = 0 36 | for i in range(n_samples): 37 | diff_target = y_hat_feats[i].dot(y_feats[i]) 38 | diff_input = y_hat_feats[i].dot(x_feats[i]) 39 | diff_views = y_feats[i].dot(x_feats[i]) 40 | id_logs.append({'diff_target': float(diff_target), 41 | 'diff_input': float(diff_input), 42 | 'diff_views': float(diff_views)}) 43 | loss += 1 - diff_target 44 | id_diff = float(diff_target) - float(diff_views) 45 | sim_improvement += id_diff 46 | count += 1 47 | 48 | return loss / count, sim_improvement / count, id_logs 49 | -------------------------------------------------------------------------------- /inversion/datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | from utils import data_utils 7 | from utils.common import get_identity_transform 8 | 9 | 10 | class InferenceDataset(Dataset): 11 | 12 | def __init__(self, root: Path, landmarks_transforms_path: Path = None, transform=None): 13 | self.paths = sorted(data_utils.make_dataset(root)) 14 | self.landmarks_transforms = self._get_landmarks_transforms(landmarks_transforms_path) 15 | self.transform = transform 16 | 17 | def __len__(self): 18 | return len(self.paths) 19 | 20 | def _get_landmarks_transforms(self, landmarks_transforms_path): 21 | if landmarks_transforms_path is not None: 22 | if not landmarks_transforms_path.exists(): 23 | raise ValueError(f"Invalid path for landmarks transforms: {landmarks_transforms_path}") 24 | landmarks_transforms = np.load(landmarks_transforms_path, allow_pickle=True).item() 25 | # filter out images not appearing in landmarks transforms 26 | valid_files = list(landmarks_transforms.keys()) 27 | self.paths = [f for f in self.paths if f.name in valid_files] 28 | else: 29 | landmarks_transforms = None 30 | return landmarks_transforms 31 | 32 | def _get_transform(self, from_path): 33 | landmarks_transform = self.landmarks_transforms[from_path.name][-1] 34 | return landmarks_transform 35 | 36 | def __getitem__(self, index): 37 | from_path = self.paths[index] 38 | from_im = Image.open(from_path).convert('RGB') 39 | if self.landmarks_transforms is not None: 40 | landmarks_transform = self._get_transform(from_path) 41 | else: 42 | landmarks_transform = get_identity_transform() 43 | if self.transform: 44 | from_im = self.transform(from_im) 45 | return from_im, landmarks_transform 46 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/templates.txt: -------------------------------------------------------------------------------- 1 | a bad photo of a {}. 2 | a sculpture of a {}. 3 | a photo of the hard to see {}. 4 | a low resolution photo of the {}. 5 | a rendering of a {}. 6 | graffiti of a {}. 7 | a bad photo of the {}. 8 | a cropped photo of the {}. 9 | a tattoo of a {}. 10 | the embroidered {}. 11 | a photo of a hard to see {}. 12 | a bright photo of a {}. 13 | a photo of a clean {}. 14 | a photo of a dirty {}. 15 | a dark photo of the {}. 16 | a drawing of a {}. 17 | a photo of my {}. 18 | the plastic {}. 19 | a photo of the cool {}. 20 | a close-up photo of a {}. 21 | a black and white photo of the {}. 22 | a painting of the {}. 23 | a painting of a {}. 24 | a pixelated photo of the {}. 25 | a sculpture of the {}. 26 | a bright photo of the {}. 27 | a cropped photo of a {}. 28 | a plastic {}. 29 | a photo of the dirty {}. 30 | a jpeg corrupted photo of a {}. 31 | a blurry photo of the {}. 32 | a photo of the {}. 33 | a good photo of the {}. 34 | a rendering of the {}. 35 | a {} in a video game. 36 | a photo of one {}. 37 | a doodle of a {}. 38 | a close-up photo of the {}. 39 | a photo of a {}. 40 | the origami {}. 41 | the {} in a video game. 42 | a sketch of a {}. 43 | a doodle of the {}. 44 | a origami {}. 45 | a low resolution photo of a {}. 46 | the toy {}. 47 | a rendition of the {}. 48 | a photo of the clean {}. 49 | a photo of a large {}. 50 | a rendition of a {}. 51 | a photo of a nice {}. 52 | a photo of a weird {}. 53 | a blurry photo of a {}. 54 | a cartoon {}. 55 | art of a {}. 56 | a sketch of the {}. 57 | a embroidered {}. 58 | a pixelated photo of a {}. 59 | itap of the {}. 60 | a jpeg corrupted photo of the {}. 61 | a good photo of a {}. 62 | a plushie {}. 63 | a photo of the nice {}. 64 | a photo of the small {}. 65 | a photo of the weird {}. 66 | the cartoon {}. 67 | art of the {}. 68 | a drawing of the {}. 69 | a photo of the large {}. 70 | a black and white photo of a {}. 71 | the plushie {}. 72 | a dark photo of a {}. 73 | itap of a {}. 74 | graffiti of the {}. 75 | a toy {}. 76 | itap of my {}. 77 | a photo of a cool {}. 78 | a photo of a small {}. 79 | a tattoo of the {}. -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /inversion/options/test_options.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | 4 | import dataclasses 5 | from dataclasses import dataclass 6 | from pyrallis import field 7 | 8 | 9 | @dataclass 10 | class TestOptions: 11 | """ Defines all inference arguments. """ 12 | 13 | """ General args """ 14 | # Path to output inference results to 15 | output_path: Path = Path("./experiments/inference") 16 | # Path to the pretrained encoder 17 | checkpoint_path: Path = Path("./experiments/checkpoints/best_model.pt") 18 | # Path to images to run inference on 19 | data_path: Path = Path("./gt_images") 20 | # Whether to resize output images to 256. By default, keeps original resolution 21 | resize_outputs: bool = False 22 | # Batch size for running inference 23 | test_batch_size: int = 2 24 | # Number of workers for test dataloader 25 | test_workers: int = 2 26 | # Number of images to run inference on. If None, runs inference on all images 27 | n_images: Optional[int] = None 28 | # Number of forward passes per batch during inference 29 | n_iters_per_batch: int = 3 30 | # Path to pkl file with landmarks-based transformations for unaligned images 31 | landmarks_transforms_path: Optional[Path] = None 32 | 33 | """ Editing args """ 34 | # List of edits to perform 35 | edit_directions: List[str] = field(default=["age", "smile", "pose"], is_mutable=True) 36 | # List of ranges for each edit. For example, (-4_5) defines an editing range from -4 to 5 37 | factor_ranges: List[str] = dataclasses.field(default_factory=lambda: ["(-5_5)", "(-5_5)", "(-5_5)"]) 38 | 39 | def __post_init__(self): 40 | self.factor_ranges = self._parse_factor_ranges() 41 | if len(self.edit_directions) != len(self.factor_ranges): 42 | raise ValueError("Invalid edit directions and factor ranges. Please provide a single factor range for each" 43 | f"edit direction. Given: {self.edit_directions} and {self.factor_ranges}") 44 | 45 | def _parse_factor_ranges(self): 46 | factor_ranges = [] 47 | for factor in self.factor_ranges: 48 | start, end = factor.strip("()").split("_") 49 | factor_ranges.append((int(start), int(end))) 50 | return factor_ranges 51 | -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | dataset_paths = { 4 | 'celeba_train': Path(''), 5 | 'celeba_test': Path(''), 6 | 7 | 'ffhq': Path(''), 8 | 'ffhq_unaligned': Path('') 9 | } 10 | 11 | model_paths = { 12 | # models for backbones and losses 13 | 'ir_se50': Path('pretrained_models/model_ir_se50.pth'), 14 | # stylegan3 generators 15 | 'stylegan3_ffhq': Path('pretrained_models/stylegan3-r-ffhq-1024x1024.pkl'), 16 | 'stylegan3_ffhq_pt': Path('pretrained_models/sg3-r-ffhq-1024.pt'), 17 | 'stylegan3_ffhq_unaligned': Path('pretrained_models/stylegan3-r-ffhqu-1024x1024.pkl'), 18 | 'stylegan3_ffhq_unaligned_pt': Path('pretrained_models/sg3-r-ffhqu-1024.pt'), 19 | # model for face alignment 20 | 'shape_predictor': Path('pretrained_models/shape_predictor_68_face_landmarks.dat'), 21 | # models for ID similarity computation 22 | 'curricular_face': Path('pretrained_models/CurricularFace_Backbone.pth'), 23 | 'mtcnn_pnet': Path('pretrained_models/mtcnn/pnet.npy'), 24 | 'mtcnn_rnet': Path('pretrained_models/mtcnn/rnet.npy'), 25 | 'mtcnn_onet': Path('pretrained_models/mtcnn/onet.npy'), 26 | # classifiers used for interfacegan training 27 | 'age_estimator': Path('pretrained_models/dex_age_classifier.pth'), 28 | 'pose_estimator': Path('pretrained_models/hopenet_robust_alpha1.pkl') 29 | } 30 | 31 | styleclip_directions = { 32 | "ffhq": { 33 | 'delta_i_c': Path('editing/styleclip_global_directions/sg3-r-ffhq-1024/delta_i_c.npy'), 34 | 's_statistics': Path('editing/styleclip_global_directions/sg3-r-ffhq-1024/s_stats'), 35 | }, 36 | 'templates': Path('editing/styleclip_global_directions/templates.txt') 37 | } 38 | 39 | interfacegan_aligned_edit_paths = { 40 | 'age': Path('editing/interfacegan/boundaries/ffhq/age_boundary.npy'), 41 | 'smile': Path('editing/interfacegan/boundaries/ffhq/Smiling_boundary.npy'), 42 | 'pose': Path('editing/interfacegan/boundaries/ffhq/pose_boundary.npy'), 43 | 'Male': Path('editing/interfacegan/boundaries/ffhq/Male_boundary.npy'), 44 | } 45 | 46 | interfacegan_unaligned_edit_paths = { 47 | 'age': Path('editing/interfacegan/boundaries/ffhqu/age_boundary.npy'), 48 | 'smile': Path('editing/interfacegan/boundaries/ffhqu/Smiling_boundary.npy'), 49 | 'pose': Path('editing/interfacegan/boundaries/ffhqu/pose_boundary.npy'), 50 | 'Male': Path('editing/interfacegan/boundaries/ffhqu/Male_boundary.npy'), 51 | } 52 | -------------------------------------------------------------------------------- /inversion/models/e4e_modules/latent_codes_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | 6 | class LatentCodesPool: 7 | """This class implements latent codes buffer that stores previously generated w latent codes. 8 | This buffer enables us to update discriminators using a history of generated w's 9 | rather than the ones produced by the latest encoder. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | Parameters: 15 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 16 | """ 17 | self.pool_size = pool_size 18 | if self.pool_size > 0: # create an empty pool 19 | self.num_ws = 0 20 | self.ws = [] 21 | 22 | def query(self, ws): 23 | """Return w's from the pool. 24 | Parameters: 25 | ws: the latest generated w's from the generator 26 | Returns w's from the buffer. 27 | By 50/100, the buffer will return input w's. 28 | By 50/100, the buffer will return w's previously stored in the buffer, 29 | and insert the current w's to the buffer. 30 | """ 31 | if self.pool_size == 0: # if the buffer size is 0, do nothing 32 | return ws 33 | return_ws = [] 34 | for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) 35 | # w = torch.unsqueeze(image.data, 0) 36 | if w.ndim == 2: 37 | i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate 38 | w = w[i] 39 | self.handle_w(w, return_ws) 40 | return_ws = torch.stack(return_ws, 0) # collect all the images and return 41 | return return_ws 42 | 43 | def handle_w(self, w, return_ws): 44 | if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer 45 | self.num_ws = self.num_ws + 1 46 | self.ws.append(w) 47 | return_ws.append(w) 48 | else: 49 | p = random.uniform(0, 1) 50 | if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer 51 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 52 | tmp = self.ws[random_id].clone() 53 | self.ws[random_id] = w 54 | return_ws.append(tmp) 55 | else: # by another 50% chance, the buffer will return the current image 56 | return_ws.append(w) 57 | -------------------------------------------------------------------------------- /models/stylegan3/model.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from enum import Enum 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | from models.stylegan3.networks_stylegan3 import Generator 9 | 10 | 11 | class GeneratorType(str, Enum): 12 | ALIGNED = "aligned" 13 | UNALIGNED = "unaligned" 14 | 15 | def __str__(self): 16 | return str(self.value) 17 | 18 | 19 | class SG3Generator(torch.nn.Module): 20 | 21 | def __init__(self, checkpoint_path: Optional[Path] = None, res: int = 1024, config: str = None): 22 | super(SG3Generator, self).__init__() 23 | print(f"Loading StyleGAN3 generator from path: {checkpoint_path}") 24 | if str(checkpoint_path).endswith("pkl"): 25 | with open(checkpoint_path, "rb") as f: 26 | self.decoder = pickle.load(f)['G_ema'].cuda() 27 | print('Done!') 28 | return 29 | elif config == "landscape": 30 | self.decoder = Generator( 31 | z_dim=512, 32 | c_dim=0, 33 | w_dim=512, 34 | img_resolution=res, 35 | img_channels=3, 36 | channel_base=32768, 37 | channel_max=512, 38 | magnitude_ema_beta=0.9988915792636801, 39 | mapping_kwargs={'num_layers': 2} 40 | ).cuda() 41 | else: 42 | self.decoder = Generator(z_dim=512, 43 | c_dim=0, 44 | w_dim=512, 45 | img_resolution=res, 46 | img_channels=3, 47 | channel_base=65536, 48 | channel_max=1024, 49 | conv_kernel=1, 50 | filter_size=6, 51 | magnitude_ema_beta=0.9988915792636801, 52 | output_scale=0.25, 53 | use_radial_filters=True 54 | ).cuda() 55 | if checkpoint_path is not None: 56 | self._load_checkpoint(checkpoint_path) 57 | print('Done!') 58 | 59 | def _load_checkpoint(self, checkpoint_path): 60 | try: 61 | self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True) 62 | except: 63 | ckpt = torch.load(checkpoint_path) 64 | ckpt = {k: v for k, v in ckpt.items() if "synthesis.input.transform" not in k} 65 | self.decoder.load_state_dict(ckpt, strict=False) 66 | -------------------------------------------------------------------------------- /inversion/scripts/train_restyle_e4e.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pprint 3 | import sys 4 | from plistlib import Dict 5 | from typing import Any 6 | 7 | import dataclasses 8 | import pyrallis 9 | import torch 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from inversion.options.e4e_train_options import e4eTrainOptions 15 | from inversion.training.coach_restyle_e4e import Coach 16 | 17 | 18 | @pyrallis.wrap() 19 | def main(opts: e4eTrainOptions): 20 | previous_train_ckpt = None 21 | if opts.resume_training_from_ckpt: 22 | opts, previous_train_ckpt = load_train_checkpoint(opts) 23 | else: 24 | setup_progressive_steps(opts) 25 | create_initial_experiment_dir(opts) 26 | 27 | coach = Coach(opts, previous_train_ckpt) 28 | coach.train() 29 | 30 | 31 | def load_train_checkpoint(opts: e4eTrainOptions): 32 | train_ckpt_path = opts.resume_training_from_ckpt 33 | previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') 34 | new_opts_dict = dataclasses.asdict(opts) 35 | opts = previous_train_ckpt['opts'] 36 | opts['resume_training_from_ckpt'] = train_ckpt_path 37 | update_new_configs(opts, new_opts_dict) 38 | pprint.pprint(opts) 39 | opts = e4eTrainOptions(**opts) 40 | if opts.sub_exp_dir is not None: 41 | sub_exp_dir = opts.sub_exp_dir 42 | opts.exp_dir = opts.exp_dir / sub_exp_dir 43 | create_initial_experiment_dir(opts) 44 | return opts, previous_train_ckpt 45 | 46 | 47 | def setup_progressive_steps(opts: e4eTrainOptions): 48 | num_style_layers = 16 49 | num_deltas = num_style_layers - 1 50 | if opts.progressive_start is not None: # If progressive delta training 51 | opts.progressive_steps = [0] 52 | next_progressive_step = opts.progressive_start 53 | for i in range(num_deltas): 54 | opts.progressive_steps.append(next_progressive_step) 55 | next_progressive_step += opts.progressive_step_every 56 | 57 | assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ 58 | "Invalid progressive training input" 59 | 60 | 61 | def is_valid_progressive_steps(opts: e4eTrainOptions, num_style_layers: int): 62 | return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 63 | 64 | 65 | def create_initial_experiment_dir(opts: e4eTrainOptions): 66 | opts.exp_dir.mkdir(exist_ok=True, parents=True) 67 | opts_dict = dataclasses.asdict(opts) 68 | pprint.pprint(opts_dict) 69 | with open(opts.exp_dir / 'opt.json', 'w') as f: 70 | json.dump(opts_dict, f, indent=4, sort_keys=True, default=str) 71 | 72 | 73 | def update_new_configs(ckpt_opts: Dict[str, Any], new_opts: Dict[str, Any]): 74 | for k, v in new_opts.items(): 75 | ckpt_opts[k] = v 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/global_direction.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import clip 4 | import torch 5 | 6 | 7 | def features_channels_to_s(channels, std, example_s): 8 | result_s = {} 9 | start_index = 0 10 | for key in example_s: 11 | curr_num_channels = example_s[key].shape[1] 12 | end_index = start_index + curr_num_channels 13 | curr_channels = channels[start_index:end_index] 14 | curr_channels = curr_channels * std[key] 15 | result_s[key] = curr_channels.unsqueeze(0) 16 | start_index = end_index 17 | return result_s 18 | 19 | 20 | class StyleCLIPGlobalDirection: 21 | 22 | def __init__(self, delta_i_c, s_std, text_prompts_templates, s_avg): 23 | super(StyleCLIPGlobalDirection, self).__init__() 24 | self.delta_i_c = delta_i_c 25 | self.s_std = s_std 26 | self.text_prompts_templates = text_prompts_templates 27 | self.clip_model, _ = clip.load("ViT-B/32", device="cuda") 28 | self.s_avg = s_avg 29 | 30 | def get_delta_s(self, neutral_text, target_text, beta): 31 | delta_i = self.get_delta_i([target_text, neutral_text]).float() 32 | r_c = torch.matmul(self.delta_i_c, delta_i) 33 | delta_s = copy.copy(r_c) 34 | channels_to_zero = torch.abs(r_c) < beta 35 | delta_s[channels_to_zero] = 0 36 | max_channel_value = torch.abs(delta_s).max() 37 | if max_channel_value > 0: 38 | delta_s /= max_channel_value 39 | direction = features_channels_to_s(delta_s, self.s_std, self.s_avg) 40 | return direction 41 | 42 | def get_delta_i(self, text_prompts): 43 | text_features = self._get_averaged_text_features(text_prompts) 44 | delta_t = text_features[0] - text_features[1] 45 | delta_i = delta_t / torch.norm(delta_t) 46 | return delta_i 47 | 48 | def _get_averaged_text_features(self, text_prompts): 49 | with torch.no_grad(): 50 | text_features_list = [] 51 | for text_prompt in text_prompts: 52 | formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class 53 | formatted_text_prompts = clip.tokenize(formatted_text_prompts).cuda() # tokenize 54 | text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder 55 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 56 | text_embedding = text_embeddings.mean(dim=0) 57 | text_embedding /= text_embedding.norm() 58 | text_features_list.append(text_embedding) 59 | text_features = torch.stack(text_features_list, dim=1).cuda() 60 | return text_features.t() 61 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | # --- picup frame --- 2 | import cv2 3 | import os 4 | def save_frame(video_name): 5 | # setting 6 | video_folder = 'examples/videos' 7 | pic_folder ='pic' 8 | frame_num = 10 9 | # content 10 | video_path = video_folder + '/' + video_name 11 | cap = cv2.VideoCapture(video_path) 12 | if not cap.isOpened(): 13 | return 14 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) 15 | ret, frame = cap.read() 16 | result_path = pic_folder + '/' + os.path.splitext(video_name)[0]+'.jpg' 17 | if ret: 18 | cv2.imwrite(result_path, frame) 19 | 20 | # --- display_movie --- 21 | import matplotlib.pyplot as plt 22 | from PIL import Image 23 | import numpy as np 24 | import os 25 | def display_movie(folder, name): 26 | fig = plt.figure(figsize=(20, 45)) 27 | files = sorted(os.listdir(folder)) 28 | for i, file in enumerate(files): 29 | if file=='.ipynb_checkpoints': 30 | continue 31 | if file=='.DS_Store': 32 | continue 33 | img = Image.open(folder+'/'+file) 34 | images = np.asarray(img) 35 | ax = fig.add_subplot(10, 3, i+1, xticks=[], yticks=[]) 36 | image_plt = np.array(images) 37 | ax.imshow(image_plt) 38 | ax.set_xlabel(name[i], fontsize=30) 39 | fig.tight_layout() 40 | plt.show() 41 | plt.close() 42 | 43 | 44 | # --- display_mp4 --- 45 | from IPython.display import display, HTML 46 | from IPython.display import HTML 47 | 48 | def display_mp4(path): 49 | from base64 import b64encode 50 | mp4 = open(path,'rb').read() 51 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode() 52 | display(HTML(""" 53 | 56 | """ % data_url)) 57 | #print('Display finished.') ### 58 | 59 | 60 | # --- display_pic --- 61 | import matplotlib.pyplot as plt 62 | from PIL import Image 63 | import numpy as np 64 | import os 65 | 66 | def display_pic(folder): 67 | fig = plt.figure(figsize=(30, 60)) 68 | files = os.listdir(folder) 69 | files.sort() 70 | for i, file in enumerate(files): 71 | if file=='.ipynb_checkpoints': 72 | continue 73 | if file=='.DS_Store': 74 | continue 75 | img = Image.open(folder+'/'+file) 76 | images = np.asarray(img) 77 | ax = fig.add_subplot(10, 5, i+1, xticks=[], yticks=[]) 78 | image_plt = np.array(images) 79 | ax.imshow(image_plt) 80 | #name = os.path.splitext(file) 81 | ax.set_xlabel(file, fontsize=30) 82 | plt.show() 83 | plt.close() 84 | 85 | 86 | # --- reset_folder --- 87 | import shutil 88 | 89 | def reset_folder(path): 90 | if os.path.isdir(path): 91 | shutil.rmtree(path) 92 | os.makedirs(path,exist_ok=True) 93 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from configs.paths_config import model_paths 6 | 7 | 8 | class MocoLoss(nn.Module): 9 | 10 | def __init__(self): 11 | super(MocoLoss, self).__init__() 12 | print(f"Loading MOCO model from path: {model_paths['moco']}") 13 | self.model = self.__load_model() 14 | self.model.cuda() 15 | self.model.eval() 16 | 17 | @staticmethod 18 | def __load_model(): 19 | import torchvision.models as models 20 | model = models.__dict__["resnet50"]() 21 | # freeze all layers but the last fc 22 | for name, param in model.named_parameters(): 23 | if name not in ['fc.weight', 'fc.bias']: 24 | param.requires_grad = False 25 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 26 | state_dict = checkpoint['state_dict'] 27 | # rename moco pre-trained keys 28 | for k in list(state_dict.keys()): 29 | # retain only encoder_q up to before the embedding layer 30 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 31 | # remove prefix 32 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 33 | # delete renamed or unused k 34 | del state_dict[k] 35 | msg = model.load_state_dict(state_dict, strict=False) 36 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 37 | # remove output layer 38 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 39 | return model 40 | 41 | def extract_feats(self, x): 42 | x = F.interpolate(x, size=224) 43 | x_feats = self.model(x) 44 | x_feats = nn.functional.normalize(x_feats, dim=1) 45 | x_feats = x_feats.squeeze() 46 | return x_feats 47 | 48 | def forward(self, y_hat, y, x): 49 | n_samples = x.shape[0] 50 | x_feats = self.extract_feats(x) 51 | y_feats = self.extract_feats(y) 52 | y_hat_feats = self.extract_feats(y_hat) 53 | y_feats = y_feats.detach() 54 | loss = 0 55 | sim_improvement = 0 56 | sim_logs = [] 57 | count = 0 58 | for i in range(n_samples): 59 | diff_target = y_hat_feats[i].dot(y_feats[i]) 60 | diff_input = y_hat_feats[i].dot(x_feats[i]) 61 | diff_views = y_feats[i].dot(x_feats[i]) 62 | sim_logs.append({'diff_target': float(diff_target), 63 | 'diff_input': float(diff_input), 64 | 'diff_views': float(diff_views)}) 65 | loss += 1 - diff_target 66 | sim_diff = float(diff_target) - float(diff_views) 67 | sim_improvement += sim_diff 68 | count += 1 69 | 70 | return loss / count, sim_improvement / count, sim_logs 71 | -------------------------------------------------------------------------------- /inversion/options/train_options.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | from dataclasses import dataclass 4 | 5 | from configs.paths_config import model_paths 6 | 7 | 8 | @dataclass 9 | class TrainOptions: 10 | """ Defines all training arguments. """ 11 | 12 | """ General Args """ 13 | # Path to experiment output directory 14 | exp_dir: Path = "./experiments/experiment" 15 | # Type of dataset/experiment to run 16 | dataset_type: str = "ffhq_encode" 17 | # Which encoder to use 18 | encoder_type: str = "BackboneEncoder" 19 | # Number of input image channels to the ReStyl encoder. Should be set to 6 20 | input_nc: int = 6 21 | # Output resolution of the generator 22 | output_size: int = 1024 23 | # Number of forward passes per batch during training 24 | n_iters_per_batch: int = 3 25 | 26 | """ Dataset args """ 27 | # Batch size for training 28 | batch_size: int = 2 29 | # Batch size of testing/validation 30 | test_batch_size: int = 2 31 | # Number of workers for train dataloader 32 | workers: int = 4 33 | # Number of works for test dataloader 34 | test_workers: int = 4 35 | 36 | """ Optimizer args """ 37 | # Optimizer learning rate 38 | learning_rate: float = 0.0001 39 | # Which optimizer to use 40 | optim_name: str = "ranger" 41 | # Whether to train the decoder during training 42 | train_decoder: bool = False 43 | # Whether to add average latent vector to generate codes from encoder 44 | start_from_latent_avg: bool = True 45 | 46 | """ Loss args """ 47 | # LPIPS loss multiplier factor 48 | lpips_lambda: float = 0 49 | # ID loss multiplier factor 50 | id_lambda: float = 0 51 | # L2 loss multiplier factor 52 | l2_lambda: float = 0 53 | # W-norm loss multiplier factor 54 | w_norm_lambda: float = 0 55 | # Moco feature loss multiplier factor 56 | moco_lambda: float = 0 57 | 58 | """ Checkpoint args """ 59 | # Path to StyleGAN model weights 60 | stylegan_weights: Path = Path(model_paths['stylegan3_ffhq_pt']) 61 | # Path to ReStyle model checkpoint for resuming training from 62 | checkpoint_path: Optional[Path] = None 63 | 64 | """ Logging args """ 65 | # Maximum number of training steps 66 | max_steps: int = 500000 67 | # Interval for logging train images during training 68 | image_interval: int = 100 69 | # Interval for logging metrics to tensorboard 70 | board_interval: int = 50 71 | # Validation interval 72 | val_interval: int = 1000 73 | # Model checkpoint interval 74 | save_interval: Optional[int] = None 75 | # Number of batches to run validation on. If None, run on all batches 76 | max_val_batches: Optional[int] = None 77 | 78 | device: Optional[str] = None 79 | 80 | def update(self, new_opts): 81 | for key, value in new_opts.items(): 82 | setattr(self, key, value) 83 | -------------------------------------------------------------------------------- /prepare_data/landmarks_handler.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | 4 | import dlib 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from configs.paths_config import model_paths 9 | from utils.alignment_utils import get_stylegan_transform 10 | 11 | 12 | class LandmarksHandler: 13 | """ 14 | Computes the landmarks-based transforms between the given aligned and cropped video frames. If the landmarks 15 | have already been saved to the given `landmarks_transforms_path`, simply load and return them. If they have not 16 | been saved yet, save them to the `landmarks_transforms_path` for next time. 17 | """ 18 | def __init__(self, output_path: Path, landmarks_transforms_path: Optional[Path] = None): 19 | if landmarks_transforms_path is None: 20 | landmarks_transforms_path = output_path / "landmarks_transforms.npy" 21 | self.landmarks_transforms_path = landmarks_transforms_path 22 | 23 | def get_landmarks_transforms(self, input_paths: List[Path], cropped_frames_path: Path, 24 | aligned_frames_path: Path, force_computing: bool = False): 25 | if self.landmarks_transforms_path is None: 26 | return None 27 | else: 28 | if self.landmarks_transforms_path.exists() and not force_computing: 29 | print(f"Using pre-computed landmarks from path: {self.landmarks_transforms_path}") 30 | landmarks_transforms = np.load(str(self.landmarks_transforms_path), allow_pickle=True).item() 31 | else: 32 | landmarks_transforms = self._compute_landmarks_transforms(input_paths, 33 | cropped_frames_path, 34 | aligned_frames_path) 35 | np.save(str(self.landmarks_transforms_path), landmarks_transforms) 36 | return landmarks_transforms 37 | 38 | @staticmethod 39 | def _compute_landmarks_transforms(input_paths: List[Path], cropped_frames_path: Path, aligned_frames_path: Path): 40 | print("Computing landmarks transforms...") 41 | detector = dlib.get_frontal_face_detector() 42 | predictor = dlib.shape_predictor(str(model_paths['shape_predictor'])) 43 | landmarks_transforms = {} 44 | for path in tqdm(input_paths): 45 | cropped_path = cropped_frames_path / path.name 46 | aligned_path = aligned_frames_path / path.name 47 | res = get_stylegan_transform(str(cropped_path), str(aligned_path), detector, predictor) 48 | if res is None: 49 | print(f"Failed on: {cropped_path}") 50 | continue 51 | else: 52 | rotation_angle, translation, transform, inverse_transform = res 53 | landmarks_transforms[path.name] = (rotation_angle, translation, transform, inverse_transform) 54 | return landmarks_transforms 55 | -------------------------------------------------------------------------------- /prepare_data/preparing_faces_parallel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing as mp 3 | import sys 4 | import time 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | import pyrallis 9 | 10 | import dlib 11 | from dataclasses import dataclass 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from configs.paths_config import model_paths 17 | from utils.alignment_utils import align_face, crop_face 18 | 19 | SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"] 20 | 21 | 22 | @dataclass 23 | class Options: 24 | # Number of threads to run in parallel 25 | num_threads: int = 1 26 | # Path to raw data 27 | root_path: str = "" 28 | # Should be 'align' / 'crop' 29 | mode: str = "align" 30 | # In case of cropping, amount of random shifting to perform 31 | random_shift: float = 0.05 32 | 33 | 34 | def chunks(lst, n): 35 | """Yield successive n-sized chunks from lst.""" 36 | for i in range(0, len(lst), n): 37 | yield lst[i:i + n] 38 | 39 | 40 | def extract_on_paths(file_paths, args: Options): 41 | 42 | predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) 43 | detector = dlib.get_frontal_face_detector() 44 | pid = mp.current_process().name 45 | print(f'\t{pid} is starting to extract on #{len(file_paths)} images') 46 | tot_count = len(file_paths) 47 | count = 0 48 | for file_path, res_path in file_paths: 49 | count += 1 50 | if count % 100 == 0: 51 | print(f'{pid} done with {count}/{tot_count}') 52 | try: 53 | if args.mode == "align": 54 | res = align_face(file_path, detector, predictor) 55 | else: 56 | res = crop_face(file_path, detector, predictor, random_shift=args.random_shift) 57 | res = res.convert('RGB') 58 | Path(res_path).parent.mkdir(exist_ok=True, parents=True) 59 | res.save(res_path) 60 | except Exception: 61 | continue 62 | print('\tDone!') 63 | 64 | 65 | @pyrallis.wrap() 66 | def run(args: Options): 67 | 68 | assert args.mode in ["align", "crop"], "Expected extractions mode to be one of 'align' or 'crop'" 69 | 70 | root_path = Path(args.root_path) 71 | out_crops_path = root_path.parent / Path(root_path.name + "_" + args.mode + "ed") 72 | if not out_crops_path.exists(): 73 | out_crops_path.mkdir(exist_ok=True, parents=True) 74 | 75 | file_paths = [] 76 | for file in root_path.iterdir(): 77 | res_path = out_crops_path / file.name 78 | file_paths.append((str(file), str(res_path))) 79 | 80 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 81 | print(len(file_chunks)) 82 | pool = mp.Pool(args.num_threads) 83 | print(f'Running on {len(file_paths)} paths\nHere we goooo') 84 | tic = time.time() 85 | pool.map(partial(extract_on_paths, args=args), file_chunks) 86 | toc = time.time() 87 | print(f'Mischief managed in {tic - toc}s') 88 | 89 | 90 | if __name__ == '__main__': 91 | run() 92 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import models 7 | 8 | from criteria.lpips.utils import normalize_activation 9 | 10 | 11 | def get_network(net_type: str): 12 | if net_type == 'alex': 13 | return AlexNet() 14 | elif net_type == 'squeeze': 15 | return SqueezeNet() 16 | elif net_type == 'vgg': 17 | return VGG16() 18 | else: 19 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 20 | 21 | 22 | class LinLayers(nn.ModuleList): 23 | def __init__(self, n_channels_list: Sequence[int]): 24 | super(LinLayers, self).__init__([ 25 | nn.Sequential( 26 | nn.Identity(), 27 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 28 | ) for nc in n_channels_list 29 | ]) 30 | 31 | for param in self.parameters(): 32 | param.requires_grad = False 33 | 34 | 35 | class BaseNet(nn.Module): 36 | def __init__(self): 37 | super(BaseNet, self).__init__() 38 | 39 | # register buffer 40 | self.register_buffer( 41 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 42 | self.register_buffer( 43 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 44 | 45 | def set_requires_grad(self, state: bool): 46 | for param in chain(self.parameters(), self.buffers()): 47 | param.requires_grad = state 48 | 49 | def z_score(self, x: torch.Tensor): 50 | return (x - self.mean) / self.std 51 | 52 | def forward(self, x: torch.Tensor): 53 | x = self.z_score(x) 54 | 55 | output = [] 56 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 57 | x = layer(x) 58 | if i in self.target_layers: 59 | output.append(normalize_activation(x)) 60 | if len(output) == len(self.target_layers): 61 | break 62 | return output 63 | 64 | 65 | class SqueezeNet(BaseNet): 66 | def __init__(self): 67 | super(SqueezeNet, self).__init__() 68 | 69 | self.layers = models.squeezenet1_1(True).features 70 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 71 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 72 | 73 | self.set_requires_grad(False) 74 | 75 | 76 | class AlexNet(BaseNet): 77 | def __init__(self): 78 | super(AlexNet, self).__init__() 79 | 80 | self.layers = models.alexnet(True).features 81 | self.target_layers = [2, 5, 8, 10, 12] 82 | self.n_channels_list = [64, 192, 384, 256, 256] 83 | 84 | self.set_requires_grad(False) 85 | 86 | 87 | class VGG16(BaseNet): 88 | def __init__(self): 89 | super(VGG16, self).__init__() 90 | 91 | self.layers = models.vgg16(True).features 92 | self.target_layers = [4, 9, 16, 23, 30] 93 | self.n_channels_list = [64, 128, 256, 512, 512] 94 | 95 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /editing/interfacegan/helpers/anycostgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | 4 | """ 5 | Code is adopted from: AnyCostGAN (https://github.com/mit-han-lab/anycost-gan) 6 | """ 7 | 8 | URL_TEMPLATE = 'https://hanlab.mit.edu/projects/anycost-gan/files/{}_{}.pt' 9 | attr_list = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 10 | 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 11 | 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 12 | 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 13 | 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 14 | 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'] 15 | 16 | 17 | def safe_load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, 18 | file_name=None): 19 | # a safe version of torch.hub.load_state_dict_from_url in distributed environment 20 | # the main idea is to only download the file on worker 0 21 | try: 22 | import horovod.torch as hvd 23 | world_size = hvd.size() 24 | except: # load horovod failed, just normal environment 25 | return torch.hub.load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 26 | 27 | if world_size == 1: 28 | return torch.hub.load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 29 | else: # world size > 1 30 | if hvd.rank() == 0: # possible download... let it only run on worker 0 to prevent conflict 31 | _ = torch.hub.load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 32 | hvd.broadcast(torch.tensor(0), root_rank=0, name='dummy') 33 | return torch.hub.load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name) 34 | 35 | 36 | def load_state_dict_from_url(url, key=None): 37 | if url.startswith('http'): 38 | sd = safe_load_state_dict_from_url(url, map_location='cpu', progress=True) 39 | else: 40 | sd = torch.load(url, map_location='cpu') 41 | if key is not None: 42 | return sd[key] 43 | return sd 44 | 45 | 46 | def get_pretrained(model, config=None): 47 | if model in ['attribute-predictor', 'inception']: 48 | assert config is None 49 | url = URL_TEMPLATE.format('attribute', 'predictor') # not used for inception 50 | else: 51 | assert config is not None 52 | url = URL_TEMPLATE.format(model, config) 53 | 54 | if model == 'attribute-predictor': # attribute predictor is general 55 | predictor = models.resnet50() 56 | predictor.fc = torch.nn.Linear(predictor.fc.in_features, 40 * 2) 57 | predictor.load_state_dict(load_state_dict_from_url(url, 'state_dict')) 58 | return predictor 59 | else: 60 | raise NotImplementedError -------------------------------------------------------------------------------- /inversion/video/post_processing.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from utils.fov_expansion import Expander 8 | from inversion.video.video_config import VideoConfig 9 | from utils.common import tensor2im, get_identity_transform 10 | 11 | 12 | def postprocess_and_smooth_inversions(results: Dict, net, opts: VideoConfig): 13 | result_latents = np.array(list(results["result_latents"].values())) 14 | # average fine layers 15 | result_latents[:, 9:, :] = result_latents[:, 9:, :].mean(axis=0) 16 | # smooth latents and landmarks transforms 17 | smoothed_latents, smoothed_transforms = smooth_latents_and_transforms(result_latents, 18 | results["landmarks_transforms"], 19 | opts=opts) 20 | # generate the smoothed video frames 21 | result_images_smoothed = [] 22 | expander = Expander(G=net.decoder) 23 | print("Generating smoothed frames...") 24 | for latent, trans in tqdm(zip(smoothed_latents, smoothed_transforms)): 25 | with torch.no_grad(): 26 | if trans is None: 27 | trans = get_identity_transform() 28 | im = expander.generate_expanded_image(ws=latent.unsqueeze(0), 29 | landmark_t=trans.cpu().numpy(), 30 | pixels_left=opts.expansion_amounts[0], 31 | pixels_right=opts.expansion_amounts[1], 32 | pixels_top=opts.expansion_amounts[2], 33 | pixels_bottom=opts.expansion_amounts[3]) 34 | result_images_smoothed.append(np.array(tensor2im(im[0]))) 35 | return result_images_smoothed 36 | 37 | 38 | def smooth_latents_and_transforms(result_latents: np.ndarray, result_landmarks_transforms: List[torch.tensor], 39 | opts: VideoConfig): 40 | smoothed_latents = smooth_ws(result_latents) 41 | smoothed_latents = torch.from_numpy(smoothed_latents).float().cuda() 42 | if opts.landmarks_transforms_path is not None: 43 | smoothed_transforms = smooth_ws(torch.cat([t.unsqueeze(0) for t in result_landmarks_transforms])) 44 | else: 45 | smoothed_transforms = [None] * len(smoothed_latents) 46 | return smoothed_latents, smoothed_transforms 47 | 48 | 49 | def smooth_ws(ws: np.ndarray): 50 | ws_p = ws[2:-2] + 0.75 * ws[3:-1] + 0.75 * ws[1:-3] + 0.25 * ws[:-4] + 0.25 * ws[4:] 51 | ws_p = ws_p / 3 52 | return ws_p 53 | 54 | 55 | def smooth_s(s): 56 | batched_s = {} 57 | for c in s[0]: 58 | bathced_c = torch.cat([s[i][c] for i in range(len(s))]) 59 | batched_s[c] = bathced_c 60 | new_s = {} 61 | for c in batched_s: 62 | new_s[c] = smooth_ws(batched_s[c]) 63 | new_smooth_s = [] 64 | for i in range(new_s['input'].shape[0]): 65 | curr_s = {c: new_s[c][i].unsqueeze(0) for c in new_s} 66 | new_smooth_s.append(curr_s) 67 | return new_smooth_s 68 | -------------------------------------------------------------------------------- /inversion/models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | 3 | from inversion.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 4 | 5 | """ 6 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Backbone(Module): 11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 12 | super(Backbone, self).__init__() 13 | assert input_size in [112, 224], "input_size should be 112 or 224" 14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 16 | blocks = get_blocks(num_layers) 17 | if mode == 'ir': 18 | unit_module = bottleneck_IR 19 | elif mode == 'ir_se': 20 | unit_module = bottleneck_IR_SE 21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 22 | BatchNorm2d(64), 23 | PReLU(64)) 24 | if input_size == 112: 25 | self.output_layer = Sequential(BatchNorm2d(512), 26 | Dropout(drop_ratio), 27 | Flatten(), 28 | Linear(512 * 7 * 7, 512), 29 | BatchNorm1d(512, affine=affine)) 30 | else: 31 | self.output_layer = Sequential(BatchNorm2d(512), 32 | Dropout(drop_ratio), 33 | Flatten(), 34 | Linear(512 * 14 * 14, 512), 35 | BatchNorm1d(512, affine=affine)) 36 | 37 | modules = [] 38 | for block in blocks: 39 | for bottleneck in block: 40 | modules.append(unit_module(bottleneck.in_channel, 41 | bottleneck.depth, 42 | bottleneck.stride)) 43 | self.body = Sequential(*modules) 44 | 45 | def forward(self, x): 46 | x = self.input_layer(x) 47 | x = self.body(x) 48 | x = self.output_layer(x) 49 | return l2_norm(x) 50 | 51 | 52 | def IR_50(input_size): 53 | """Constructs a ir-50 model.""" 54 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 55 | return model 56 | 57 | 58 | def IR_101(input_size): 59 | """Constructs a ir-101 model.""" 60 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 61 | return model 62 | 63 | 64 | def IR_152(input_size): 65 | """Constructs a ir-152 model.""" 66 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 67 | return model 68 | 69 | 70 | def IR_SE_50(input_size): 71 | """Constructs a ir_se-50 model.""" 72 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 73 | return model 74 | 75 | 76 | def IR_SE_101(input_size): 77 | """Constructs a ir_se-101 model.""" 78 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 79 | return model 80 | 81 | 82 | def IR_SE_152(input_size): 83 | """Constructs a ir_se-152 model.""" 84 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 85 | return model 86 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /editing/interfacegan/face_editor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from configs.paths_config import interfacegan_aligned_edit_paths, interfacegan_unaligned_edit_paths 7 | from models.stylegan3.model import GeneratorType 8 | from models.stylegan3.networks_stylegan3 import Generator 9 | from utils.common import tensor2im, generate_random_transform 10 | 11 | 12 | class FaceEditor: 13 | 14 | def __init__(self, stylegan_generator: Generator, generator_type=GeneratorType.ALIGNED): 15 | self.generator = stylegan_generator 16 | if generator_type == GeneratorType.ALIGNED: 17 | paths = interfacegan_aligned_edit_paths 18 | else: 19 | paths = interfacegan_unaligned_edit_paths 20 | 21 | self.interfacegan_directions = { 22 | 'age': torch.from_numpy(np.load(paths['age'])).cuda(), 23 | 'smile': torch.from_numpy(np.load(paths['smile'])).cuda(), 24 | 'pose': torch.from_numpy(np.load(paths['pose'])).cuda(), 25 | 'Male': torch.from_numpy(np.load(paths['Male'])).cuda(), 26 | } 27 | 28 | def edit(self, latents: torch.tensor, direction: str, factor: int = 1, factor_range: Optional[Tuple[int, int]] = None, 29 | user_transforms: Optional[np.ndarray] = None, apply_user_transformations: Optional[bool] = False): 30 | edit_latents = [] 31 | edit_images = [] 32 | direction = self.interfacegan_directions[direction] 33 | if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) 34 | for f in range(*factor_range): 35 | edit_latent = latents + f * direction 36 | edit_image, user_transforms = self._latents_to_image(edit_latent, 37 | apply_user_transformations, 38 | user_transforms) 39 | edit_latents.append(edit_latent) 40 | edit_images.append(edit_image) 41 | else: 42 | edit_latents = latents + factor * direction 43 | edit_images, _ = self._latents_to_image(edit_latents, apply_user_transformations) 44 | return edit_images, edit_latents 45 | 46 | def _latents_to_image(self, all_latents: torch.tensor, apply_user_transformations: bool = False, 47 | user_transforms: Optional[torch.tensor] = None): 48 | with torch.no_grad(): 49 | if apply_user_transformations: 50 | if user_transforms is None: 51 | # if no transform provided, generate a random transformation 52 | user_transforms = generate_random_transform(translate=0.3, rotate=25) 53 | # apply the user-specified transformation 54 | if type(user_transforms) == np.ndarray: 55 | user_transforms = torch.from_numpy(user_transforms) 56 | self.generator.synthesis.input.transform = user_transforms.cuda().float() 57 | # generate the images 58 | images = self.generator.synthesis(all_latents, noise_mode='const') 59 | images = [tensor2im(image) for image in images] 60 | return images, user_transforms 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TrainOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 12 | self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') 13 | self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") 14 | self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") 15 | self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") 16 | self.parser.add_argument('--latents_train_path', type=str, required=True, help="The latents for the training") 17 | self.parser.add_argument('--latents_test_path', type=str, required=True, help="The latents for the validation") 18 | self.parser.add_argument('--train_dataset_size', default=10000, type=int, help="Will be used only if no latents are given") 19 | self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given") 20 | 21 | self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') 22 | self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') 23 | self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') 24 | self.parser.add_argument('--test_workers', default=1, type=int, help='Number of test/inference dataloader workers') 25 | 26 | self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') 27 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 28 | 29 | self.parser.add_argument('--id_lambda', default=0.06, type=float, help='ID loss multiplier factor') 30 | self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor') 31 | self.parser.add_argument('--latent_l2_lambda', default=0.3, type=float, help='Latent L2 loss multiplier factor') 32 | 33 | self.parser.add_argument('--stylegan_weights', default="/path/to/weights", type=str, help='Path to StyleGAN model weights') 34 | self.parser.add_argument('--truncation_psi', default=0.7, type=int) 35 | self.parser.add_argument('--stylegan_size', default=1024, type=int) 36 | self.parser.add_argument('--ir_se50_weights', default='/path/to/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") 37 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIP model checkpoint') 38 | 39 | self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps') 40 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') 41 | self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') 42 | self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval') 43 | self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval') 44 | 45 | self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt') 46 | 47 | 48 | def parse(self): 49 | opts = self.parser.parse_args() 50 | return opts -------------------------------------------------------------------------------- /editing/interfacegan/train_boundaries.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pyrallis 6 | from dataclasses import dataclass 7 | from tqdm import tqdm 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from editing.interfacegan.helpers.anycostgan import attr_list 13 | from editing.interfacegan.helpers.manipulator import train_boundary 14 | 15 | 16 | @dataclass 17 | class TrainConfig: 18 | # Path to the `npy` saved from the script `generate_latents_and_attribute_scores.py` 19 | input_path: Path = Path("./latents") 20 | # Where to ave the boundary `npy` files to 21 | output_path: Path = Path("./boundaries") 22 | 23 | 24 | @pyrallis.wrap() 25 | def main(opts: TrainConfig): 26 | all_latent_codes, all_attribute_scores, all_ages, all_poses = [], [], [], [] 27 | for batch_dir in tqdm(opts.input_path.glob("*")): 28 | if not str(batch_dir.name).startswith("id_"): 29 | continue 30 | # load batch latents 31 | latent_codes = np.load(opts.input_path / batch_dir / 'ws.npy', allow_pickle=True) 32 | all_latent_codes.extend(latent_codes.tolist()) 33 | # load batch attribute scores 34 | scores = np.load(opts.input_path / batch_dir / 'scores.npy', allow_pickle=True) 35 | all_attribute_scores.extend(scores.tolist()) 36 | # load batch ages 37 | ages = np.load(opts.input_path / batch_dir / 'ages.npy', allow_pickle=True) 38 | all_ages.extend(ages.tolist()) 39 | # load batch poses 40 | poses = np.load(opts.input_path / batch_dir / 'poses.npy', allow_pickle=True) 41 | all_poses.extend(poses.tolist()) 42 | 43 | opts.output_path.mkdir(exist_ok=True, parents=True) 44 | 45 | print(f"Obtained a total of {len(all_latent_codes)} latent codes!") 46 | 47 | all_latent_codes = np.array(all_latent_codes) 48 | all_latent_codes = np.array([l[0] for l in all_latent_codes]) 49 | 50 | # train all boundaries for all attributes predicted from the AnyCostGAN classifier 51 | for attribute_name in attr_list: 52 | print("Training boundary for: {attribute_name}") 53 | attr_scores = [s[attr_list.index(attribute_name)][1] for s in all_attribute_scores] 54 | attr_scores = np.array(attr_scores)[:, np.newaxis] 55 | boundary = train_boundary(latent_codes=np.array(all_latent_codes), 56 | scores=attr_scores, 57 | chosen_num_or_ratio=0.02, 58 | split_ratio=0.7, 59 | invalid_value=None) 60 | np.save(opts.output_path / f'{attribute_name}_boundary.npy', boundary) 61 | 62 | # train the age boundary 63 | boundary = train_boundary(latent_codes=np.array(all_latent_codes), 64 | scores=np.array(all_ages), 65 | chosen_num_or_ratio=0.02, 66 | split_ratio=0.7, 67 | invalid_value=None) 68 | np.save(opts.output_path / f'age_boundary.npy', boundary) 69 | 70 | boundary = train_boundary(latent_codes=np.array(all_latent_codes), 71 | scores=np.array(all_poses), 72 | chosen_num_or_ratio=0.02, 73 | split_ratio=0.7, 74 | invalid_value=None) 75 | np.save(opts.output_path / f'pose_boundary.npy', boundary) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /inversion/scripts/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from pathlib import Path 4 | from typing import List, Optional 5 | 6 | import numpy as np 7 | import pyrallis 8 | import torch 9 | import torchvision.transforms as transforms 10 | from dataclasses import dataclass 11 | from pyrallis import field 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | sys.path.append(".") 16 | sys.path.append("..") 17 | 18 | from criteria.lpips.lpips import LPIPS 19 | from criteria.ms_ssim import MSSSIM 20 | from inversion.datasets.gt_res_dataset import GTResDataset 21 | 22 | 23 | @dataclass 24 | class RunConfig: 25 | # Path to reconstructed images 26 | output_path: Path 27 | # Path to gt images 28 | gt_path: Path 29 | # List of metrics to compute 30 | metrics: List[str] = field(default=["lpips", "l2", "msssim"], is_mutable=True) 31 | # Number of works for dataloader 32 | workers: int = 4 33 | # Batch size for computing losses 34 | batch_size: int = 4 35 | # Stores current metric 36 | metric: Optional[str] = None 37 | 38 | 39 | @pyrallis.wrap() 40 | def run(opts: RunConfig): 41 | for metric in opts.metrics: 42 | opts.metric = metric 43 | for step in sorted(opts.output_path.iterdir()): 44 | if not str(step.name).isdigit(): 45 | continue 46 | step_outputs_path = opts.output_path / step.name 47 | if step_outputs_path.is_dir(): 48 | print('#' * 80) 49 | print(f'Computing {opts.metric} on step: {step.name}') 50 | print('#' * 80) 51 | run_on_step_output(step=step.name, opts=opts) 52 | 53 | 54 | def run_on_step_output(step: str, opts: RunConfig): 55 | 56 | transform = transforms.Compose([transforms.Resize((256, 256)), 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 59 | 60 | step_outputs_path = opts.output_path / step 61 | 62 | print('Loading dataset') 63 | dataset = GTResDataset(root_path=step_outputs_path, 64 | gt_dir=opts.gt_path, 65 | transform=transform) 66 | 67 | dataloader = DataLoader(dataset, 68 | batch_size=opts.batch_size, 69 | shuffle=False, 70 | num_workers=int(opts.workers), 71 | drop_last=True) 72 | 73 | if opts.metric == 'lpips': 74 | loss_func = LPIPS(net_type='alex') 75 | elif opts.metric == 'l2': 76 | loss_func = torch.nn.MSELoss() 77 | elif opts.metric == 'msssim': 78 | loss_func = MSSSIM() 79 | else: 80 | raise Exception(f'Not a valid metric: {opts.metric}!') 81 | 82 | loss_func.cuda() 83 | 84 | global_i = 0 85 | scores_dict = {} 86 | all_scores = [] 87 | for result_batch, gt_batch in tqdm(dataloader): 88 | for i in range(opts.batch_size): 89 | loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda())) 90 | all_scores.append(loss) 91 | im_path = dataset.pairs[global_i][0] 92 | scores_dict[im_path.name] = loss 93 | global_i += 1 94 | 95 | all_scores = list(scores_dict.values()) 96 | mean = np.mean(all_scores) 97 | std = np.std(all_scores) 98 | result_str = f'Average loss is {mean:.2f}+-{std:.2f}' 99 | print('Finished with ', step_outputs_path) 100 | print(result_str) 101 | 102 | out_path = opts.output_path.parent / 'inference_metrics' 103 | out_path.mkdir(exist_ok=True, parents=True) 104 | 105 | with open(out_path / f'stat_{opts.metric}_step_{step}.txt', 'w') as f: 106 | f.write(result_str) 107 | with open(out_path / f'scores_{opts.metric}_step_{step}.json', 'w') as f: 108 | json.dump(scores_dict, f) 109 | 110 | 111 | if __name__ == '__main__': 112 | run() 113 | -------------------------------------------------------------------------------- /editing/styleclip_global_directions/preprocess/s_statistics.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pyrallis 6 | import torch 7 | from dataclasses import dataclass 8 | 9 | from configs.paths_config import model_paths 10 | from models.stylegan3.model import SG3Generator 11 | 12 | 13 | @dataclass 14 | class Options: 15 | """ StyleGAN Args """ 16 | # Path to StyleGAN model weights 17 | checkpoint_path: Path = Path(model_paths['stylegan3_ffhq_pt']) 18 | # Images resolution generated by the StyleGAN model 19 | stylegan_size: int = 1024 20 | # Is it the landscape model? If so, different init_kwargs are used to load the pretained StyleGAN model 21 | is_landscape: bool = False 22 | 23 | """ Global Direction Args """ 24 | # Truncation used for generating the images 25 | truncation_psi: float = 0.5 26 | # Truncation cutoff used for generating the images 27 | truncation_cutoff: int = 8 28 | # Correct pose when generating images 29 | pseudo_align: bool = True 30 | 31 | """ General Args """ 32 | # Path to directory in which result files are saved 33 | output_path: Path = Path("stats") 34 | # Number of samples used for computing the stats 35 | num_images: int = 100_000 36 | # Seed for random state 37 | random_state: int = 0 38 | 39 | 40 | def save_stats(G, random_state, num_images, truncation_psi, truncation_cutoff, output_path): 41 | rnd = np.random.RandomState(random_state) 42 | z = rnd.randn(num_images, 512) 43 | z = torch.tensor(z).cuda() 44 | 45 | with torch.no_grad(): 46 | ws = G.mapping(z=z, c=None, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) 47 | all_s = G.synthesis.W2S(ws) 48 | 49 | w = ws.cpu().numpy()[:, 0, :] 50 | result_w_path = output_path / 'W' 51 | np.save(result_w_path, w) 52 | 53 | all_s_np = {} 54 | all_s_1000 = {} 55 | s_mean = {} 56 | s_std = {} 57 | for layer in all_s.keys(): 58 | s = all_s[layer].cpu().numpy() 59 | all_s_np[layer] = s 60 | all_s_1000[layer] = s[:1000] 61 | s_mean[layer] = s.mean(axis=0) 62 | s_std[layer] = s.std(axis=0) 63 | 64 | result_s_path = output_path / 'S' 65 | with open(result_s_path, "wb") as fp: 66 | pickle.dump(all_s_np, fp) 67 | 68 | result_s1000_path = output_path / 'S_1000' 69 | with open(result_s1000_path, "wb") as fp: 70 | pickle.dump(all_s_1000, fp) 71 | 72 | fourier_features_channels = all_s['input'].cpu().numpy() 73 | theta = np.arccos(fourier_features_channels[:, 0]) 74 | 75 | theta_mean = np.mean(theta) 76 | x = fourier_features_channels[:, 2].mean() 77 | y = fourier_features_channels[:, 3].mean() 78 | 79 | transform = {} 80 | transform['theta'] = theta_mean 81 | transform['x'] = x 82 | transform['y'] = y 83 | 84 | s_stats = [transform, s_mean, s_std] 85 | 86 | result_s_stats_path = output_path / 's_stats' 87 | with open(result_s_stats_path, "wb") as fp: 88 | pickle.dump(s_stats, fp) 89 | 90 | 91 | @pyrallis.wrap() 92 | def main(args: Options): 93 | args.output_path.mkdir(exist_ok=True, parents=True) 94 | G = SG3Generator(args.checkpoint_path, res=args.stylegan_size, 95 | config="landscape" if args.is_landscape else None).decoder 96 | save_stats(G, args.random_state, args.num_images, args.truncation_psi, args.truncation_cutoff, args.output_path) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | from .box_utils import nms, _preprocess 8 | 9 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | device = 'cuda:0' 11 | 12 | 13 | def run_first_stage(image, net, scale, threshold): 14 | """Run P-Net, generate bounding boxes, and do NMS. 15 | 16 | Arguments: 17 | image: an instance of PIL.Image. 18 | net: an instance of pytorch's nn.Module, P-Net. 19 | scale: a float number, 20 | scale width and height of the image by this number. 21 | threshold: a float number, 22 | threshold on the probability of a face when generating 23 | bounding boxes from predictions of the net. 24 | 25 | Returns: 26 | a float numpy array of shape [n_boxes, 9], 27 | bounding boxes with scores and offsets (4 + 1 + 4). 28 | """ 29 | 30 | # scale the image and convert it to a float array 31 | width, height = image.size 32 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 33 | img = image.resize((sw, sh), Image.BILINEAR) 34 | img = np.asarray(img, 'float32') 35 | 36 | img = torch.FloatTensor(_preprocess(img)).to(device) 37 | with torch.no_grad(): 38 | output = net(img) 39 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 40 | offsets = output[0].cpu().data.numpy() 41 | # probs: probability of a face at each sliding window 42 | # offsets: transformations to true bounding boxes 43 | 44 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 45 | if len(boxes) == 0: 46 | return None 47 | 48 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 49 | return boxes[keep] 50 | 51 | 52 | def _generate_bboxes(probs, offsets, scale, threshold): 53 | """Generate bounding boxes at places 54 | where there is probably a face. 55 | 56 | Arguments: 57 | probs: a float numpy array of shape [n, m]. 58 | offsets: a float numpy array of shape [1, 4, n, m]. 59 | scale: a float number, 60 | width and height of the image were scaled by this number. 61 | threshold: a float number. 62 | 63 | Returns: 64 | a float numpy array of shape [n_boxes, 9] 65 | """ 66 | 67 | # applying P-Net is equivalent, in some sense, to 68 | # moving 12x12 window with stride 2 69 | stride = 2 70 | cell_size = 12 71 | 72 | # indices of boxes where there is probably a face 73 | inds = np.where(probs > threshold) 74 | 75 | if inds[0].size == 0: 76 | return np.array([]) 77 | 78 | # transformations of bounding boxes 79 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 80 | # they are defined as: 81 | # w = x2 - x1 + 1 82 | # h = y2 - y1 + 1 83 | # x1_true = x1 + tx1*w 84 | # x2_true = x2 + tx2*w 85 | # y1_true = y1 + ty1*h 86 | # y2_true = y2 + ty2*h 87 | 88 | offsets = np.array([tx1, ty1, tx2, ty2]) 89 | score = probs[inds[0], inds[1]] 90 | 91 | # P-Net is applied to scaled images 92 | # so we need to rescale bounding boxes back 93 | bounding_boxes = np.vstack([ 94 | np.round((stride * inds[1] + 1.0) / scale), 95 | np.round((stride * inds[0] + 1.0) / scale), 96 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 97 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 98 | score, offsets 99 | ]) 100 | # why one is added? 101 | 102 | return bounding_boxes.T 103 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from argparse import Namespace 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from editing.styleclip_mapper.datasets.latents_dataset import LatentsDataset 16 | 17 | from editing.styleclip_mapper.options.test_options import TestOptions 18 | from editing.styleclip_mapper.styleclip_mapper import StyleCLIPMapper 19 | 20 | 21 | def run(test_opts): 22 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 23 | os.makedirs(out_path_results, exist_ok=True) 24 | 25 | # update test options with options used during training 26 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 27 | opts = ckpt['opts'] 28 | opts.update(vars(test_opts)) 29 | opts = Namespace(**opts) 30 | 31 | net = StyleCLIPMapper(opts) 32 | net.eval() 33 | net.cuda() 34 | 35 | test_latents = torch.load(opts.latents_test_path) 36 | if opts.fourier_features_transforms_path: 37 | transforms = np.load(opts.fourier_features_transforms_path, allow_pickle=True) 38 | else: 39 | transforms = None 40 | dataset = LatentsDataset(latents=test_latents.cpu(), opts=opts, transforms=transforms) 41 | dataloader = DataLoader(dataset, 42 | batch_size=opts.test_batch_size, 43 | shuffle=False, 44 | num_workers=int(opts.test_workers), 45 | drop_last=True) 46 | 47 | if opts.n_images is None: 48 | opts.n_images = len(dataset) 49 | 50 | global_i = 0 51 | global_time = [] 52 | for input_batch in tqdm(dataloader): 53 | if global_i >= opts.n_images: 54 | break 55 | with torch.no_grad(): 56 | if opts.fourier_features_transforms_path: 57 | input_cuda, transform = input_batch 58 | transform = transform.cuda() 59 | else: 60 | input_cuda = input_batch 61 | transform = None 62 | input_cuda = input_cuda.cuda() 63 | 64 | tic = time.time() 65 | result_batch = run_on_batch(input_cuda, transform, net, opts.couple_outputs) 66 | toc = time.time() 67 | global_time.append(toc - tic) 68 | 69 | for i in range(opts.test_batch_size): 70 | im_path = str(global_i).zfill(5) 71 | if test_opts.couple_outputs: 72 | couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)]) 73 | torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, range=(-1, 1)) 74 | else: 75 | torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, range=(-1, 1)) 76 | torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt")) 77 | 78 | global_i += 1 79 | 80 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 81 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 82 | print(result_str) 83 | 84 | with open(stats_path, 'w') as f: 85 | f.write(result_str) 86 | 87 | 88 | def run_on_batch(inputs, transform, net, couple_outputs=False): 89 | w = inputs 90 | with torch.no_grad(): 91 | w_hat = w + 0.1 * net.mapper(w) 92 | if transform is not None: 93 | net.decoder.synthesis.input.transform = transform 94 | x_hat = net.decoder.synthesis(w_hat) 95 | result_batch = (x_hat, w_hat) 96 | if couple_outputs: 97 | x = net.decoder.synthesis(w) 98 | result_batch = (x_hat, w_hat, x) 99 | return result_batch 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | test_opts = TestOptions().parse() 105 | run(test_opts) 106 | -------------------------------------------------------------------------------- /inversion/models/encoders/restyle_psp_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module 4 | from torchvision.models.resnet import resnet34 5 | 6 | from inversion.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE 7 | from inversion.models.encoders.map2style import GradualStyleBlock 8 | 9 | 10 | class BackboneEncoder(Module): 11 | """ 12 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 13 | map of the encoder. This classes uses the simplified architecture applied over an ResNet IRSE-50 backbone. 14 | Note this class is designed to be used for the human facial domain. 15 | """ 16 | def __init__(self, num_layers, mode='ir', n_styles=18, opts=None): 17 | super(BackboneEncoder, self).__init__() 18 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 19 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 20 | blocks = get_blocks(num_layers) 21 | if mode == 'ir': 22 | unit_module = bottleneck_IR 23 | elif mode == 'ir_se': 24 | unit_module = bottleneck_IR_SE 25 | 26 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 27 | BatchNorm2d(64), 28 | PReLU(64)) 29 | modules = [] 30 | for block in blocks: 31 | for bottleneck in block: 32 | modules.append(unit_module(bottleneck.in_channel, 33 | bottleneck.depth, 34 | bottleneck.stride)) 35 | self.body = Sequential(*modules) 36 | 37 | self.styles = nn.ModuleList() 38 | self.style_count = n_styles 39 | for i in range(self.style_count): 40 | style = GradualStyleBlock(512, 512, 16) 41 | self.styles.append(style) 42 | 43 | def forward(self, x): 44 | x = self.input_layer(x) 45 | x = self.body(x) 46 | latents = [] 47 | for j in range(self.style_count): 48 | latents.append(self.styles[j](x)) 49 | out = torch.stack(latents, dim=1) 50 | return out 51 | 52 | 53 | class ResNetBackboneEncoder(Module): 54 | """ 55 | The simpler backbone architecture used by ReStyle where all style vectors are extracted from the final 16x16 feature 56 | map of the encoder. This classes uses the simplified architecture applied over an ResNet34 backbone. 57 | """ 58 | def __init__(self, n_styles=18, opts=None): 59 | super(ResNetBackboneEncoder, self).__init__() 60 | 61 | self.conv1 = nn.Conv2d(opts.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 62 | self.bn1 = BatchNorm2d(64) 63 | self.relu = PReLU(64) 64 | 65 | resnet_basenet = resnet34(pretrained=True) 66 | blocks = [ 67 | resnet_basenet.layer1, 68 | resnet_basenet.layer2, 69 | resnet_basenet.layer3, 70 | resnet_basenet.layer4 71 | ] 72 | modules = [] 73 | for block in blocks: 74 | for bottleneck in block: 75 | modules.append(bottleneck) 76 | self.body = Sequential(*modules) 77 | 78 | self.styles = nn.ModuleList() 79 | self.style_count = n_styles 80 | for i in range(self.style_count): 81 | style = GradualStyleBlock(512, 512, 16) 82 | self.styles.append(style) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | x = self.body(x) 89 | latents = [] 90 | for j in range(self.style_count): 91 | latents.append(self.styles[j](x)) 92 | out = torch.stack(latents, dim=1) 93 | return out 94 | -------------------------------------------------------------------------------- /inversion/models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError(f"Invalid number of layers: {num_layers}. Must be one of [50, 100, 152]") 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | SEModule(depth, 16) 115 | ) 116 | 117 | def forward(self, x): 118 | shortcut = self.shortcut_layer(x) 119 | res = self.res_layer(x) 120 | return res + shortcut 121 | -------------------------------------------------------------------------------- /inversion/scripts/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import multiprocessing as mp 4 | import sys 5 | import time 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import pyrallis 10 | import torch 11 | import torchvision.transforms as trans 12 | from PIL import Image 13 | from dataclasses import dataclass 14 | 15 | sys.path.append(".") 16 | sys.path.append("..") 17 | 18 | from inversion.models.mtcnn.mtcnn import MTCNN 19 | from inversion.models.encoders.model_irse import IR_101 20 | from configs.paths_config import model_paths 21 | 22 | 23 | CIRCULAR_FACE_PATH = model_paths['curricular_face'] 24 | 25 | 26 | @dataclass 27 | class RunConfig: 28 | # Path to reconstructed images 29 | output_path: Path 30 | # Path to gt images 31 | gt_path: Path 32 | # Number of works to use for computing losses in parallel 33 | num_threads: int = 4 34 | 35 | 36 | @pyrallis.wrap() 37 | def run(opts: RunConfig): 38 | for step in sorted(opts.output_path.glob("*")): 39 | if not str(step.name).isdigit(): 40 | continue 41 | step_outputs_path = opts.output_path / step.name 42 | if step_outputs_path.is_dir(): 43 | print('#' * 80) 44 | print(f'Running on step: {step.name}') 45 | print('#' * 80) 46 | run_on_step_output(step=step.name, args=opts) 47 | 48 | 49 | def run_on_step_output(step: str, args: RunConfig): 50 | file_paths = [] 51 | step_outputs_path = args.output_path / step 52 | for f in step_outputs_path.glob("*"): 53 | image_path = step_outputs_path / f 54 | gt_path = args.gt_path / f 55 | if f.suffix in [".jpg", ".png", ".jpeg"]: 56 | file_paths.append([image_path, gt_path]) 57 | 58 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 59 | pool = mp.Pool(args.num_threads) 60 | print(f'Running on {len(file_paths)} paths\nHere we goooo') 61 | 62 | tic = time.time() 63 | results = pool.map(extract_on_paths, file_chunks) 64 | scores_dict = {} 65 | for d in results: 66 | scores_dict.update(d) 67 | 68 | all_scores = list(scores_dict.values()) 69 | mean = np.mean(all_scores) 70 | std = np.std(all_scores) 71 | result_str = f'New Average score is {mean:.2f}+-{std:.2f}' 72 | print(result_str) 73 | 74 | out_path = args.output_path.parent / 'inference_metrics' 75 | out_path.mkdir(exist_ok=True, parents=True) 76 | 77 | with open(out_path / f'stat_id_step_{step}.txt', 'w') as f: 78 | f.write(result_str) 79 | with open(out_path / f'scores_id_step_{step}.json', 'w') as f: 80 | json.dump(scores_dict, f) 81 | 82 | toc = time.time() 83 | print(f'Mischief managed in {tic - toc}s') 84 | 85 | 86 | def chunks(lst, n): 87 | """Yield successive n-sized chunks from lst.""" 88 | for i in range(0, len(lst), n): 89 | yield lst[i:i + n] 90 | 91 | 92 | def extract_on_paths(file_paths): 93 | facenet = IR_101(input_size=112) 94 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 95 | facenet.cuda() 96 | facenet.eval() 97 | mtcnn = MTCNN() 98 | id_transform = trans.Compose([ 99 | trans.ToTensor(), 100 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 101 | ]) 102 | 103 | pid = mp.current_process().name 104 | print(f'\t{pid} is starting to extract on {len(file_paths)} images') 105 | tot_count = len(file_paths) 106 | count = 0 107 | 108 | scores_dict = {} 109 | for res_path, gt_path in file_paths: 110 | count += 1 111 | if count % 100 == 0: 112 | print(f'{pid} done with {count}/{tot_count}') 113 | if True: 114 | input_im = Image.open(res_path) 115 | input_im, _ = mtcnn.align(input_im) 116 | if input_im is None: 117 | print(f'{pid} skipping {res_path}') 118 | continue 119 | 120 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 121 | 122 | result_im = Image.open(gt_path) 123 | result_im, _ = mtcnn.align(result_im) 124 | if result_im is None: 125 | print(f'{pid} skipping {gt_path}') 126 | continue 127 | 128 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 129 | score = float(input_id.dot(result_id)) 130 | scores_dict[gt_path.name] = score 131 | 132 | return scores_dict 133 | 134 | 135 | if __name__ == '__main__': 136 | run() 137 | -------------------------------------------------------------------------------- /editing/interfacegan/generate_latents_and_attribute_scores.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import numpy as np 7 | import pyrallis 8 | import torch 9 | from dataclasses import dataclass 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs.paths_config import model_paths 15 | from editing.interfacegan.helpers import anycostgan 16 | from editing.interfacegan.helpers.pose_estimator import PoseEstimator 17 | from editing.interfacegan.helpers.age_estimator import AgeEstimator 18 | 19 | 20 | @dataclass 21 | class EditConfig: 22 | # Path to StyleGAN3 generator 23 | generator_path: Path = Path(model_paths["stylegan3_ffhq"]) 24 | # Number of latents to sample 25 | n_images: int = 500000 26 | # Truncation psi for sampling 27 | truncation_psi: float = 0.7 28 | # Where to save the `npy` files with latents and scores to 29 | output_path: Path = Path("./latents") 30 | # How often to save sample latents/scores to `npy` files 31 | save_interval: int = 10000 32 | 33 | 34 | @pyrallis.wrap() 35 | def run(opts: EditConfig): 36 | generate_images(generator_path=opts.generator_path, 37 | n_images=opts.n_images, 38 | truncation_psi=opts.truncation_psi, 39 | output_path=opts.output_path, 40 | save_interval=opts.save_interval) 41 | 42 | 43 | def generate_images(generator_path: Path, n_images: int, truncation_psi: float, output_path: Path, save_interval: int): 44 | 45 | print('Loading generator from "%s"...' % generator_path) 46 | device = torch.device('cuda') 47 | with open(generator_path, "rb") as f: 48 | G = pickle.load(f)['G_ema'].cuda() 49 | 50 | output_path.mkdir(exist_ok=True, parents=True) 51 | 52 | # estimator for all attributes 53 | estimator = anycostgan.get_pretrained('attribute-predictor').to('cuda:0') 54 | estimator.eval() 55 | 56 | # estimators for age and pose 57 | age_estimator = AgeEstimator() 58 | pose_estimator = PoseEstimator() 59 | 60 | face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 61 | 62 | preds, ages, poses, ws = [], [], [], [] 63 | saving_batch_id = 0 64 | for seed_idx, seed in enumerate(range(n_images)): 65 | 66 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 67 | w = G.mapping(z, None, truncation_psi=truncation_psi) 68 | ws.append(w.detach().cpu().numpy()) 69 | 70 | # if using unaligned generator, before generating the image and predicting attribute scores, align the image 71 | if generator_path == Path(model_paths["stylegan3_ffhq_unaligned"]): 72 | w[:, 0] = G.mapping.w_avg 73 | 74 | img = G.synthesis(w, noise_mode="const") 75 | img = face_pool(img) 76 | 77 | # get attribute scores for the generated image 78 | logits = estimator(img).view(-1, 40, 2)[0] 79 | attr_preds = torch.nn.functional.softmax(logits).cpu().detach().numpy() 80 | preds.append(attr_preds) 81 | 82 | # get predicted age 83 | age = age_estimator.extract_ages(img).cpu().detach().numpy()[0] 84 | ages.append(age) 85 | 86 | # get predicted pose 87 | pose = pose_estimator.extract_yaw(img).cpu().detach().numpy()[0] 88 | poses.append(pose) 89 | 90 | if seed_idx % save_interval == 0 and seed > 0: 91 | save_latents_and_scores(preds, ws, ages, poses, saving_batch_id, output_path) 92 | saving_batch_id = saving_batch_id + 1 93 | preds, ages, poses, ws = [], [], [], [] 94 | print(f'Generated {save_interval} images!') 95 | 96 | 97 | def save_latents_and_scores(preds: List[np.ndarray], ws: List[np.ndarray], ages: List[float], poses: List[float], 98 | batch_id: int, output_path: Path): 99 | ws = np.vstack(ws) 100 | preds = np.array(preds) 101 | ages = np.vstack(ages) 102 | poses = np.vstack(poses) 103 | dir_path = output_path / f'id_{batch_id}' 104 | dir_path.mkdir(exist_ok=True, parents=True) 105 | np.save(dir_path / 'ws.npy', ws) 106 | np.save(dir_path / 'scores.npy', preds) 107 | np.save(dir_path / 'ages.npy', ages) 108 | np.save(dir_path / 'poses.npy', poses) 109 | 110 | 111 | if __name__ == '__main__': 112 | run() 113 | -------------------------------------------------------------------------------- /inversion/scripts/inference_iterative.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import numpy as np 5 | import pyrallis 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from configs import data_configs 15 | from inversion.options.test_options import TestOptions 16 | from inversion.datasets.inference_dataset import InferenceDataset 17 | from utils.common import tensor2im 18 | from utils.inference_utils import get_average_image, run_on_batch, load_encoder 19 | 20 | 21 | @pyrallis.wrap() 22 | def run(test_opts: TestOptions): 23 | 24 | out_path_results = test_opts.output_path / 'inference_results' 25 | out_path_coupled = test_opts.output_path / 'inference_coupled' 26 | out_path_results.mkdir(exist_ok=True, parents=True) 27 | out_path_coupled.mkdir(exist_ok=True, parents=True) 28 | 29 | # update test options with options used during training 30 | net, opts = load_encoder(checkpoint_path=test_opts.checkpoint_path, test_opts=test_opts) 31 | 32 | print(f'Loading dataset for {opts.dataset_type}') 33 | dataset_args = data_configs.DATASETS[opts.dataset_type] 34 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 35 | dataset = InferenceDataset(root=opts.data_path, 36 | landmarks_transforms_path=opts.landmarks_transforms_path, 37 | transform=transforms_dict['transform_inference']) 38 | dataloader = DataLoader(dataset, 39 | batch_size=opts.test_batch_size, 40 | shuffle=False, 41 | num_workers=int(opts.test_workers), 42 | drop_last=False) 43 | 44 | if opts.n_images is None: 45 | opts.n_images = len(dataset) 46 | 47 | # get the image corresponding to the latent average 48 | avg_image = get_average_image(net) 49 | 50 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 51 | 52 | global_i = 0 53 | global_time = [] 54 | all_latents = {} 55 | for input_batch in tqdm(dataloader): 56 | if global_i >= opts.n_images: 57 | break 58 | 59 | with torch.no_grad(): 60 | input_batch, landmarks_transform = input_batch 61 | tic = time.time() 62 | result_batch, result_latents = run_on_batch(inputs=input_batch.cuda().float(), 63 | net=net, 64 | opts=opts, 65 | avg_image=avg_image, 66 | landmarks_transform=landmarks_transform.cuda().float()) 67 | toc = time.time() 68 | global_time.append(toc - tic) 69 | 70 | for i in range(input_batch.shape[0]): 71 | results = [tensor2im(result_batch[i][iter_idx]) for iter_idx in range(opts.n_iters_per_batch)] 72 | im_path = dataset.paths[global_i] 73 | 74 | # save individual step results 75 | for idx, result in enumerate(results): 76 | save_dir = out_path_results / str(idx) 77 | save_dir.mkdir(exist_ok=True, parents=True) 78 | result.resize(resize_amount).save(save_dir / im_path.name) 79 | 80 | # save step-by-step results side-by-side 81 | input_im = tensor2im(input_batch[i]) 82 | res = np.array(results[0].resize(resize_amount)) 83 | for idx, result in enumerate(results[1:]): 84 | res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1) 85 | res = np.concatenate([res, input_im.resize(resize_amount)], axis=1) 86 | Image.fromarray(res).save(out_path_coupled / im_path.name) 87 | 88 | # store all latents with dict pairs (image_name, latents) 89 | all_latents[im_path.name] = result_latents[i] 90 | 91 | global_i += 1 92 | 93 | stats_path = opts.output_path / 'stats.txt' 94 | result_str = f'Runtime {np.mean(global_time):.4f}+-{np.std(global_time):.4f}' 95 | print(result_str) 96 | 97 | with open(stats_path, 'w') as f: 98 | f.write(result_str) 99 | 100 | # save all latents as npy file 101 | np.save(test_opts.output_path / 'latents.npy', all_latents) 102 | 103 | 104 | if __name__ == '__main__': 105 | run() 106 | -------------------------------------------------------------------------------- /utils/inference_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | import dataclasses 5 | import torch 6 | from torchvision import transforms 7 | 8 | from configs.paths_config import model_paths 9 | from inversion.models.e4e3 import e4e 10 | from inversion.models.psp3 import pSp 11 | from inversion.options.e4e_train_options import e4eTrainOptions 12 | from inversion.options.test_options import TestOptions 13 | from inversion.options.train_options import TrainOptions 14 | from models.stylegan3.model import SG3Generator 15 | from utils.model_utils import ENCODER_TYPES 16 | 17 | IMAGE_TRANSFORMS = transforms.Compose([ 18 | transforms.Resize((256, 256)), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 21 | 22 | FULL_IMAGE_TRANSFORMS = transforms.Compose([ 23 | transforms.Resize((1024, 1024)), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 26 | 27 | 28 | def load_encoder(checkpoint_path: Path, test_opts: Optional[TestOptions] = None, generator_path: Optional[Path] = None): 29 | ckpt = torch.load(checkpoint_path, map_location='cpu') 30 | opts = ckpt['opts'] 31 | opts["checkpoint_path"] = checkpoint_path 32 | 33 | if opts['stylegan_weights'] == Path(model_paths["stylegan3_ffhq"]): 34 | opts['stylegan_weights'] = Path(model_paths["stylegan3_ffhq_pt"]) 35 | if opts['stylegan_weights'] == Path(model_paths["stylegan3_ffhq_unaligned"]): 36 | opts['stylegan_weights'] = Path(model_paths["stylegan3_ffhq_unaligned_pt"]) 37 | 38 | if opts["encoder_type"] in ENCODER_TYPES['pSp']: 39 | opts = TrainOptions(**opts) 40 | if test_opts is not None: 41 | opts.update(dataclasses.asdict(test_opts)) 42 | net = pSp(opts) 43 | else: 44 | opts = e4eTrainOptions(**opts) 45 | if test_opts is not None: 46 | opts.update(dataclasses.asdict(test_opts)) 47 | net = e4e(opts) 48 | 49 | print('Model successfully loaded!') 50 | if generator_path is not None: 51 | print(f"Updating SG3 generator with generator from path: {generator_path}") 52 | net.decoder = SG3Generator(checkpoint_path=generator_path).decoder 53 | 54 | net.eval() 55 | net.cuda() 56 | return net, opts 57 | 58 | 59 | def get_average_image(net): 60 | avg_image = net(net.latent_avg.repeat(16, 1).unsqueeze(0).cuda(), 61 | input_code=True, 62 | return_latents=False)[0] 63 | avg_image = avg_image.to('cuda').float().detach() 64 | return avg_image 65 | 66 | 67 | def run_on_batch(inputs: torch.tensor, net, opts: TrainOptions, avg_image: torch.tensor, 68 | landmarks_transform: Optional[torch.tensor] = None): 69 | results_batch = {idx: [] for idx in range(inputs.shape[0])} 70 | results_latent = {idx: [] for idx in range(inputs.shape[0])} 71 | y_hat, latent = None, None 72 | if "resize_outputs" not in dataclasses.asdict(opts): 73 | opts.resize_outputs = False 74 | 75 | for iter in range(opts.n_iters_per_batch): 76 | if iter == 0: 77 | avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1) 78 | x_input = torch.cat([inputs, avg_image_for_batch], dim=1) 79 | else: 80 | x_input = torch.cat([inputs, y_hat], dim=1) 81 | 82 | is_last_iteration = iter == opts.n_iters_per_batch - 1 83 | 84 | res = net.forward(x_input, 85 | latent=latent, 86 | landmarks_transform=landmarks_transform, 87 | return_aligned_and_unaligned=True, 88 | return_latents=True, 89 | resize=opts.resize_outputs) 90 | 91 | # if no landmark transforms are given, return the aligned output image 92 | if landmarks_transform is None: 93 | y_hat, latent = res 94 | 95 | # otherwise, if current iteration is not the last, return the aligned output; else return final unaligned output 96 | else: 97 | # note: res = images, unaligned_images, codes 98 | if is_last_iteration: 99 | _, y_hat, latent = res 100 | else: 101 | y_hat, _, latent = res 102 | 103 | # store intermediate outputs 104 | for idx in range(inputs.shape[0]): 105 | results_batch[idx].append(y_hat[idx]) 106 | results_latent[idx].append(latent[idx].cpu().numpy()) 107 | 108 | # resize input to 256 before feeding into next iteration 109 | y_hat = net.face_pool(y_hat) 110 | 111 | return results_batch, results_latent 112 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /inversion/models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 5 | from .first_stage import run_first_stage 6 | from .get_nets import PNet, RNet, ONet 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /editing/styleclip_mapper/latent_mappers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import Module 6 | from torch.nn import functional as F 7 | 8 | 9 | class Mapper(Module): 10 | 11 | def __init__(self, opts, latent_dim=512): 12 | super(Mapper, self).__init__() 13 | 14 | self.opts = opts 15 | layers = [PixelNorm()] 16 | 17 | for i in range(4): 18 | layers.append( 19 | EqualLinear( 20 | latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu' 21 | ) 22 | ) 23 | 24 | self.mapping = nn.Sequential(*layers) 25 | 26 | 27 | def forward(self, x): 28 | x = self.mapping(x) 29 | return x 30 | 31 | 32 | class SingleMapper(Module): 33 | 34 | def __init__(self, opts): 35 | super(SingleMapper, self).__init__() 36 | 37 | self.opts = opts 38 | 39 | self.mapping = Mapper(opts) 40 | 41 | def forward(self, x): 42 | out = self.mapping(x) 43 | return out 44 | 45 | 46 | class LevelsMapper(Module): 47 | 48 | def __init__(self, opts): 49 | super(LevelsMapper, self).__init__() 50 | 51 | self.opts = opts 52 | 53 | if not opts.no_coarse_mapper: 54 | self.course_mapping = Mapper(opts) 55 | if not opts.no_medium_mapper: 56 | self.medium_mapping = Mapper(opts) 57 | if not opts.no_fine_mapper: 58 | self.fine_mapping = Mapper(opts) 59 | 60 | def forward(self, x): 61 | x_coarse = x[:, :5, :] 62 | x_medium = x[:, 5:8, :] 63 | x_fine = x[:, 8:, :] 64 | 65 | if not self.opts.no_coarse_mapper: 66 | x_coarse = self.course_mapping(x_coarse) 67 | else: 68 | x_coarse = torch.zeros_like(x_coarse) 69 | if not self.opts.no_medium_mapper: 70 | x_medium = self.medium_mapping(x_medium) 71 | else: 72 | x_medium = torch.zeros_like(x_medium) 73 | if not self.opts.no_fine_mapper: 74 | x_fine = self.fine_mapping(x_fine) 75 | else: 76 | x_fine = torch.zeros_like(x_fine) 77 | 78 | out = torch.cat([x_coarse, x_medium, x_fine], dim=1) 79 | 80 | return out 81 | 82 | class FusedLeakyReLU(nn.Module): 83 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 95 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 96 | input = input.cuda() 97 | if input.ndim == 3: 98 | return ( 99 | F.leaky_relu( 100 | input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope 101 | ) 102 | * scale 103 | ) 104 | else: 105 | return ( 106 | F.leaky_relu( 107 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope 108 | ) 109 | * scale 110 | ) 111 | 112 | class EqualLinear(nn.Module): 113 | def __init__( 114 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 115 | ): 116 | super().__init__() 117 | 118 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 119 | 120 | if bias: 121 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 122 | 123 | else: 124 | self.bias = None 125 | 126 | self.activation = activation 127 | 128 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 129 | self.lr_mul = lr_mul 130 | 131 | def forward(self, input): 132 | if self.activation: 133 | out = F.linear(input, self.weight * self.scale) 134 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 135 | 136 | else: 137 | out = F.linear( 138 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 139 | ) 140 | 141 | return out 142 | 143 | def __repr__(self): 144 | return ( 145 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 146 | ) 147 | 148 | class PixelNorm(nn.Module): 149 | def __init__(self): 150 | super().__init__() 151 | 152 | def forward(self, input): 153 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 154 | 155 | 156 | def make_kernel(k): 157 | k = torch.tensor(k, dtype=torch.float32) 158 | 159 | if k.ndim == 1: 160 | k = k[None, :] * k[:, None] 161 | 162 | k /= k.sum() 163 | 164 | return k -------------------------------------------------------------------------------- /notebooks/notebook_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pydrive.auth import GoogleAuth 3 | from pydrive.drive import GoogleDrive 4 | from google.colab import auth 5 | from oauth2client.client import GoogleCredentials 6 | import dlib 7 | import subprocess 8 | 9 | from utils.alignment_utils import align_face, crop_face, get_stylegan_transform 10 | 11 | 12 | ENCODER_PATHS = { 13 | "restyle_e4e_ffhq": {"id": "1z_cB187QOc6aqVBdLvYvBjoc93-_EuRm", "name": "restyle_e4e_ffhq.pt"}, 14 | "restyle_pSp_ffhq": {"id": "12WZi2a9ORVg-j6d9x4eF-CKpLaURC2W-", "name": "restyle_pSp_ffhq.pt"}, 15 | } 16 | INTERFACEGAN_PATHS = { 17 | "age": {'id': '1NQVOpKX6YZKVbz99sg94HiziLXHMUbFS', 'name': 'age_boundary.npy'}, 18 | "smile": {'id': '1KgfJleIjrKDgdBTN4vAz0XlgSaa9I99R', 'name': 'Smiling_boundary.npy'}, 19 | "pose": {'id': '1nCzCR17uaMFhAjcg6kFyKnCCxAKOCT2d', 'name': 'pose_boundary.npy'}, 20 | "Male": {'id': '18dpXS5j1h54Y3ah5HaUpT03y58Ze2YEY', 'name': 'Male_boundary.npy'} 21 | } 22 | STYLECLIP_PATHS = { 23 | "delta_i_c": {"id": "1HOUGvtumLFwjbwOZrTbIloAwBBzs2NBN", "name": "delta_i_c.npy"}, 24 | "s_stats": {"id": "1FVm_Eh7qmlykpnSBN1Iy533e_A2xM78z", "name": "s_stats"}, 25 | } 26 | 27 | 28 | class Downloader: 29 | 30 | def __init__(self, code_dir, use_pydrive, subdir): 31 | self.use_pydrive = use_pydrive 32 | current_directory = os.getcwd() 33 | self.save_dir = os.path.join(os.path.dirname(current_directory), code_dir, subdir) 34 | os.makedirs(self.save_dir, exist_ok=True) 35 | if self.use_pydrive: 36 | self.authenticate() 37 | 38 | def authenticate(self): 39 | auth.authenticate_user() 40 | gauth = GoogleAuth() 41 | gauth.credentials = GoogleCredentials.get_application_default() 42 | self.drive = GoogleDrive(gauth) 43 | 44 | def download_file(self, file_id, file_name): 45 | file_dst = f'{self.save_dir}/{file_name}' 46 | if os.path.exists(file_dst): 47 | print(f'{file_name} already exists!') 48 | return 49 | if self.use_pydrive: 50 | downloaded = self.drive.CreateFile({'id': file_id}) 51 | downloaded.FetchMetadata(fetch_all=True) 52 | downloaded.GetContentFile(file_dst) 53 | else: 54 | command = self._get_download_model_command(file_id=file_id, file_name=file_name) 55 | subprocess.run(command, shell=True, stdout=subprocess.PIPE) 56 | 57 | def _get_download_model_command(self, file_id, file_name): 58 | """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """ 59 | url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=self.save_dir) 60 | return url 61 | 62 | 63 | def download_dlib_models(): 64 | if not os.path.exists("shape_predictor_68_face_landmarks.dat"): 65 | #print('Downloading files for aligning face image...') 66 | os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2') 67 | os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2') 68 | #print('Done.') 69 | 70 | 71 | def run_alignment(image_path): 72 | download_dlib_models() 73 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 74 | detector = dlib.get_frontal_face_detector() 75 | #print("Aligning image...") 76 | aligned_image = align_face(filepath=str(image_path), detector=detector, predictor=predictor) 77 | #print(f"Finished aligning image: {image_path}") 78 | return aligned_image 79 | 80 | 81 | def crop_image(image_path): 82 | download_dlib_models() 83 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 84 | detector = dlib.get_frontal_face_detector() 85 | #print("Cropping image...") 86 | cropped_image = crop_face(filepath=str(image_path), detector=detector, predictor=predictor) 87 | #print(f"Finished cropping image: {image_path}") 88 | return cropped_image 89 | 90 | 91 | def compute_transforms(aligned_path, cropped_path): 92 | download_dlib_models() 93 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 94 | detector = dlib.get_frontal_face_detector() 95 | #print("Computing landmarks-based transforms...") 96 | res = get_stylegan_transform(str(cropped_path), str(aligned_path), detector, predictor) 97 | #print("Done!") 98 | if res is None: 99 | print(f"Failed computing transforms on: {cropped_path}") 100 | return 101 | else: 102 | rotation_angle, translation, transform, inverse_transform = res 103 | return inverse_transform 104 | -------------------------------------------------------------------------------- /editing/interfacegan/helpers/manipulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import svm 3 | 4 | 5 | """ 6 | Code is adopted from InterFaceGAN (https://github.com/genforce/interfacegan/blob/master/utils/manipulator.py) 7 | """ 8 | 9 | 10 | def train_boundary(latent_codes, scores, chosen_num_or_ratio=0.02, split_ratio=0.7, invalid_value=None): 11 | """Trains boundary in latent space with offline predicted attribute scores.""" 12 | 13 | if (not isinstance(latent_codes, np.ndarray) or 14 | not len(latent_codes.shape) == 2): 15 | raise ValueError(f'Input `latent_codes` should be with type' 16 | f'`numpy.ndarray`, and shape [num_samples, ' 17 | f'latent_space_dim]!') 18 | num_samples = latent_codes.shape[0] 19 | latent_space_dim = latent_codes.shape[1] 20 | if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or 21 | not scores.shape[0] == num_samples or not scores.shape[1] == 1): 22 | raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and ' 23 | f'shape [num_samples, 1], where `num_samples` should be ' 24 | f'exactly same as that of input `latent_codes`!') 25 | if chosen_num_or_ratio <= 0: 26 | raise ValueError(f'Input `chosen_num_or_ratio` should be positive, ' 27 | f'but {chosen_num_or_ratio} received!') 28 | 29 | print(f'Filtering training data.') 30 | if invalid_value is not None: 31 | latent_codes = latent_codes[scores[:, 0] != invalid_value] 32 | scores = scores[scores[:, 0] != invalid_value] 33 | 34 | print(f'Sorting scores to get positive and negative samples.') 35 | sorted_idx = np.argsort(scores, axis=0)[::-1, 0] 36 | latent_codes = latent_codes[sorted_idx] 37 | scores = scores[sorted_idx] 38 | num_samples = latent_codes.shape[0] 39 | if 0 < chosen_num_or_ratio <= 1: 40 | chosen_num = int(num_samples * chosen_num_or_ratio) 41 | else: 42 | chosen_num = int(chosen_num_or_ratio) 43 | chosen_num = min(chosen_num, num_samples // 2) 44 | 45 | print(f'Spliting training and validation sets:') 46 | train_num = int(chosen_num * split_ratio) 47 | val_num = chosen_num - train_num 48 | # Positive samples. 49 | positive_idx = np.arange(chosen_num) 50 | np.random.shuffle(positive_idx) 51 | positive_train = latent_codes[:chosen_num][positive_idx[:train_num]] 52 | positive_val = latent_codes[:chosen_num][positive_idx[train_num:]] 53 | # Negative samples. 54 | negative_idx = np.arange(chosen_num) 55 | np.random.shuffle(negative_idx) 56 | negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]] 57 | negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]] 58 | # Training set. 59 | train_data = np.concatenate([positive_train, negative_train], axis=0) 60 | train_label = np.concatenate([np.ones(train_num, dtype=np.int), 61 | np.zeros(train_num, dtype=np.int)], axis=0) 62 | print(f' Training: {train_num} positive, {train_num} negative.') 63 | # Validation set. 64 | val_data = np.concatenate([positive_val, negative_val], axis=0) 65 | val_label = np.concatenate([np.ones(val_num, dtype=np.int), 66 | np.zeros(val_num, dtype=np.int)], axis=0) 67 | print(f' Validation: {val_num} positive, {val_num} negative.') 68 | # Remaining set. 69 | remaining_num = num_samples - chosen_num * 2 70 | remaining_data = latent_codes[chosen_num:-chosen_num] 71 | remaining_scores = scores[chosen_num:-chosen_num] 72 | decision_value = (scores[0] + scores[-1]) / 2 73 | remaining_label = np.ones(remaining_num, dtype=np.int) 74 | remaining_label[remaining_scores.ravel() < decision_value] = 0 75 | remaining_positive_num = np.sum(remaining_label == 1) 76 | remaining_negative_num = np.sum(remaining_label == 0) 77 | print(f' Remaining: {remaining_positive_num} positive, ' 78 | f'{remaining_negative_num} negative.') 79 | 80 | print(f'Training boundary.') 81 | clf = svm.SVC(kernel='linear') 82 | classifier = clf.fit(train_data, train_label) 83 | print(f'Finish training.') 84 | 85 | if val_num: 86 | val_prediction = classifier.predict(val_data) 87 | correct_num = np.sum(val_label == val_prediction) 88 | print(f'Accuracy for validation set: ' 89 | f'{correct_num} / {val_num * 2} = ' 90 | f'{correct_num / (val_num * 2):.6f}') 91 | 92 | if remaining_num: 93 | remaining_prediction = classifier.predict(remaining_data) 94 | correct_num = np.sum(remaining_label == remaining_prediction) 95 | print(f'Accuracy for remaining set: ' 96 | f'{correct_num} / {remaining_num} = ' 97 | f'{correct_num / remaining_num:.6f}') 98 | 99 | a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32) 100 | return a / np.linalg.norm(a) 101 | -------------------------------------------------------------------------------- /inversion/video/video_handler.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import cv2 5 | import dlib 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | from prepare_data.preparing_faces_parallel import SHAPE_PREDICTOR_PATH 10 | from utils.alignment_utils import align_face, get_alignment_transformation, get_alignment_positions 11 | 12 | 13 | class VideoHandler: 14 | """ Parses a given video and stores the raw, aligned, and cropped video frames. """ 15 | def __init__(self, video_path: Path, output_path: Path, raw_frames_path: Path = None, 16 | aligned_frames_path: Path = None, cropped_frames_path: Path = None): 17 | self.video = cv2.VideoCapture(str(video_path)) 18 | self.frame_count = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) 19 | self.fps = self.video.get(cv2.CAP_PROP_FPS) 20 | self.raw_frames_path = output_path / "raw_frames" if raw_frames_path is None else raw_frames_path 21 | self.aligned_frames_path = output_path / "aligned_frames" if aligned_frames_path is None else aligned_frames_path 22 | self.cropped_frames_path = output_path / "cropped_frames" if cropped_frames_path is None else cropped_frames_path 23 | self.raw_frames_path.mkdir(exist_ok=True, parents=True) 24 | self.cropped_frames_path.mkdir(exist_ok=True, parents=True) 25 | self.aligned_frames_path.mkdir(exist_ok=True, parents=True) 26 | 27 | def parse_video(self): 28 | """ Gets the raw, aligned, and cropped video frames. If they are already saved, uses the pre-saved images. """ 29 | # get raw video frames 30 | if len(list(self.raw_frames_path.glob("*"))) == 0: 31 | frames_paths = self._parse_raw_video_frames() 32 | else: 33 | frames_paths = [self.raw_frames_path / f for f in self.raw_frames_path.iterdir()] 34 | # get aligned video frames 35 | if len(list(self.aligned_frames_path.glob("*"))) == 0: 36 | self._save_aligned_video_frames(frames_paths) 37 | else: 38 | print(f"Aligned video frames already saved to: {self.aligned_frames_path}") 39 | # get all cropped video frames 40 | if len(list(self.cropped_frames_path.glob("*"))) == 0: 41 | self._save_cropped_video_frames(frames_paths) 42 | else: 43 | print(f"Cropped video frames already saved to: {self.cropped_frames_path}") 44 | 45 | def get_input_paths(self): 46 | sorted_paths = sorted(self.aligned_frames_path.iterdir(), key=lambda x: int(str(x.name).replace(".jpg", ""))) 47 | file_names = [f.name for f in sorted_paths] 48 | aligned_paths = [self.aligned_frames_path / file_name for file_name in file_names] 49 | cropped_paths = [self.cropped_frames_path / file_name for file_name in file_names] 50 | return aligned_paths, cropped_paths 51 | 52 | @staticmethod 53 | def load_images(input_paths: List[Path]): 54 | input_images = [Image.open(input_path).convert("RGB") for input_path in input_paths] 55 | return input_images 56 | 57 | def _parse_raw_video_frames(self): 58 | frames_paths = [] 59 | print("Parsing video!") 60 | for frame_idx in tqdm(range(self.frame_count)): 61 | ret, frame = self.video.read() 62 | if not ret: 63 | continue 64 | im_path = self.raw_frames_path / f"{str(frame_idx).zfill(4)}.jpg" 65 | Image.fromarray(frame[:, :, ::-1]).convert("RGB").save(im_path) 66 | frames_paths.append(im_path) 67 | return frames_paths 68 | 69 | def _save_aligned_video_frames(self, frames_paths: List[Path]): 70 | print("Saving aligned video frames...") 71 | predictor = dlib.shape_predictor(str(SHAPE_PREDICTOR_PATH)) 72 | detector = dlib.get_frontal_face_detector() 73 | for path in tqdm(frames_paths): 74 | try: 75 | im = align_face(filepath=str(path), detector=detector, predictor=predictor).convert("RGB") 76 | im.save(self.aligned_frames_path / path.name) 77 | except Exception as e: 78 | print(e) 79 | continue 80 | 81 | def _save_cropped_video_frames(self, frames_paths: List[Path]): 82 | print("Saving cropped video frames...") 83 | detector = dlib.get_frontal_face_detector() 84 | predictor = dlib.shape_predictor(str(SHAPE_PREDICTOR_PATH)) 85 | # crops all the video frames according to the alignment of the first frame 86 | c, x, y = get_alignment_positions(str(frames_paths[0]), detector, predictor) 87 | alignment_transform, _ = get_alignment_transformation(c, x, y) 88 | alignment_transform = (alignment_transform + 0.5).flatten() 89 | for path in tqdm(frames_paths): 90 | try: 91 | curr_im = Image.open(path) 92 | curr_im = curr_im.transform((1024, 1024), Image.QUAD, alignment_transform, Image.BILINEAR) 93 | curr_im.save(self.cropped_frames_path / path.name) 94 | except Exception as e: 95 | print(e) 96 | continue 97 | -------------------------------------------------------------------------------- /inversion/scripts/create_inversion_animation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from typing import Optional, List 4 | 5 | import numpy as np 6 | import pyrallis 7 | import torch 8 | from PIL import Image 9 | from dataclasses import dataclass 10 | from tqdm import tqdm 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from configs.paths_config import model_paths 16 | from models.stylegan3.model import SG3Generator 17 | from models.stylegan3.networks_stylegan3 import Generator 18 | from utils.common import tensor2im, generate_mp4 19 | 20 | 21 | RESIZE_AMOUNT = (1024, 1024) 22 | N_TRANSITIONS = 25 23 | SIZE = RESIZE_AMOUNT[0] 24 | 25 | 26 | @dataclass 27 | class RunConfig: 28 | # Where to save the animations 29 | output_path: Path 30 | # Path to directory of images to add to the animation 31 | data_path: Path 32 | # Path to `npy` file containing the inverted latents 33 | latents_path: Path 34 | # Path to StyleGAN3 generator 35 | generator_path: Path = Path(model_paths["stylegan3_ffhq"]) 36 | # Path to `npy` with the transformations used for generating the unaligned images 37 | landmarks_transforms_path: Optional[Path] = None 38 | # Number of images to include in animation. If None, run on all data 39 | n_images: Optional[int] = None 40 | # Fps of the generated animations 41 | fps: int = 15 42 | 43 | 44 | @pyrallis.wrap() 45 | def main(opts: RunConfig): 46 | decoder = SG3Generator(checkpoint_path=opts.generator_path).decoder 47 | 48 | latents = np.load(opts.latents_path, allow_pickle=True).item() 49 | landmarks_transforms = np.load(opts.landmarks_transforms_path, allow_pickle=True).item() 50 | image_names = list(latents.keys()) 51 | if opts.n_images is not None: 52 | image_names = np.random.choice(image_names, size=opts.n_images, replace=False) 53 | image_paths = [opts.data_path / image_name for image_name in image_names] 54 | 55 | in_images = [] 56 | all_vecs = [] 57 | all_landmarks_transforms = [] 58 | for image_path in image_paths: 59 | print(f'Working on {image_path.name}...') 60 | original_image = Image.open(image_path).convert("RGB") 61 | latent = latents[image_path.name][-1] 62 | landmark_transform = landmarks_transforms[image_path.name][-1] 63 | all_vecs.append([latent]) 64 | all_landmarks_transforms.append(landmark_transform) 65 | in_images.append(original_image.resize(RESIZE_AMOUNT)) 66 | 67 | image_paths.append(image_paths[0]) 68 | all_vecs.append(all_vecs[0]) 69 | all_landmarks_transforms.append(all_landmarks_transforms[0]) 70 | in_images.append(in_images[0]) 71 | 72 | all_images = [] 73 | for i in range(1, len(image_paths)): 74 | if i == 0: 75 | alpha_vals = [0] * 10 + np.linspace(0, 1, N_TRANSITIONS).tolist() + [1] * 5 76 | else: 77 | alpha_vals = [0] * 5 + np.linspace(0, 1, N_TRANSITIONS).tolist() + [1] * 5 78 | 79 | for alpha in tqdm(alpha_vals): 80 | image_a = np.array(in_images[i - 1]) 81 | image_b = np.array(in_images[i]) 82 | image_joint = np.zeros_like(image_a) 83 | up_to_row = int((SIZE - 1) * alpha) 84 | if up_to_row > 0: 85 | image_joint[:(up_to_row + 1), :, :] = image_b[((SIZE - 1) - up_to_row):, :, :] 86 | if up_to_row < (SIZE - 1): 87 | image_joint[up_to_row:, :, :] = image_a[:(SIZE - up_to_row), :, :] 88 | 89 | result_image = get_result_from_vecs(decoder, 90 | all_vecs[i - 1], all_vecs[i], 91 | all_landmarks_transforms[i - 1], all_landmarks_transforms[i], 92 | alpha)[0] 93 | 94 | output_im = tensor2im(result_image) 95 | res = np.concatenate([image_joint, np.array(output_im)], axis=1) 96 | all_images.append(res) 97 | 98 | kwargs = {'fps': opts.fps} 99 | opts.output_path.mkdir(exist_ok=True, parents=True) 100 | gif_path = opts.output_path / f"inversions_gif" 101 | generate_mp4(gif_path, all_images, kwargs) 102 | 103 | 104 | def get_result_from_vecs(generator: Generator, vectors_a: List[np.ndarray], vectors_b: List[np.ndarray], 105 | landmarks_a: np.ndarray, landmarks_b: np.ndarray, alpha: float): 106 | results = [] 107 | for i in range(len(vectors_a)): 108 | with torch.no_grad(): 109 | cur_vec = vectors_b[i] * alpha + vectors_a[i] * (1 - alpha) 110 | landmarks_transform = landmarks_b * alpha + landmarks_a * (1 - alpha) 111 | generator.synthesis.input.transform = torch.from_numpy(landmarks_transform).float().cuda().unsqueeze(0) 112 | res = generator.synthesis(torch.from_numpy(cur_vec).cuda().unsqueeze(0), noise_mode='const', force_fp32=True) 113 | results.append(res[0]) 114 | return results 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /inversion/video/video_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | 4 | from dataclasses import field, dataclass 5 | 6 | 7 | @dataclass 8 | class InterFaceGANEdit: 9 | direction: str 10 | start: int = -5 11 | end: int = 5 12 | 13 | 14 | @dataclass 15 | class StyleCLIPEdit: 16 | target_text: str 17 | alpha: float 18 | beta: float 19 | 20 | @property 21 | def save_name(self): 22 | return f'result_video_{"_".join(self.target_text.split())}_{self.alpha}_{self.beta}' 23 | 24 | 25 | @dataclass 26 | class VideoConfig: 27 | """ All arguments related to inverting and editing videos """ 28 | 29 | """ General input/output args """ 30 | # Path to the video to invert and edit 31 | video_path: Path 32 | # Path to the trained encoder to use for inversion 33 | checkpoint_path: Path 34 | # Path to the output directory 35 | output_path: Path 36 | # Path to pre-saved transforms for video (will extract and save to path if doesn't exist) 37 | landmarks_transforms_path: Optional[Path] = None 38 | # Optionally add a path to a generator to switch the generator (e.g., after training PTI) 39 | generator_path: Optional[Path] = None 40 | # Path to raw video frames to invert (will extract and save to path if doesn't exist) 41 | raw_frames_path: Optional[Path] = None 42 | # Path to aligned video frames to invert (will extract and save to path if doesn't exist) 43 | aligned_frames_path: Optional[Path] = None 44 | # Path to the cropped frames to invert (will extract and save to path if doesn't exist) 45 | cropped_frames_path: Optional[Path] = None 46 | 47 | """ Inference args """ 48 | # Number of ReStyle iterations to run per batch 49 | n_iters_per_batch: int = 3 50 | # Maximum number of images to invert in video. If None, inverts all images 51 | max_images: Optional[int] = None 52 | 53 | """ Field of view expansion args """ 54 | # Expansion amounts for field-of-view expansion given in [left, right, top, bottom] 55 | expansion_amounts: List[int] = field(default_factory=lambda: [0, 0, 0, 0]) 56 | 57 | """ Editing args """ 58 | # Comma-separated list of which edit directions to perform with InterFaceGAN 59 | interfacegan_directions: List[str] = field(default_factory=lambda: ['age']) 60 | # Comma-separated list of interfacegan ranges for each edit 61 | interfacegan_ranges: List[str] = field(default_factory=lambda: ['(-4_5)']) 62 | # Comma-separated list of which edit directions to perform with StyleCLIP 63 | styleclip_directions: List[str] = field(default_factory=lambda: ["a smiling face"]) 64 | # Comma-separated list of alpha and beta values for each edit. Eg., 0.13_4 -> beta=0.13, alpha=4 65 | styleclip_alpha_betas: List[str] = field(default_factory=lambda: ["(4_0.13)"]) 66 | interfacegan_edits = None 67 | styleclip_edits = None 68 | 69 | def __post_init__(self): 70 | self.interfacegan_edits = self._parse_interfacegan_edits() 71 | self.styleclip_edits = self._parse_styleclip_edits() 72 | 73 | def _parse_interfacegan_edits(self): 74 | factor_ranges = self._parse_factor_ranges() 75 | if len(self.interfacegan_directions) != len(factor_ranges): 76 | raise ValueError("Invalid edit directions and factor ranges. Please provide a single factor range for each " 77 | f"edit direction. Given: {self.interfacegan_directions} and {self.interfacegan_ranges}") 78 | interfacegan_edits = [] 79 | for edit_direction, factor_range in zip(self.interfacegan_directions, factor_ranges): 80 | edit = InterFaceGANEdit(direction=edit_direction, start=factor_range[0], end=factor_range[1]) 81 | interfacegan_edits.append(edit) 82 | return interfacegan_edits 83 | 84 | def _parse_factor_ranges(self): 85 | factor_ranges = [] 86 | for factor in self.interfacegan_ranges: 87 | start, end = factor.strip("()").split("_") 88 | factor_ranges.append((int(start), int(end))) 89 | return factor_ranges 90 | 91 | def _parse_styleclip_edits(self): 92 | alpha_betas = self._parse_styleclip_alpha_betas() 93 | if len(self.styleclip_directions) != len(alpha_betas): 94 | raise ValueError("Invalid edit directions and alpha-beta pairs. Please provide a single alpha-beta for each" 95 | f" edit direction. Given: {self.styleclip_directions} and {self.styleclip_alpha_betas}") 96 | styleclip_edits = [] 97 | for edit_direction, alpha_beta in zip(self.styleclip_directions, alpha_betas): 98 | alpha, beta = alpha_beta 99 | edit = StyleCLIPEdit(target_text=edit_direction, alpha=alpha, beta=beta) 100 | styleclip_edits.append(edit) 101 | return styleclip_edits 102 | 103 | def _parse_styleclip_alpha_betas(self): 104 | alpha_betas = [] 105 | for alpha_beta in self.styleclip_alpha_betas: 106 | alpha, beta = alpha_beta.strip("()").split("_") 107 | alpha_betas.append((float(alpha), float(beta))) 108 | return alpha_betas 109 | -------------------------------------------------------------------------------- /criteria/ms_ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Taken from https://github.com/jorge-pessoa/pytorch-msssim 8 | """ 9 | 10 | 11 | def gaussian(window_size, sigma): 12 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 13 | return gauss/gauss.sum() 14 | 15 | 16 | def create_window(window_size, channel=1): 17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 18 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 19 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 20 | return window 21 | 22 | 23 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 24 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 25 | if val_range is None: 26 | if torch.max(img1) > 128: 27 | max_val = 255 28 | else: 29 | max_val = 1 30 | 31 | if torch.min(img1) < -0.5: 32 | min_val = -1 33 | else: 34 | min_val = 0 35 | L = max_val - min_val 36 | else: 37 | L = val_range 38 | 39 | padd = 0 40 | (_, channel, height, width) = img1.size() 41 | if window is None: 42 | real_size = min(window_size, height, width) 43 | window = create_window(real_size, channel=channel).to(img1.device) 44 | 45 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 46 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 47 | 48 | mu1_sq = mu1.pow(2) 49 | mu2_sq = mu2.pow(2) 50 | mu1_mu2 = mu1 * mu2 51 | 52 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 53 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 54 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 55 | 56 | C1 = (0.01 * L) ** 2 57 | C2 = (0.03 * L) ** 2 58 | 59 | v1 = 2.0 * sigma12 + C2 60 | v2 = sigma1_sq + sigma2_sq + C2 61 | cs = v1 / v2 # contrast sensitivity 62 | 63 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 64 | 65 | if size_average: 66 | cs = cs.mean() 67 | ret = ssim_map.mean() 68 | else: 69 | cs = cs.mean(1).mean(1).mean(1) 70 | ret = ssim_map.mean(1).mean(1).mean(1) 71 | 72 | if full: 73 | return ret, cs 74 | return ret 75 | 76 | 77 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): 78 | device = img1.device 79 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 80 | levels = weights.size()[0] 81 | ssims = [] 82 | mcs = [] 83 | for _ in range(levels): 84 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 85 | 86 | # Relu normalize (not compliant with original definition) 87 | if normalize == "relu": 88 | ssims.append(torch.relu(sim)) 89 | mcs.append(torch.relu(cs)) 90 | else: 91 | ssims.append(sim) 92 | mcs.append(cs) 93 | 94 | img1 = F.avg_pool2d(img1, (2, 2)) 95 | img2 = F.avg_pool2d(img2, (2, 2)) 96 | 97 | ssims = torch.stack(ssims) 98 | mcs = torch.stack(mcs) 99 | 100 | # Simple normalize (not compliant with original definition) 101 | if normalize == "simple" or normalize == True: 102 | ssims = (ssims + 1) / 2 103 | mcs = (mcs + 1) / 2 104 | 105 | pow1 = mcs ** weights 106 | pow2 = ssims ** weights 107 | 108 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 109 | output = torch.prod(pow1[:-1]) * pow2[-1] 110 | return output 111 | 112 | 113 | # Classes to re-use window 114 | class SSIM(torch.nn.Module): 115 | def __init__(self, window_size=11, size_average=True, val_range=None): 116 | super(SSIM, self).__init__() 117 | self.window_size = window_size 118 | self.size_average = size_average 119 | self.val_range = val_range 120 | 121 | # Assume 1 channel for SSIM 122 | self.channel = 1 123 | self.window = create_window(window_size) 124 | 125 | def forward(self, img1, img2): 126 | (_, channel, _, _) = img1.size() 127 | 128 | if channel == self.channel and self.window.dtype == img1.dtype: 129 | window = self.window 130 | else: 131 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 132 | self.window = window 133 | self.channel = channel 134 | 135 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 136 | 137 | class MSSSIM(torch.nn.Module): 138 | def __init__(self, window_size=11, size_average=True, channel=3): 139 | super(MSSSIM, self).__init__() 140 | self.window_size = window_size 141 | self.size_average = size_average 142 | self.channel = channel 143 | 144 | def forward(self, img1, img2): 145 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) -------------------------------------------------------------------------------- /editing/styleclip_global_directions/preprocess/create_delta_i_c.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import clip 5 | import numpy as np 6 | import pyrallis 7 | import torch 8 | import torch.nn.functional as F 9 | from dataclasses import dataclass 10 | from torchvision.transforms import Normalize 11 | from tqdm import tqdm 12 | 13 | from configs.paths_config import model_paths 14 | from models.stylegan3.model import SG3Generator 15 | 16 | 17 | @dataclass 18 | class Options: 19 | """ Creating delta_i_c file for StyleCLIP's global directions """ 20 | 21 | """ StyleGAN Args """ 22 | # Path to StyleGAN model weights 23 | checkpoint_path: Path = Path(model_paths['stylegan3_ffhq_pt']) 24 | # Images resolution generated by the StyleGAN model 25 | stylegan_size: int = 1024 26 | # Is it the landscape model? If so, different init_kwargs are used to load the pretained StyleGAN model 27 | is_landscape: bool = False 28 | 29 | """ Precomputed StyleSpace properties """ 30 | # Path to S latent codes precomputed by s_statistics.py 31 | latents_s_path: Path = Path("stats/S") 32 | # Path to StyleSpace statistics precomputed by Path to StyleGAN model weights 33 | latents_statistics_path: Path = Path("stats/s_stats") 34 | 35 | """ Global Direction Args """ 36 | # Manipulation strength used to perturb the latent codes for obtaining directions in CLIP's space. Should be 5 for 37 | # FFHQ, and 10 for other domains 38 | manipulation_strength: int = 5 39 | 40 | """ General Args """ 41 | # Path to directory in which result files are saved 42 | results_path: Path = Path("delta_i_c") 43 | # Number of images used for computing delta_i_c. We used 300 in our experiments 44 | num_samples: int = 1 45 | 46 | 47 | def generate_images(stylegan_model, latents_s, batch_size=1): 48 | all_images = [] 49 | for i in range(0, latents_s['input'].shape[0], batch_size): 50 | curr_latents_s = {l: latents_s[l][i:i + batch_size] for l in latents_s} 51 | with torch.no_grad(): 52 | curr_images = stylegan_model.synthesis(None, all_s=curr_latents_s, noise_mode='const') 53 | curr_images = F.interpolate(curr_images, size=(224, 224), mode='bicubic', align_corners=True) 54 | curr_images = (curr_images + 1) / 2 55 | curr_images = curr_images.clamp(0, 1) 56 | curr_images = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))(curr_images) 57 | all_images.append(curr_images) 58 | 59 | all_images = torch.cat(all_images) 60 | return all_images 61 | 62 | 63 | def get_clip_features(clip_model, images): 64 | images_reshaped = images.view(-1, 3, images.shape[-2], images.shape[-1]) 65 | processed_images = images_reshaped 66 | with torch.no_grad(): 67 | clip_features = clip_model.encode_image(processed_images) 68 | 69 | clip_features = clip_features.view(images.shape[0], 2, -1).unsqueeze(0) 70 | return clip_features 71 | 72 | 73 | def get_delta_i_c(clip_features): 74 | features_norm = np.linalg.norm(clip_features, axis=-1) 75 | normalized_features = clip_features / features_norm[:, :, :, None] 76 | delta_i_c = normalized_features[:, :, 1, :] - normalized_features[:, :, 0, :] 77 | normalized_delta_i_c = delta_i_c / np.linalg.norm(delta_i_c, axis=-1)[:, :, None] 78 | normalized_delta_i_c = normalized_delta_i_c.mean(axis=1) 79 | normalized_delta_i_c = normalized_delta_i_c / np.linalg.norm(normalized_delta_i_c, axis=-1)[:, None] 80 | return normalized_delta_i_c 81 | 82 | 83 | @pyrallis.wrap() 84 | def main(args: Options): 85 | args.results_path.mkdir(exist_ok=True, parents=True) 86 | 87 | G = SG3Generator(args.checkpoint_path, res=args.stylegan_size, 88 | config="landscape" if args.is_landscape else None).decoder 89 | 90 | clip_model, clip_preprocess = clip.load("ViT-B/32", device="cuda") 91 | 92 | latents_s = pickle.load(open(str(args.latents_s_path), "rb")) 93 | latents_s = {l: torch.from_numpy(latents_s[l][:args.num_samples]).float().cuda() for l in latents_s} 94 | transform, mean, std = pickle.load(open(str(args.latents_statistics_path), "rb")) 95 | 96 | all_clip_features = [] 97 | manipulated_images = torch.zeros(args.num_samples, 2, 3, 224, 224).cuda() 98 | 99 | for layer in latents_s.keys(): 100 | layer_channels = latents_s[layer].shape[1] 101 | for channel in tqdm(range(layer_channels)): 102 | for dir_idx, direction in enumerate([-args.manipulation_strength, args.manipulation_strength]): 103 | latents_s[layer][:, channel] = mean[layer][channel] + direction * std[layer][channel] 104 | curr_images = generate_images(G, latents_s, batch_size=1) 105 | manipulated_images[:, dir_idx] = curr_images 106 | clip_features = get_clip_features(clip_model, manipulated_images) 107 | all_clip_features.append(clip_features) 108 | 109 | all_clip_features = torch.cat(all_clip_features).detach().cpu().numpy() 110 | np.save(str(args.results_path / "clip_features.npy"), all_clip_features) 111 | 112 | delta_i_c = get_delta_i_c(all_clip_features) 113 | np.save(str(args.results_path / "delta_i_c.npy"), delta_i_c) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /inversion/models/psp3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from configs.paths_config import model_paths 5 | from inversion.models.encoders import restyle_psp_encoders 6 | from models.stylegan3.model import SG3Generator 7 | from utils import common 8 | 9 | 10 | class pSp(nn.Module): 11 | 12 | def __init__(self, opts): 13 | super(pSp, self).__init__() 14 | self.set_opts(opts) 15 | # Define architecture 16 | self.n_styles = 16 17 | self.encoder = self.set_encoder() 18 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 19 | # Load weights if needed 20 | self.load_weights() 21 | 22 | def set_encoder(self): 23 | if self.opts.encoder_type == 'BackboneEncoder': 24 | encoder = restyle_psp_encoders.BackboneEncoder(50, 'ir_se', self.n_styles, self.opts) 25 | elif self.opts.encoder_type == 'ResNetBackboneEncoder': 26 | encoder = restyle_psp_encoders.ResNetBackboneEncoder(self.n_styles, self.opts) 27 | else: 28 | raise Exception(f'{self.opts.encoder_type} is not a valid encoders') 29 | return encoder 30 | 31 | def load_weights(self): 32 | if self.opts.checkpoint_path is not None: 33 | print(f'Loading ReStyle pSp from checkpoint: {self.opts.checkpoint_path}') 34 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 35 | self.encoder.load_state_dict(self._get_keys(ckpt, 'encoder'), strict=True) 36 | self.decoder = SG3Generator(checkpoint_path=None).decoder 37 | self.decoder.load_state_dict(self._get_keys(ckpt, 'decoder', remove=["synthesis.input.transform"]), strict=False) 38 | self._load_latent_avg(ckpt) 39 | else: 40 | encoder_ckpt = self._get_encoder_checkpoint() 41 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 42 | self.decoder = SG3Generator(checkpoint_path=self.opts.stylegan_weights).decoder 43 | self.latent_avg = self.decoder.mapping.w_avg 44 | 45 | def forward(self, x, latent=None, resize=True, input_code=False, landmarks_transform=None, 46 | return_latents=False, return_aligned_and_unaligned=False): 47 | 48 | images, unaligned_images = None, None 49 | 50 | if input_code: 51 | codes = x 52 | else: 53 | codes = self.encoder(x) 54 | # residual step 55 | if x.shape[1] == 6 and latent is not None: 56 | # learn error with respect to previous iteration 57 | codes = codes + latent 58 | else: 59 | # first iteration is with respect to the avg latent code 60 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 61 | 62 | # generate the aligned images 63 | identity_transform = common.get_identity_transform() 64 | identity_transform = torch.from_numpy(identity_transform).unsqueeze(0).repeat(x.shape[0], 1, 1).cuda().float() 65 | self.decoder.synthesis.input.transform = identity_transform 66 | images = self.decoder.synthesis(codes, noise_mode='const', force_fp32=True) 67 | 68 | if resize: 69 | images = self.face_pool(images) 70 | 71 | # generate the unaligned image using the user-specified transforms 72 | if landmarks_transform is not None: 73 | self.decoder.synthesis.input.transform = landmarks_transform.float() # size: [batch_size, 3, 3] 74 | unaligned_images = self.decoder.synthesis(codes, noise_mode='const', force_fp32=True) 75 | if resize: 76 | unaligned_images = self.face_pool(unaligned_images) 77 | 78 | if landmarks_transform is not None and return_aligned_and_unaligned: 79 | return images, unaligned_images, codes 80 | 81 | if return_latents: 82 | return images, codes 83 | else: 84 | return images 85 | 86 | def set_opts(self, opts): 87 | self.opts = opts 88 | 89 | def _load_latent_avg(self, ckpt, repeat=None): 90 | if 'latent_avg' in ckpt: 91 | self.latent_avg = ckpt['latent_avg'].to("cuda") 92 | if repeat is not None: 93 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 94 | else: 95 | self.latent_avg = None 96 | 97 | def _get_encoder_checkpoint(self): 98 | print('Loading encoders weights from irse50!') 99 | encoder_ckpt = torch.load(model_paths['ir_se50']) 100 | # Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder 101 | if self.opts.input_nc != 3: 102 | shape = encoder_ckpt['input_layer.0.weight'].shape 103 | altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32) 104 | altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight'] 105 | encoder_ckpt['input_layer.0.weight'] = altered_input_layer 106 | return encoder_ckpt 107 | 108 | @staticmethod 109 | def _get_keys(d, name, remove=[]): 110 | if 'state_dict' in d: 111 | d = d['state_dict'] 112 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() 113 | if k[:len(name)] == name and k[len(name) + 1:] not in remove} 114 | return d_filt 115 | --------------------------------------------------------------------------------