├── README.md ├── criteria ├── __init__.py ├── cx_style_loss.py └── id_loss.py ├── demo └── teaser.jpg ├── examples ├── img │ ├── 001.jpg │ ├── 002.jpg │ ├── 003.jpg │ ├── 004.jpg │ ├── 005.jpg │ └── 006.jpg ├── pair_reenact.txt ├── pair_swap.txt ├── reenact │ └── pair.txt └── swap │ └── pair.txt ├── generate_reenact.py ├── generate_swap.py ├── install.sh ├── train.py ├── train_reenact.py ├── train_swap.py ├── training ├── __init__.py ├── base_dataset.py ├── dataset.py ├── dataset_ddp.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── model.py ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── pose.py └── vgg.py └── utils ├── common.py └── flow_utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Designing One Unified Framework for High-Fidelity Face Reenactment and Swapping 2 | This repository contains the official PyTorch implementation of the paper *Designing One Unified Framework for High-Fidelity Face Reenactment and Swapping* (ECCV2022). 3 | 4 | ![avatar](demo/teaser.jpg) 5 | 6 | ## Using the Code 7 | 8 | ### Requirements 9 | ``` 10 | conda create -y -n uniface python=3.6.12 11 | conda activate uniface 12 | ./install.sh 13 | ``` 14 | 15 | ### Data 16 | Please download CelebA-HQ in `data`. 17 | 18 | Please download VoxCeleb2 in `data` and follow the instrcution in [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) official repository to perform preprocessing. 19 | 20 | ### Inference 21 | Please put test images to `examples` and create `pair.txt` to indicate the source and target file names. For example, `001_002` means the source file name is `001` and the target is `002`. 22 | Please put pre-trained models in `session`. We release the separately trained models [reenact](https://drive.google.com/file/d/1Y-Sm-_HmvPSBwz16Ol8apwnBUCYU4_tH/view?usp=sharing) and [swap](https://drive.google.com/file/d/1H3DHfld_M5F940bZMuoT_bDkW3tJZvbq/view?usp=sharing), the unified one will be available soon after we open source the journal version. 23 | 24 | ``` 25 | git clone https://github.com/xc-csc101/UniFace 26 | python generate_swap.py # test for swapping 27 | python generate_reenact.py # test for reenactment 28 | ``` 29 | 30 | ### Train 31 | ``` 32 | python train_reenact.py # train for reenactment 33 | python train_swap.py # train for swapping 34 | ``` 35 | 36 | ### Acknowledgements 37 | Our project is built on the [StyleMapGAN](https://github.com/naver-ai/StyleMapGAN) and some codes are borrowed from [pSp](https://github.com/eladrich/pixel2style2pixel). We thank the authors for their excellent work. -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/cx_style_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class CXLoss(nn.Module): 9 | 10 | def __init__(self, sigma=0.1, b=1.0, similarity="consine"): 11 | super(CXLoss, self).__init__() 12 | self.similarity = similarity 13 | self.sigma = sigma 14 | self.b = b 15 | 16 | def center_by_T(self, featureI, featureT): 17 | # Calculate mean channel vector for feature map. 18 | meanT = featureT.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 19 | return featureI - meanT, featureT - meanT 20 | 21 | def l2_normalize_channelwise(self, features): 22 | # Normalize on channel dimension (axis=1) 23 | norms = features.norm(p=2, dim=1, keepdim=True) 24 | features = features.div(norms) 25 | return features 26 | 27 | def patch_decomposition(self, features): 28 | N, C, H, W = features.shape 29 | assert N == 1 30 | P = H * W 31 | # NCHW --> 1x1xCxHW --> HWxCx1x1 32 | patches = features.view(1, 1, C, P).permute((3, 2, 0, 1)) 33 | return patches 34 | 35 | def calc_relative_distances(self, raw_dist, axis=1): 36 | epsilon = 1e-5 37 | div = torch.min(raw_dist, dim=axis, keepdim=True)[0] 38 | relative_dist = raw_dist / (div + epsilon) 39 | return relative_dist 40 | 41 | def calc_CX(self, dist, axis=1): 42 | W = torch.exp((self.b - dist) / self.sigma) 43 | W_sum = W.sum(dim=axis, keepdim=True) 44 | return W.div(W_sum) 45 | 46 | def forward(self, featureT, featureI): 47 | ''' 48 | :param featureT: target 49 | :param featureI: inference 50 | :return: 51 | ''' 52 | # NCHW 53 | # print(featureI.shape) 54 | 55 | featureI, featureT = self.center_by_T(featureI, featureT) 56 | 57 | featureI = self.l2_normalize_channelwise(featureI) 58 | featureT = self.l2_normalize_channelwise(featureT) 59 | 60 | dist = [] 61 | N = featureT.size()[0] 62 | for i in range(N): 63 | # NCHW 64 | featureT_i = featureT[i, :, :, :].unsqueeze(0) 65 | # NCHW 66 | featureI_i = featureI[i, :, :, :].unsqueeze(0) 67 | featureT_patch = self.patch_decomposition(featureT_i) 68 | # Calculate cosine similarity 69 | # See the torch document for functional.conv2d 70 | dist_i = F.conv2d(featureI_i, featureT_patch) 71 | dist.append(dist_i) 72 | 73 | # NCHW 74 | dist = torch.cat(dist, dim=0) 75 | 76 | raw_dist = (1. - dist) / 2. 77 | 78 | relative_dist = self.calc_relative_distances(raw_dist) 79 | 80 | CX = self.calc_CX(relative_dist) 81 | 82 | CX = CX.max(dim=3)[0].max(dim=2)[0] 83 | CX = CX.mean(1) 84 | CX = -torch.log(CX) 85 | CX = torch.mean(CX) 86 | return CX -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from criteria.encoders.model_irse import Backbone 4 | 5 | 6 | class IDLoss(nn.Module): 7 | def __init__(self, model_path='pretrained_models/model_ir_se50.pth'): 8 | super(IDLoss, self).__init__() 9 | print('Loading ResNet ArcFace') 10 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 11 | self.facenet.load_state_dict(torch.load(model_path)) 12 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 13 | self.facenet.eval() 14 | 15 | def extract_feats(self, x): 16 | x = x[:, :, 35:223, 32:220] # Crop interesting region 17 | x = self.face_pool(x) 18 | x_feats = self.facenet(x) 19 | return x_feats 20 | 21 | def forward(self, y_hat, y, x): 22 | n_samples = x.shape[0] 23 | x_feats = self.extract_feats(x) 24 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 25 | y_hat_feats = self.extract_feats(y_hat) 26 | y_feats = y_feats.detach() 27 | loss = 0 28 | sim_improvement = 0 29 | id_logs = [] 30 | count = 0 31 | for i in range(n_samples): 32 | diff_target = y_hat_feats[i].dot(y_feats[i]) 33 | diff_input = y_hat_feats[i].dot(x_feats[i]) 34 | diff_views = y_feats[i].dot(x_feats[i]) 35 | id_logs.append({'diff_target': float(diff_target), 36 | 'diff_input': float(diff_input), 37 | 'diff_views': float(diff_views)}) 38 | loss += 1 - diff_target 39 | id_diff = float(diff_target) - float(diff_views) 40 | sim_improvement += id_diff 41 | count += 1 42 | 43 | return loss / count, sim_improvement / count, id_logs 44 | -------------------------------------------------------------------------------- /demo/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/demo/teaser.jpg -------------------------------------------------------------------------------- /examples/img/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/001.jpg -------------------------------------------------------------------------------- /examples/img/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/002.jpg -------------------------------------------------------------------------------- /examples/img/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/003.jpg -------------------------------------------------------------------------------- /examples/img/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/004.jpg -------------------------------------------------------------------------------- /examples/img/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/005.jpg -------------------------------------------------------------------------------- /examples/img/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/examples/img/006.jpg -------------------------------------------------------------------------------- /examples/pair_reenact.txt: -------------------------------------------------------------------------------- 1 | 001_002 2 | 003_004 -------------------------------------------------------------------------------- /examples/pair_swap.txt: -------------------------------------------------------------------------------- 1 | 006_005 -------------------------------------------------------------------------------- /examples/reenact/pair.txt: -------------------------------------------------------------------------------- 1 | 001_002 2 | 003_004 -------------------------------------------------------------------------------- /examples/swap/pair.txt: -------------------------------------------------------------------------------- 1 | 001_002 2 | 003_004 -------------------------------------------------------------------------------- /generate_reenact.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | import argparse 11 | import pickle 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | from torch.utils import data 16 | from torchvision import utils, transforms 17 | import numpy as np 18 | from torchvision.datasets import ImageFolder 19 | from training.dataset import * 20 | from scipy import linalg 21 | import random 22 | import time 23 | import os 24 | from tqdm import tqdm 25 | from copy import deepcopy 26 | import cv2 27 | from PIL import Image 28 | from itertools import combinations 29 | # need to modify 30 | from training.model import Generator_32 as Generator 31 | from training.model import Encoder_32 as Encoder 32 | from training.pose import Encoder_Pose 33 | from utils.flow_utils import flow_to_image, resize_flow 34 | 35 | random.seed(0) 36 | torch.manual_seed(0) 37 | torch.cuda.manual_seed_all(0) 38 | 39 | 40 | def save_image(img, path, normalize=True, range=(-1, 1)): 41 | utils.save_image( 42 | img, 43 | path, 44 | normalize=normalize, 45 | range=range, 46 | ) 47 | 48 | def save_image_list(img, path, normalize=True, range=(-1, 1)): 49 | nrow = len(img) 50 | utils.save_image( 51 | img, 52 | path, 53 | nrow=nrow, 54 | normalize=normalize, 55 | range=range, 56 | padding=0 57 | ) 58 | 59 | def save_images(imgs, paths, normalize=True, range=(-1, 1)): 60 | for img, path in zip(imgs, paths): 61 | save_image(img, path, normalize=normalize, range=range) 62 | 63 | 64 | def make_noise(batch, latent_channel_size, device): 65 | return torch.randn(batch, latent_channel_size, device=device) 66 | 67 | 68 | def data_sampler(dataset, shuffle): 69 | if shuffle: 70 | return data.RandomSampler(dataset) 71 | else: 72 | return data.SequentialSampler(dataset) 73 | 74 | 75 | class Model(nn.Module): 76 | def __init__(self, device="cuda"): 77 | super(Model, self).__init__() 78 | self.g_ema = Generator( 79 | args.size, 80 | args.latent_channel_size, 81 | args.latent_spatial_size, 82 | lr_mul=args.lr_mul, 83 | channel_multiplier=args.channel_multiplier, 84 | normalize_mode=args.normalize_mode, 85 | small_generator=args.small_generator, 86 | ) 87 | self.e_ema = Encoder( 88 | args.size, 89 | args.latent_channel_size, 90 | args.latent_spatial_size, 91 | channel_multiplier=args.channel_multiplier, 92 | ) 93 | 94 | self.e_ema_p = Encoder_Pose() 95 | 96 | def forward(self, input): 97 | src = input[0] 98 | drv = input[1] 99 | 100 | src_w = self.e_ema(src) 101 | flow, pose = self.e_ema_p(drv) 102 | fake_img = self.g_ema([src_w, flow]) 103 | return src, drv, fake_img, resize_flow(flow, (256, 256)) 104 | 105 | 106 | if __name__ == "__main__": 107 | # python generate.py --ckpt expr/checkpoints/celeba_hq_256_8x8.pt --mixing_type local_editing --test_lmdb data/celeba_hq/LMDB_test --local_editing_part nose 108 | device = "cuda" 109 | 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument( 113 | "--mixing_type", 114 | type=str, 115 | default='examples' 116 | ) 117 | parser.add_argument("--inter", type=str, default='pair') 118 | parser.add_argument("--ckpt", type=str, default='session/reenactment/checkpoints/1000000.pt') 119 | parser.add_argument("--test_path", type=str, default='examples/img') 120 | parser.add_argument("--txt_path", type=str, default='examples/pair_reenact.txt') 121 | parser.add_argument("--batch", type=int, default=1) 122 | parser.add_argument("--num_workers", type=int, default=1) 123 | parser.add_argument("--save_image_dir", type=str, default="expr") 124 | 125 | args = parser.parse_args() 126 | 127 | ckpt = torch.load(args.ckpt) 128 | train_args = ckpt["train_args"] 129 | for key in vars(train_args): 130 | if not (key in vars(args)): 131 | setattr(args, key, getattr(train_args, key)) 132 | print(args) 133 | 134 | dataset_name = args.inter 135 | args.save_image_pair_dir = os.path.join( 136 | args.save_image_dir, args.mixing_type, dataset_name, 'pair' 137 | ) 138 | os.makedirs(args.save_image_pair_dir, exist_ok=True) 139 | 140 | args.save_image_single_dir = os.path.join( 141 | args.save_image_dir, args.mixing_type, dataset_name, 'single' 142 | ) 143 | os.makedirs(args.save_image_single_dir, exist_ok=True) 144 | 145 | model = Model().to(device) 146 | model.g_ema.load_state_dict(ckpt["g_ema"]) 147 | model.e_ema.load_state_dict(ckpt["e_ema"]) 148 | model.e_ema_p.load_state_dict(ckpt["e_ema_p"]) 149 | model.eval() 150 | 151 | batch = args.batch 152 | 153 | device = "cuda" 154 | transform = transforms.Compose( 155 | [ 156 | transforms.Resize((256, 256)), 157 | transforms.ToTensor(), 158 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 159 | ] 160 | ) 161 | 162 | test_dataset = Dataset_for_test(args.test_path, mode='test', root_txt=args.txt_path, suffix='.jpg', transforms=transform) 163 | n_sample = len(test_dataset) 164 | sampler = data_sampler(test_dataset, shuffle=False) 165 | 166 | loader = data.DataLoader( 167 | test_dataset, 168 | batch, 169 | sampler=sampler, 170 | num_workers=args.num_workers, 171 | pin_memory=True, 172 | drop_last=False, 173 | ) 174 | 175 | with torch.no_grad(): 176 | for i, (imgs, img_paths, pair) in enumerate(tqdm(loader, mininterval=1)): 177 | src_img = imgs[0].to(device) 178 | drv_img = imgs[1].to(device) 179 | 180 | filenames = img_paths[1] 181 | # print(filenames) 182 | 183 | img_s, img_d, img_r, flow = model([src_img, drv_img]) 184 | 185 | for i_b, (ims, imd, imr, sn, f) in enumerate(zip(img_s, img_d, img_r, pair, flow)): 186 | # print(f'******{sn}') 187 | f = f.cpu().numpy().transpose([1, 2, 0]) 188 | f_show = flow_to_image(f) 189 | 190 | save_tmp = f"{args.save_image_pair_dir}" 191 | os.makedirs(save_tmp, exist_ok=True) 192 | save_image_list( 193 | [ims, imd, imr], 194 | # f"{args.save_image_pair_dir}/{trg_img_n[i_b]}_{src_img_n[i_b]}.png", 195 | f"{save_tmp}/{sn}.png" 196 | ) 197 | save_image(imr, f"{args.save_image_single_dir}/{sn}.png",) 198 | im = cv2.imread(f"{save_tmp}/{sn}.png") 199 | im_flow = np.hstack((im, f_show)) 200 | cv2.imwrite(f"{save_tmp}/{sn}.png", im_flow) -------------------------------------------------------------------------------- /generate_swap.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import argparse 12 | import pickle 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.utils import data 17 | from torchvision import utils, transforms 18 | import numpy as np 19 | from torchvision.datasets import ImageFolder 20 | from training.dataset import * 21 | from scipy import linalg 22 | import random 23 | import time 24 | import os 25 | from tqdm import tqdm 26 | from copy import deepcopy 27 | import cv2 28 | from PIL import Image 29 | from itertools import combinations 30 | # need to modify 31 | from training.model import Generator_globalatt_return_32 as Generator 32 | from training.model import Encoder_return_32 as Encoder 33 | 34 | random.seed(0) 35 | torch.manual_seed(0) 36 | torch.cuda.manual_seed_all(0) 37 | 38 | cmap = np.array([(0, 0, 0), (255, 0, 0), (76, 153, 0), 39 | (204, 204, 0), (51, 51, 255), (204, 0, 204), (0, 255, 255), 40 | (51, 255, 255), (102, 51, 0), (255, 0, 0), (102, 204, 0), 41 | (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153), 42 | (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)], 43 | dtype=np.uint8) 44 | 45 | class Colorize(object): 46 | def __init__(self, n=19): 47 | self.cmap = cmap 48 | self.cmap = torch.from_numpy(self.cmap[:n]) 49 | 50 | def __call__(self, gray_image): 51 | size = gray_image.size() 52 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 53 | 54 | for label in range(0, len(self.cmap)): 55 | mask = (label == gray_image[0]).cpu() 56 | color_image[0][mask] = self.cmap[label][0] 57 | color_image[1][mask] = self.cmap[label][1] 58 | color_image[2][mask] = self.cmap[label][2] 59 | 60 | return color_image 61 | 62 | def save_image(img, path, normalize=True, range=(-1, 1)): 63 | utils.save_image( 64 | img, 65 | path, 66 | normalize=normalize, 67 | range=range, 68 | ) 69 | 70 | def save_image_list(img, path, normalize=True, range=(-1, 1)): 71 | nrow = len(img) 72 | utils.save_image( 73 | img, 74 | path, 75 | nrow=nrow, 76 | normalize=normalize, 77 | range=range, 78 | ) 79 | 80 | def save_images(imgs, paths, normalize=True, range=(-1, 1)): 81 | for img, path in zip(imgs, paths): 82 | save_image(img, path, normalize=normalize, range=range) 83 | 84 | 85 | def make_noise(batch, latent_channel_size, device): 86 | return torch.randn(batch, latent_channel_size, device=device) 87 | 88 | 89 | def data_sampler(dataset, shuffle): 90 | if shuffle: 91 | return data.RandomSampler(dataset) 92 | else: 93 | return data.SequentialSampler(dataset) 94 | 95 | 96 | class Model(nn.Module): 97 | def __init__(self, device="cuda"): 98 | super(Model, self).__init__() 99 | self.g_ema = Generator( 100 | args.size, 101 | args.latent_channel_size, 102 | args.latent_spatial_size, 103 | lr_mul=args.lr_mul, 104 | channel_multiplier=args.channel_multiplier, 105 | normalize_mode=args.normalize_mode, 106 | small_generator=args.small_generator, 107 | ) 108 | self.e_ema = Encoder( 109 | args.size, 110 | args.latent_channel_size, 111 | args.latent_spatial_size, 112 | channel_multiplier=args.channel_multiplier, 113 | ) 114 | 115 | def tensor2label(self, label_tensor, n_label): 116 | label_tensor = label_tensor.cpu().float() 117 | if label_tensor.size()[0] > 1: 118 | label_tensor = label_tensor.max(0, keepdim=True)[1] 119 | label_tensor = Colorize(n_label)(label_tensor) 120 | label_numpy = label_tensor.numpy() 121 | 122 | return label_numpy 123 | 124 | 125 | def forward(self, input): 126 | trg = input[0] 127 | src = input[1] 128 | 129 | trg_src = torch.cat([trg, src], dim=0) 130 | # w = self.e_ema(trg_src) 131 | 132 | w, w_feat = self.e_ema(trg_src) 133 | w_feat_tgt = [torch.chunk(f, 2, dim=0)[0] for f in w_feat][::-1] 134 | 135 | trg_w, src_w = torch.chunk(w, 2, dim=0) 136 | 137 | fake_img = self.g_ema([trg_w, src_w, w_feat_tgt]) 138 | 139 | 140 | return trg, src, fake_img 141 | 142 | 143 | if __name__ == "__main__": 144 | device = "cuda" 145 | 146 | parser = argparse.ArgumentParser() 147 | 148 | parser.add_argument( 149 | "--mixing_type", 150 | type=str, 151 | default='examples' 152 | ) 153 | parser.add_argument("--inter", type=str, default='pair') 154 | parser.add_argument("--ckpt", type=str, default='session/swap/checkpoints/500000.pt') 155 | parser.add_argument("--test_path", type=str, default='examples/img/') 156 | parser.add_argument("--test_txt_path", type=str, default='examples/pair_swap.txt') 157 | parser.add_argument("--batch", type=int, default=1) 158 | parser.add_argument("--num_workers", type=int, default=1) 159 | parser.add_argument("--save_image_dir", type=str, default="expr") 160 | 161 | args = parser.parse_args() 162 | 163 | ckpt = torch.load(args.ckpt) 164 | train_args = ckpt["train_args"] 165 | for key in vars(train_args): 166 | if not (key in vars(args)): 167 | setattr(args, key, getattr(train_args, key)) 168 | print(args) 169 | 170 | dataset_name = args.inter 171 | args.save_image_pair_dir = os.path.join( 172 | args.save_image_dir, args.mixing_type, dataset_name, 'pair' 173 | ) 174 | os.makedirs(args.save_image_pair_dir, exist_ok=True) 175 | 176 | args.save_image_single_dir = os.path.join( 177 | args.save_image_dir, args.mixing_type, dataset_name, 'single' 178 | ) 179 | os.makedirs(args.save_image_single_dir, exist_ok=True) 180 | 181 | model = Model().half().to(device) 182 | model.g_ema.load_state_dict(ckpt["g_ema"]) 183 | model.e_ema.load_state_dict(ckpt["e_ema"]) 184 | model.eval() 185 | 186 | batch = args.batch 187 | 188 | device = "cuda" 189 | transform = transforms.Compose( 190 | [ 191 | transforms.Resize((256, 256)), 192 | transforms.ToTensor(), 193 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 194 | ] 195 | ) 196 | 197 | test_dataset = SwapTestTxtDataset(args.test_path, args.test_txt_path, transform, suffix='.jpg') 198 | n_sample = len(test_dataset) 199 | sampler = data_sampler(test_dataset, shuffle=False) 200 | 201 | loader = data.DataLoader( 202 | test_dataset, 203 | batch, 204 | sampler=sampler, 205 | num_workers=args.num_workers, 206 | pin_memory=True, 207 | drop_last=False, 208 | ) 209 | 210 | with torch.no_grad(): 211 | for i, ([trg_img, trg_name],[src_img, src_name]) in enumerate(tqdm(loader, mininterval=1)): 212 | trg_img = trg_img.half().to(device) 213 | src_img = src_img.half().to(device) 214 | trg_img_n = trg_name 215 | src_img_n = src_name 216 | 217 | img_t, img_s, img_r1 = model([trg_img, src_img]) 218 | 219 | for i_b, (imt, ims, imr1) in enumerate(zip(img_t, img_s, img_r1)): 220 | 221 | save_image_list( 222 | [imt, ims, imr1], 223 | f"{args.save_image_pair_dir}/{trg_img_n[i_b]}_{src_img_n[i_b]}.jpg", 224 | ) 225 | # imr1_resize = F.interpolate(imr1.unsqueeze(0), (1024, 1024)).squeeze(0) 226 | save_image(imr1, f"{args.save_image_single_dir}/{trg_img_n[i_b]}.jpg") 227 | 228 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda install -y pytorch=1.4.0 torchvision=0.5.0 -c pytorch 4 | conda install -y numpy=1.18.1 scikit-image=0.16.2 tqdm 5 | conda install -y -c anaconda ipython=7.13.0 6 | pip install lmdb==0.98 opencv-python==4.2.0.34 munch==2.5.0 7 | pip install -U scikit-image==0.15.0 scipy==1.2.1 matplotlib scikit-learn 8 | pip install flask==1.0.2 pillow==7.0.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | This work is licensed under the Creative Commons Attribution-NonCommercial 5 | 4.0 International License. To view a copy of this license, visit 6 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 7 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | """ 9 | 10 | import argparse 11 | import os 12 | import torch 13 | from torch import nn, autograd, optim 14 | from torch.nn import functional as F 15 | from torch.utils import data 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torchvision import transforms 20 | from training import lpips 21 | from training.model import Generator, Discriminator, Encoder 22 | from training.dataset_ddp import MultiResolutionDataset 23 | from tqdm import tqdm 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | 28 | def setup(rank, world_size): 29 | os.environ["MASTER_ADDR"] = "localhost" 30 | os.environ["MASTER_PORT"] = "12355" 31 | 32 | # initialize the process group 33 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 34 | 35 | 36 | def gather_grad(params, world_size): 37 | for param in params: 38 | if param.grad is not None: 39 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 40 | param.grad.data.div_(world_size) 41 | 42 | 43 | def requires_grad(model, flag=True): 44 | for p in model.parameters(): 45 | p.requires_grad = flag 46 | 47 | 48 | def accumulate(model1, model2, decay=0.999): 49 | with torch.no_grad(): 50 | par1 = dict(model1.named_parameters()) 51 | par2 = dict(model2.named_parameters()) 52 | 53 | for k in par1.keys(): 54 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 55 | 56 | 57 | def copy_norm_params(model_tgt, model_src): 58 | with torch.no_grad(): 59 | src_state_dict = model_src.state_dict() 60 | tgt_state_dict = model_tgt.state_dict() 61 | names = [name for name, _ in model_tgt.named_parameters()] 62 | 63 | for n in names: 64 | del src_state_dict[n] 65 | 66 | tgt_state_dict.update(src_state_dict) 67 | model_tgt.load_state_dict(tgt_state_dict) 68 | 69 | 70 | def sample_data(loader): 71 | while True: 72 | for batch in loader: 73 | yield batch 74 | 75 | 76 | def d_logistic_loss(real_pred, fake_pred): 77 | real_loss = F.softplus(-real_pred) 78 | fake_loss = F.softplus(fake_pred) 79 | 80 | return real_loss.mean() + fake_loss.mean() 81 | 82 | 83 | def d_r1_loss(real_pred, real_img): 84 | (grad_real,) = autograd.grad( 85 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 86 | ) 87 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 88 | 89 | return grad_penalty 90 | 91 | 92 | def g_nonsaturating_loss(fake_pred): 93 | loss = F.softplus(-fake_pred).mean() 94 | 95 | return loss 96 | 97 | 98 | def make_noise(batch, latent_channel_size, device): 99 | return torch.randn(batch, latent_channel_size, device=device) 100 | 101 | 102 | class DDPModel(nn.Module): 103 | def __init__(self, device, args): 104 | super(DDPModel, self).__init__() 105 | self.generator = Generator( 106 | args.size, 107 | args.mapping_layer_num, 108 | args.latent_channel_size, 109 | args.latent_spatial_size, 110 | lr_mul=args.lr_mul, 111 | channel_multiplier=args.channel_multiplier, 112 | normalize_mode=args.normalize_mode, 113 | small_generator=args.small_generator, 114 | ) 115 | self.g_ema = Generator( 116 | args.size, 117 | args.mapping_layer_num, 118 | args.latent_channel_size, 119 | args.latent_spatial_size, 120 | lr_mul=args.lr_mul, 121 | channel_multiplier=args.channel_multiplier, 122 | normalize_mode=args.normalize_mode, 123 | small_generator=args.small_generator, 124 | ) 125 | 126 | self.discriminator = Discriminator( 127 | args.size, channel_multiplier=args.channel_multiplier 128 | ) 129 | self.encoder = Encoder( 130 | args.size, 131 | args.latent_channel_size, 132 | args.latent_spatial_size, 133 | channel_multiplier=args.channel_multiplier, 134 | ) 135 | 136 | self.l1_loss = nn.L1Loss(size_average=True) 137 | self.mse_loss = nn.MSELoss(size_average=True) 138 | self.e_ema = Encoder( 139 | args.size, 140 | args.latent_channel_size, 141 | args.latent_spatial_size, 142 | channel_multiplier=args.channel_multiplier, 143 | ) 144 | self.percept = lpips.exportPerceptualLoss( 145 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 146 | ) 147 | 148 | self.device = device 149 | self.args = args 150 | 151 | def forward(self, real_img, mode): 152 | if mode == "G": 153 | z = make_noise( 154 | self.args.batch_per_gpu, 155 | self.args.latent_channel_size, 156 | self.device, 157 | ) 158 | 159 | fake_img, stylecode = self.generator(z, return_stylecode=True) 160 | fake_pred = self.discriminator(fake_img) 161 | adv_loss = g_nonsaturating_loss(fake_pred) 162 | fake_img = fake_img.detach() 163 | stylecode = stylecode.detach() 164 | fake_stylecode = self.encoder(fake_img) 165 | w_rec_loss = self.mse_loss(stylecode, fake_stylecode) 166 | 167 | return adv_loss, w_rec_loss, stylecode 168 | 169 | elif mode == "D": 170 | with torch.no_grad(): 171 | z = make_noise( 172 | self.args.batch_per_gpu, 173 | self.args.latent_channel_size, 174 | self.device, 175 | ) 176 | fake_img, _ = self.generator(z) 177 | fake_stylecode = self.encoder(real_img) 178 | fake_img_from_E, _ = self.generator( 179 | fake_stylecode, input_is_stylecode=True 180 | ) 181 | 182 | real_pred = self.discriminator(real_img) 183 | fake_pred = self.discriminator(fake_img) 184 | d_loss = d_logistic_loss(real_pred, fake_pred) 185 | fake_pred_from_E = self.discriminator(fake_img_from_E) 186 | indomainGAN_D_loss = F.softplus(fake_pred_from_E).mean() 187 | 188 | return ( 189 | d_loss, 190 | indomainGAN_D_loss, 191 | real_pred.mean(), 192 | fake_pred.mean(), 193 | ) 194 | 195 | elif mode == "D_reg": 196 | real_img.requires_grad = True 197 | real_pred = self.discriminator(real_img) 198 | r1_loss = d_r1_loss(real_pred, real_img) 199 | d_reg_loss = ( 200 | self.args.r1 / 2 * r1_loss * self.args.d_reg_every + 0 * real_pred[0] 201 | ) 202 | 203 | return d_reg_loss, r1_loss 204 | 205 | elif mode == "E_x_rec": 206 | fake_stylecode = self.encoder(real_img) 207 | fake_img, _ = self.generator(fake_stylecode, input_is_stylecode=True) 208 | x_rec_loss = self.mse_loss(real_img, fake_img) 209 | perceptual_loss = self.percept(real_img, fake_img).mean() 210 | fake_pred_from_E = self.discriminator(fake_img) 211 | indomainGAN_E_loss = F.softplus(-fake_pred_from_E).mean() 212 | 213 | return x_rec_loss, perceptual_loss, indomainGAN_E_loss 214 | 215 | elif mode == "cal_mse_lpips": 216 | fake_stylecode = self.e_ema(real_img) 217 | fake_img, _ = self.g_ema(fake_stylecode, input_is_stylecode=True) 218 | x_rec_loss = self.mse_loss(real_img, fake_img) 219 | perceptual_loss = self.percept(real_img, fake_img).mean() 220 | 221 | return x_rec_loss, perceptual_loss 222 | 223 | 224 | def run(ddp_fn, world_size, args): 225 | print("world size", world_size) 226 | mp.spawn(ddp_fn, args=(world_size, args), nprocs=world_size, join=True) 227 | 228 | 229 | def ddp_main(rank, world_size, args): 230 | print(f"Running DDP model on rank {rank}.") 231 | setup(rank, world_size) 232 | map_location = f"cuda:{rank}" 233 | torch.cuda.set_device(map_location) 234 | 235 | if args.ckpt: # ignore current arguments 236 | ckpt = torch.load(args.ckpt, map_location=map_location) 237 | train_args = ckpt["train_args"] 238 | print("load model:", args.ckpt) 239 | train_args.start_iter = int(args.ckpt.split("/")[-1].replace(".pt", "")) 240 | print(f"continue training from {train_args.start_iter} iter") 241 | args = train_args 242 | args.ckpt = True 243 | else: 244 | args.start_iter = 0 245 | 246 | # create model and move it to GPU with id rank 247 | model = DDPModel(device=map_location, args=args).to(map_location) 248 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 249 | model.train() 250 | 251 | g_module = model.module.generator 252 | g_ema_module = model.module.g_ema 253 | g_ema_module.eval() 254 | accumulate(g_ema_module, g_module, 0) 255 | 256 | e_module = model.module.encoder 257 | e_ema_module = model.module.e_ema 258 | e_ema_module.eval() 259 | accumulate(e_ema_module, e_module, 0) 260 | 261 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 262 | 263 | g_optim = optim.Adam( 264 | g_module.parameters(), 265 | lr=args.lr, 266 | betas=(0, 0.99), 267 | ) 268 | 269 | d_optim = optim.Adam( 270 | model.module.discriminator.parameters(), 271 | lr=args.lr * d_reg_ratio, 272 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 273 | ) 274 | 275 | e_optim = optim.Adam( 276 | e_module.parameters(), 277 | lr=args.lr, 278 | betas=(0, 0.99), 279 | ) 280 | 281 | accum = 0.999 282 | 283 | if args.ckpt: 284 | model.module.generator.load_state_dict(ckpt["generator"]) 285 | model.module.discriminator.load_state_dict(ckpt["discriminator"]) 286 | model.module.g_ema.load_state_dict(ckpt["g_ema"]) 287 | g_optim.load_state_dict(ckpt["g_optim"]) 288 | d_optim.load_state_dict(ckpt["d_optim"]) 289 | 290 | model.module.encoder.load_state_dict(ckpt["encoder"]) 291 | e_optim.load_state_dict(ckpt["e_optim"]) 292 | model.module.e_ema.load_state_dict(ckpt["e_ema"]) 293 | 294 | del ckpt # free GPU memory 295 | 296 | transform = transforms.Compose( 297 | [ 298 | transforms.RandomHorizontalFlip(), 299 | transforms.ToTensor(), 300 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 301 | ] 302 | ) 303 | 304 | save_dir = "expr" 305 | os.makedirs(save_dir, 0o777, exist_ok=True) 306 | os.makedirs(save_dir + "/checkpoints", 0o777, exist_ok=True) 307 | 308 | train_dataset = MultiResolutionDataset(args.train_lmdb, transform, args.size) 309 | val_dataset = MultiResolutionDataset(args.val_lmdb, transform, args.size) 310 | 311 | print(f"train_dataset: {len(train_dataset)}, val_dataset: {len(val_dataset)}") 312 | 313 | val_sampler = torch.utils.data.distributed.DistributedSampler( 314 | val_dataset, num_replicas=world_size, rank=rank, shuffle=True 315 | ) 316 | 317 | val_loader = data.DataLoader( 318 | val_dataset, 319 | batch_size=args.batch_per_gpu, 320 | drop_last=True, 321 | sampler=val_sampler, 322 | num_workers=args.num_workers, 323 | pin_memory=True, 324 | ) 325 | 326 | train_sampler = torch.utils.data.distributed.DistributedSampler( 327 | train_dataset, num_replicas=world_size, rank=rank, shuffle=True 328 | ) 329 | 330 | train_loader = data.DataLoader( 331 | train_dataset, 332 | batch_size=args.batch_per_gpu, 333 | drop_last=True, 334 | sampler=train_sampler, 335 | num_workers=args.num_workers, 336 | pin_memory=True, 337 | ) 338 | 339 | train_loader = sample_data(train_loader) 340 | pbar = range(args.start_iter, args.iter) 341 | pbar = tqdm(pbar, initial=args.start_iter, mininterval=1) 342 | 343 | requires_grad(model.module.discriminator, False) 344 | epoch = -1 345 | gpu_group = dist.new_group(list(range(args.ngpus))) 346 | 347 | for i in pbar: 348 | if i > args.iter: 349 | print("Done!") 350 | break 351 | elif i % (len(train_dataset) // args.batch) == 0: 352 | epoch += 1 353 | val_sampler.set_epoch(epoch) 354 | train_sampler.set_epoch(epoch) 355 | print("epoch: ", epoch) 356 | 357 | real_img = next(train_loader) 358 | real_img = real_img.to(map_location) 359 | 360 | adv_loss, w_rec_loss, stylecode = model(None, "G") 361 | adv_loss = adv_loss.mean() 362 | 363 | with torch.no_grad(): 364 | latent_std = stylecode.std().mean().item() 365 | latent_channel_std = stylecode.std(dim=1).mean().item() 366 | latent_spatial_std = stylecode.std(dim=(2, 3)).mean().item() 367 | 368 | g_loss = adv_loss * args.lambda_adv_loss 369 | g_loss_val = g_loss.item() 370 | adv_loss_val = adv_loss.item() 371 | 372 | g_optim.zero_grad() 373 | g_loss.backward() 374 | gather_grad( 375 | g_module.parameters(), world_size 376 | ) # Explicitly synchronize Generator parameters. There is a gradient sync bug in G. 377 | g_optim.step() 378 | 379 | w_rec_loss = w_rec_loss.mean() 380 | w_rec_loss_val = w_rec_loss.item() 381 | e_optim.zero_grad() 382 | (w_rec_loss * args.lambda_w_rec_loss).backward() 383 | e_optim.step() 384 | 385 | requires_grad(model.module.discriminator, True) 386 | # D adv 387 | d_loss, indomainGAN_D_loss, real_score, fake_score = model(real_img, "D") 388 | d_loss = d_loss.mean() 389 | indomainGAN_D_loss = indomainGAN_D_loss.mean() 390 | indomainGAN_D_loss_val = indomainGAN_D_loss.item() 391 | 392 | d_loss_val = d_loss.item() 393 | 394 | d_optim.zero_grad() 395 | 396 | ( 397 | d_loss * args.lambda_d_loss 398 | + indomainGAN_D_loss * args.lambda_indomainGAN_D_loss 399 | ).backward() 400 | d_optim.step() 401 | 402 | real_score_val = real_score.mean().item() 403 | fake_score_val = fake_score.mean().item() 404 | 405 | # D reg 406 | d_regularize = i % args.d_reg_every == 0 407 | if d_regularize: 408 | d_reg_loss, r1_loss = model(real_img, "D_reg") 409 | d_reg_loss = d_reg_loss.mean() 410 | d_optim.zero_grad() 411 | d_reg_loss.backward() 412 | d_optim.step() 413 | r1_val = r1_loss.mean().item() 414 | 415 | requires_grad(model.module.discriminator, False) 416 | 417 | # E_x_rec 418 | x_rec_loss, perceptual_loss, indomainGAN_E_loss = model(real_img, "E_x_rec") 419 | x_rec_loss = x_rec_loss.mean() 420 | perceptual_loss = perceptual_loss.mean() 421 | 422 | if indomainGAN_E_loss is not None: 423 | indomainGAN_E_loss = indomainGAN_E_loss.mean() 424 | indomainGAN_E_loss_val = indomainGAN_E_loss.item() 425 | else: 426 | indomainGAN_E_loss = 0 427 | indomainGAN_E_loss_val = 0 428 | 429 | e_optim.zero_grad() 430 | g_optim.zero_grad() 431 | 432 | encoder_loss = ( 433 | x_rec_loss * args.lambda_x_rec_loss 434 | + perceptual_loss * args.lambda_perceptual_loss 435 | + indomainGAN_E_loss * args.lambda_indomainGAN_E_loss 436 | ) 437 | 438 | encoder_loss.backward() 439 | e_optim.step() 440 | g_optim.step() 441 | 442 | x_rec_loss_val = x_rec_loss.item() 443 | perceptual_loss_val = perceptual_loss.item() 444 | 445 | pbar.set_description( 446 | (f"g: {g_loss_val:.4f}; d: {d_loss_val:.4f}; r1: {r1_val:.4f};") 447 | ) 448 | 449 | with torch.no_grad(): 450 | accumulate(g_ema_module, g_module, accum) 451 | accumulate(e_ema_module, e_module, accum) 452 | 453 | if i % args.save_network_interval == 0: 454 | copy_norm_params(g_ema_module, g_module) 455 | copy_norm_params(e_ema_module, e_module) 456 | x_rec_loss_avg, perceptual_loss_avg = 0, 0 457 | iter_num = 0 458 | 459 | for test_image in tqdm(val_loader): 460 | test_image = test_image.to(map_location) 461 | x_rec_loss, perceptual_loss = model(test_image, "cal_mse_lpips") 462 | x_rec_loss_avg += x_rec_loss.mean() 463 | perceptual_loss_avg += perceptual_loss.mean() 464 | iter_num += 1 465 | 466 | x_rec_loss_avg /= iter_num 467 | perceptual_loss_avg /= iter_num 468 | 469 | dist.reduce( 470 | x_rec_loss_avg, dst=0, op=dist.ReduceOp.SUM, group=gpu_group 471 | ) 472 | dist.reduce( 473 | perceptual_loss_avg, 474 | dst=0, 475 | op=dist.ReduceOp.SUM, 476 | group=gpu_group, 477 | ) 478 | 479 | if rank == 0: 480 | x_rec_loss_avg = x_rec_loss_avg / args.ngpus 481 | perceptual_loss_avg = perceptual_loss_avg / args.ngpus 482 | x_rec_loss_avg_val = x_rec_loss_avg.item() 483 | perceptual_loss_avg_val = perceptual_loss_avg.item() 484 | 485 | print( 486 | f"x_rec_loss_avg: {x_rec_loss_avg_val}, perceptual_loss_avg: {perceptual_loss_avg_val}" 487 | ) 488 | 489 | print( 490 | f"step={i}, epoch={epoch}, x_rec_loss_avg_val={x_rec_loss_avg_val}, perceptual_loss_avg_val={perceptual_loss_avg_val}, d_loss_val={d_loss_val}, indomainGAN_D_loss_val={indomainGAN_D_loss_val}, indomainGAN_E_loss_val={indomainGAN_E_loss_val}, x_rec_loss_val={x_rec_loss_val}, perceptual_loss_val={perceptual_loss_val}, g_loss_val={g_loss_val}, adv_loss_val={adv_loss_val}, w_rec_loss_val={w_rec_loss_val}, r1_val={r1_val}, real_score_val={real_score_val}, fake_score_val={fake_score_val}, latent_std={latent_std}, latent_channel_std={latent_channel_std}, latent_spatial_std={latent_spatial_std}" 491 | ) 492 | 493 | torch.save( 494 | { 495 | "generator": model.module.generator.state_dict(), 496 | "discriminator": model.module.discriminator.state_dict(), 497 | "encoder": model.module.encoder.state_dict(), 498 | "g_ema": g_ema_module.state_dict(), 499 | "e_ema": e_ema_module.state_dict(), 500 | "train_args": args, 501 | "e_optim": e_optim.state_dict(), 502 | "g_optim": g_optim.state_dict(), 503 | "d_optim": d_optim.state_dict(), 504 | }, 505 | f"{save_dir}/checkpoints/{str(i).zfill(6)}.pt", 506 | ) 507 | 508 | 509 | if __name__ == "__main__": 510 | parser = argparse.ArgumentParser() 511 | 512 | parser.add_argument("--train_lmdb", type=str) 513 | parser.add_argument("--val_lmdb", type=str) 514 | parser.add_argument("--ckpt", type=str) 515 | parser.add_argument( 516 | "--dataset", 517 | type=str, 518 | default="celeba_hq", 519 | choices=[ 520 | "celeba_hq", 521 | "afhq", 522 | "ffhq", 523 | "lsun/church_outdoor", 524 | "lsun/car", 525 | "lsun/bedroom", 526 | ], 527 | ) 528 | parser.add_argument("--iter", type=int, default=1400000) 529 | parser.add_argument("--save_network_interval", type=int, default=10000) 530 | parser.add_argument("--small_generator", action="store_true") 531 | parser.add_argument("--batch", type=int, default=16, help="total batch sizes") 532 | parser.add_argument("--size", type=int, choices=[128, 256, 512, 1024], default=256) 533 | parser.add_argument("--r1", type=float, default=10) 534 | parser.add_argument("--d_reg_every", type=int, default=16) 535 | parser.add_argument("--lr", type=float, default=0.002) 536 | parser.add_argument("--lr_mul", type=float, default=0.01) 537 | parser.add_argument("--channel_multiplier", type=int, default=2) 538 | parser.add_argument("--latent_channel_size", type=int, default=64) 539 | parser.add_argument("--latent_spatial_size", type=int, default=8) 540 | parser.add_argument("--num_workers", type=int, default=2) 541 | parser.add_argument( 542 | "--normalize_mode", 543 | type=str, 544 | choices=["LayerNorm", "InstanceNorm2d", "BatchNorm2d", "GroupNorm"], 545 | default="LayerNorm", 546 | ) 547 | parser.add_argument("--mapping_layer_num", type=int, default=8) 548 | 549 | parser.add_argument("--lambda_x_rec_loss", type=float, default=1) 550 | parser.add_argument("--lambda_adv_loss", type=float, default=1) 551 | parser.add_argument("--lambda_w_rec_loss", type=float, default=1) 552 | parser.add_argument("--lambda_d_loss", type=float, default=1) 553 | parser.add_argument("--lambda_perceptual_loss", type=float, default=1) 554 | parser.add_argument("--lambda_indomainGAN_D_loss", type=float, default=1) 555 | parser.add_argument("--lambda_indomainGAN_E_loss", type=float, default=1) 556 | 557 | input_args = parser.parse_args() 558 | 559 | ngpus = torch.cuda.device_count() 560 | print("{} GPUS!".format(ngpus)) 561 | 562 | assert input_args.batch % ngpus == 0 563 | input_args.batch_per_gpu = input_args.batch // ngpus 564 | input_args.ngpus = ngpus 565 | print("{} batch per gpu!".format(input_args.batch_per_gpu)) 566 | 567 | run(ddp_main, ngpus, input_args) 568 | -------------------------------------------------------------------------------- /train_reenact.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | import argparse 11 | import os 12 | import yaml 13 | import matplotlib.pyplot as plt 14 | import torch 15 | from torch import autograd, nn, optim 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from torch.nn import functional as F 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | from torch.utils import data 21 | from torchvision import transforms 22 | from tqdm import tqdm 23 | 24 | from training import lpips 25 | from training.model import Discriminator, Encoder_32, Generator_32_att 26 | from training.pose import Encoder_Pose 27 | from training.dataset import Dataset_scale_trans 28 | from utils import common 29 | 30 | torch.backends.cudnn.benchmark = True 31 | 32 | 33 | def log_images_v3(step, logdir, name, im_data, subscript=None, log_latest=False): 34 | fig = common.vis_faces_v3(im_data) 35 | if log_latest: 36 | step = 0 37 | if subscript: 38 | path = os.path.join(logdir, name, '{}_{:04d}.jpg'.format(subscript, step)) 39 | else: 40 | path = os.path.join(logdir, name, '{:04d}.jpg'.format(step)) 41 | os.makedirs(os.path.dirname(path), exist_ok=True) 42 | fig.savefig(path) 43 | plt.close(fig) 44 | 45 | 46 | def parse_and_log_images_v3(step, logdir, x, y_scale, y_gt, y_hat, title, subscript=None, display_count=2): 47 | im_data = [] 48 | for i in range(display_count): 49 | cur_im_data = { 50 | 'input_face': common.log_input_image(x[i]), 51 | 'target_face_scale': common.tensor2im(y_scale[i]), 52 | 'target_face_gt': common.tensor2im(y_gt[i]), 53 | 'output_face': common.tensor2im(y_hat[i]), 54 | } 55 | im_data.append(cur_im_data) 56 | log_images_v3(step, logdir, title, im_data=im_data, subscript=subscript) 57 | 58 | 59 | def save_args(path, args): 60 | args_dict = args.__dict__ 61 | with open(path, 'w') as f: 62 | yaml.dump(args_dict, f) 63 | 64 | 65 | def setup(rank, world_size): 66 | os.environ["MASTER_ADDR"] = "localhost" 67 | os.environ["MASTER_PORT"] = "12345" # 12369 68 | 69 | # initialize the process group 70 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 71 | 72 | 73 | def gather_grad(params, world_size): 74 | for param in params: 75 | if param.grad is not None: 76 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 77 | param.grad.data.div_(world_size) 78 | 79 | 80 | def requires_grad(model, flag=True): 81 | for p in model.parameters(): 82 | p.requires_grad = flag 83 | 84 | 85 | def accumulate(model1, model2, decay=0.999): 86 | with torch.no_grad(): 87 | par1 = dict(model1.named_parameters()) 88 | par2 = dict(model2.named_parameters()) 89 | 90 | for k in par1.keys(): 91 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 92 | 93 | 94 | def copy_norm_params(model_tgt, model_src): 95 | with torch.no_grad(): 96 | src_state_dict = model_src.state_dict() 97 | tgt_state_dict = model_tgt.state_dict() 98 | names = [name for name, _ in model_tgt.named_parameters()] 99 | 100 | for n in names: 101 | del src_state_dict[n] 102 | 103 | tgt_state_dict.update(src_state_dict) 104 | model_tgt.load_state_dict(tgt_state_dict) 105 | 106 | 107 | def sample_data(loader): 108 | while True: 109 | for batch in loader: 110 | yield batch 111 | 112 | 113 | def d_logistic_loss(real_pred, fake_pred): 114 | real_loss = F.softplus(-real_pred) 115 | fake_loss = F.softplus(fake_pred) 116 | 117 | return real_loss.mean() + fake_loss.mean() 118 | 119 | 120 | def d_r1_loss(real_pred, real_img): 121 | (grad_real,) = autograd.grad( 122 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 123 | ) 124 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 125 | 126 | return grad_penalty 127 | 128 | 129 | def g_nonsaturating_loss(fake_pred): 130 | loss = F.softplus(-fake_pred).mean() 131 | return loss 132 | 133 | 134 | def make_noise(batch, latent_channel_size, device): 135 | return torch.randn(batch, latent_channel_size, device=device) 136 | 137 | 138 | class DDPModel(nn.Module): 139 | def __init__(self, device, args): 140 | super(DDPModel, self).__init__() 141 | 142 | self.generator = Generator_32_att( 143 | args.size, 144 | args.latent_channel_size, 145 | args.latent_spatial_size, 146 | lr_mul=args.lr_mul, 147 | channel_multiplier=args.channel_multiplier, 148 | normalize_mode=args.normalize_mode, 149 | small_generator=args.small_generator 150 | ) 151 | self.g_ema = Generator_32_att( 152 | args.size, 153 | args.latent_channel_size, 154 | args.latent_spatial_size, 155 | lr_mul=args.lr_mul, 156 | channel_multiplier=args.channel_multiplier, 157 | normalize_mode=args.normalize_mode, 158 | small_generator=args.small_generator 159 | ) 160 | 161 | self.discriminator = Discriminator( 162 | args.size, channel_multiplier=args.channel_multiplier 163 | ) 164 | 165 | self.encoder = Encoder_32( 166 | args.size, 167 | args.latent_channel_size, 168 | args.latent_spatial_size, 169 | in_ch=3, 170 | channel_multiplier=args.channel_multiplier, 171 | ) 172 | self.e_ema = Encoder_32( 173 | args.size, 174 | args.latent_channel_size, 175 | args.latent_spatial_size, 176 | in_ch=3, 177 | channel_multiplier=args.channel_multiplier, 178 | ) 179 | 180 | self.encoder_p = Encoder_Pose() 181 | self.e_ema_p = Encoder_Pose() 182 | 183 | self.mse_loss = nn.MSELoss(size_average=True) 184 | self.percept = lpips.exportPerceptualLoss( 185 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 186 | ) 187 | 188 | self.device = device 189 | self.args = args 190 | 191 | def forward(self, real_img, mode): 192 | if mode == "D": 193 | # real img is list 194 | with torch.no_grad(): 195 | src = real_img[0] 196 | drv_gt = real_img[1] 197 | drv_input = real_img[2] 198 | 199 | src_w = self.encoder(src) 200 | flow, pose = self.encoder_p(drv_input) 201 | fake_img = self.generator([src_w, flow]) 202 | 203 | real_pred = self.discriminator(src) 204 | fake_pred = self.discriminator(fake_img) 205 | 206 | d_loss = d_logistic_loss(real_pred, fake_pred) 207 | return d_loss 208 | 209 | 210 | elif mode == "D_reg": 211 | # real_img is tensor 212 | real_img.requires_grad = True 213 | real_pred = self.discriminator(real_img) 214 | r1_loss = d_r1_loss(real_pred, real_img) 215 | d_reg_loss = ( 216 | self.args.r1 / 2 * r1_loss * self.args.d_reg_every + 0 * real_pred[0] 217 | ) 218 | 219 | return d_reg_loss, r1_loss 220 | 221 | elif mode == "E_x_rec": 222 | src = real_img[0] 223 | drv_gt = real_img[1] 224 | drv_input = real_img[2] 225 | 226 | src_w = self.encoder(src) 227 | flow, pose = self.encoder_p(drv_input) 228 | fake_img = self.generator([src_w, flow]) 229 | 230 | x_rec_loss = self.mse_loss(drv_gt, fake_img) 231 | perceptual_loss = self.percept(drv_gt, fake_img).mean() 232 | 233 | fake_pred_from_E = self.discriminator(fake_img) 234 | 235 | indomainGAN_E_loss = F.softplus(-fake_pred_from_E).mean() 236 | 237 | return x_rec_loss, perceptual_loss, indomainGAN_E_loss, fake_img 238 | 239 | elif mode == "cal_mse_lpips": 240 | # real img is list 241 | src = real_img[0] 242 | drv_gt = real_img[1] 243 | drv_input = real_img[2] 244 | 245 | src_w = self.e_ema(src) 246 | flow, pose = self.e_ema_p(drv_input) 247 | fake_img = self.g_ema([src_w, flow]) 248 | 249 | x_rec_loss = self.mse_loss(drv_gt, fake_img) 250 | perceptual_loss = self.percept(drv_gt, fake_img).mean() 251 | 252 | return x_rec_loss, perceptual_loss, fake_img 253 | 254 | 255 | def run(ddp_fn, world_size, args): 256 | print("world size", world_size) 257 | mp.spawn(ddp_fn, args=(world_size, args), nprocs=world_size, join=True) 258 | 259 | 260 | def ddp_main(rank, world_size, args): 261 | print(f"Running DDP model on rank {rank}.") 262 | setup(rank, world_size) 263 | map_location = f"cuda:{rank}" 264 | torch.cuda.set_device(map_location) 265 | 266 | if args.ckpt: # ignore current arguments 267 | ckpt = torch.load(args.ckpt, map_location=map_location) 268 | train_args = ckpt["train_args"] 269 | print("load model:", args.ckpt) 270 | train_args.start_iter = int(args.ckpt.split("/")[-1].replace(".pt", "")) 271 | print(f"continue training from {train_args.start_iter} iter") 272 | args.ckpt = True 273 | args.start_iter = train_args.start_iter 274 | else: 275 | args.start_iter = 0 276 | 277 | # create model and move it to GPU with id rank 278 | model = DDPModel(device=map_location, args=args).to(map_location) 279 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 280 | model.train() 281 | 282 | g_module = model.module.generator 283 | g_ema_module = model.module.g_ema 284 | g_ema_module.eval() 285 | accumulate(g_ema_module, g_module, 0) 286 | 287 | e_module = model.module.encoder 288 | e_ema_module = model.module.e_ema 289 | e_ema_module.eval() 290 | accumulate(e_ema_module, e_module, 0) 291 | 292 | e_module_p = model.module.encoder_p 293 | e_ema_module_p = model.module.e_ema_p 294 | e_ema_module_p.eval() 295 | accumulate(e_ema_module_p, e_module_p, 0) 296 | 297 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 298 | r1_val = 0 299 | g_optim = optim.Adam( 300 | g_module.parameters(), 301 | lr=args.lr, 302 | betas=(0, 0.99), 303 | ) 304 | 305 | d_optim = optim.Adam( 306 | model.module.discriminator.parameters(), 307 | lr=args.lr * d_reg_ratio, 308 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 309 | ) 310 | 311 | e_optim = optim.Adam( 312 | e_module.parameters(), 313 | lr=args.lr, 314 | betas=(0, 0.99), 315 | ) 316 | 317 | e_optim_p = optim.Adam( 318 | e_module_p.parameters(), 319 | lr=args.lr * 0.01, 320 | betas=(0, 0.99), 321 | ) 322 | 323 | accum = 0.999 324 | 325 | if args.ckpt: 326 | model.module.generator.load_state_dict(ckpt["generator"]) 327 | model.module.g_ema.load_state_dict(ckpt["g_ema"]) 328 | model.module.discriminator.load_state_dict(ckpt["discriminator"]) 329 | 330 | model.module.encoder.load_state_dict(ckpt["encoder"]) 331 | model.module.e_ema.load_state_dict(ckpt["e_ema"]) 332 | 333 | model.module.encoder_p.load_state_dict(ckpt["encoder_p"]) 334 | model.module.e_ema_p.load_state_dict(ckpt["e_ema_p"]) 335 | 336 | transform = transforms.Compose([ 337 | transforms.Resize((256, 256)), 338 | transforms.ToTensor(), 339 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 340 | ]) 341 | 342 | save_dir = os.path.join("./session", args.checkname) 343 | os.makedirs(save_dir, 0o777, exist_ok=True) 344 | os.makedirs(save_dir + "/checkpoints", 0o777, exist_ok=True) 345 | os.makedirs(save_dir + "/imgs", 0o777, exist_ok=True) 346 | save_args(os.path.join('session', args.checkname, 'args.yaml'), args) 347 | 348 | train_dataset = Dataset_scale_trans(args.train_path, train=True, transforms=transform) 349 | val_dataset = Dataset_scale_trans(args.val_path, train=False, transforms=transform) 350 | 351 | print(f"train_dataset: {len(train_dataset)}, val_dataset: {len(val_dataset)}") 352 | 353 | val_sampler = torch.utils.data.distributed.DistributedSampler( 354 | val_dataset, num_replicas=world_size, rank=rank, shuffle=True 355 | ) 356 | 357 | val_loader = data.DataLoader( 358 | val_dataset, 359 | batch_size=args.batch_per_gpu, 360 | drop_last=True, 361 | sampler=val_sampler, 362 | num_workers=args.num_workers, 363 | pin_memory=True, 364 | ) 365 | 366 | train_sampler = torch.utils.data.distributed.DistributedSampler( 367 | train_dataset, num_replicas=world_size, rank=rank, shuffle=True 368 | ) 369 | 370 | train_loader = data.DataLoader( 371 | train_dataset, 372 | batch_size=args.batch_per_gpu, 373 | drop_last=True, 374 | sampler=train_sampler, 375 | num_workers=args.num_workers, 376 | pin_memory=True, 377 | ) 378 | 379 | train_loader = sample_data(train_loader) 380 | pbar = range(args.start_iter, args.iter) 381 | pbar = tqdm(pbar, initial=args.start_iter, mininterval=1) 382 | 383 | # requires_grad(model.module.vgg, False) 384 | epoch = -1 385 | gpu_group = dist.new_group(list(range(args.ngpus))) 386 | 387 | for i in pbar: 388 | if i > args.iter: 389 | print("Done!") 390 | break 391 | elif i % (len(train_dataset) // args.batch) == 0: 392 | epoch += 1 393 | val_sampler.set_epoch(epoch) 394 | train_sampler.set_epoch(epoch) 395 | print("epoch: ", epoch) 396 | 397 | [src_img, drv_img_gt, drv_img_input], _ = next(train_loader) 398 | src_img = src_img.to(map_location) 399 | drv_img_gt = drv_img_gt.to(map_location) 400 | drv_img_input = drv_img_input.to(map_location) 401 | 402 | requires_grad(model.module.discriminator, True) 403 | # D adv 404 | d_loss = model([src_img, drv_img_gt, drv_img_input], "D") 405 | d_loss = d_loss.mean() 406 | 407 | d_optim.zero_grad() 408 | ( 409 | d_loss * args.lambda_d_loss 410 | ).backward() 411 | d_optim.step() 412 | 413 | # D reg 414 | d_regularize = i % args.d_reg_every == 0 415 | if d_regularize: 416 | d_reg_loss, r1_loss = model(src_img, "D_reg") # r1 loss 417 | d_reg_loss = d_reg_loss.mean() 418 | d_optim.zero_grad() 419 | d_reg_loss.backward() 420 | d_optim.step() 421 | r1_val = r1_loss.mean().item() 422 | 423 | requires_grad(model.module.discriminator, False) 424 | 425 | # E_x_rec, 426 | x_rec_loss, perceptual_loss, indomainGAN_E_loss, fake_img = model([src_img, drv_img_gt, drv_img_input], "E_x_rec") 427 | 428 | x_rec_loss = x_rec_loss.mean() * args.lambda_x_rec_loss 429 | perceptual_loss = perceptual_loss.mean() * args.lambda_perceptual_loss 430 | indomainGAN_E_loss = indomainGAN_E_loss.mean() * args.lambda_indomainGAN_E_loss 431 | 432 | e_optim.zero_grad() 433 | e_optim_p.zero_grad() 434 | g_optim.zero_grad() 435 | 436 | encoder_loss = ( 437 | x_rec_loss 438 | + perceptual_loss 439 | + indomainGAN_E_loss 440 | ) 441 | 442 | encoder_loss.backward() 443 | 444 | e_optim.step() 445 | e_optim_p.step() 446 | g_optim.step() 447 | 448 | pbar.set_description( 449 | ( 450 | f"g: {indomainGAN_E_loss.item():.4f}; d: {d_loss.item():.4f}; r1: {r1_val:.4f}; rec: {x_rec_loss.item():.4f}; per: {perceptual_loss.item():.4f}") 451 | ) 452 | 453 | # log image 454 | if rank == 0: 455 | if i % args.image_interval == 0 or (i < 1000 and i % 25 == 0): 456 | parse_and_log_images_v3(i, save_dir, src_img, drv_img_input, drv_img_gt, fake_img, title='imgs/train/') 457 | 458 | with torch.no_grad(): 459 | accumulate(g_ema_module, g_module, accum) 460 | accumulate(e_ema_module, e_module, accum) 461 | accumulate(e_ema_module_p, e_module_p, accum) 462 | 463 | if i % args.save_img_interval == 0: 464 | copy_norm_params(g_ema_module, g_module) 465 | copy_norm_params(e_ema_module, e_module) 466 | copy_norm_params(e_ema_module_p, e_module_p) 467 | 468 | x_rec_loss_avg, perceptual_loss_avg = 0, 0 469 | iter_num = 0 470 | for ([src_img, drv_img_gt, drv_img_input], [s1, d1]) in tqdm(val_loader): 471 | 472 | src_img = src_img.to(map_location) 473 | drv_img_gt = drv_img_gt.to(map_location) 474 | drv_img_input = drv_img_input.to(map_location) 475 | 476 | x_rec_loss, perceptual_loss, fake_img = model([src_img, drv_img_gt, drv_img_input], "cal_mse_lpips") 477 | 478 | x_rec_loss_avg += x_rec_loss.mean() 479 | perceptual_loss_avg += perceptual_loss.mean() 480 | iter_num += 1 481 | 482 | # log images 483 | if rank == 0: 484 | parse_and_log_images_v3(i, save_dir, src_img, drv_img_input, drv_img_gt, fake_img, 485 | title='imgs/test/') 486 | 487 | x_rec_loss_avg /= iter_num 488 | perceptual_loss_avg /= iter_num 489 | 490 | dist.reduce( 491 | x_rec_loss_avg, dst=0, op=dist.ReduceOp.SUM, group=gpu_group 492 | ) 493 | dist.reduce( 494 | perceptual_loss_avg, 495 | dst=0, 496 | op=dist.ReduceOp.SUM, 497 | group=gpu_group, 498 | ) 499 | if i % args.save_network_interval == 0: 500 | if rank == 0: 501 | x_rec_loss_avg = x_rec_loss_avg / args.ngpus 502 | perceptual_loss_avg = perceptual_loss_avg / args.ngpus 503 | x_rec_loss_avg_val = x_rec_loss_avg.item() 504 | perceptual_loss_avg_val = perceptual_loss_avg.item() 505 | 506 | print( 507 | f"x_rec_loss_avg: {x_rec_loss_avg_val}, perceptual_loss_avg: {perceptual_loss_avg_val}" 508 | ) 509 | 510 | torch.save( 511 | { 512 | "generator": model.module.generator.state_dict(), 513 | "g_ema": g_ema_module.state_dict(), 514 | "discriminator": model.module.discriminator.state_dict(), 515 | "encoder": model.module.encoder.state_dict(), 516 | "e_ema": e_ema_module.state_dict(), 517 | "encoder_p": model.module.encoder_p.state_dict(), 518 | "e_ema_p": e_ema_module_p.state_dict(), 519 | "train_args": args, 520 | }, 521 | f"{save_dir}/checkpoints/{str(i).zfill(6)}.pt", 522 | ) 523 | 524 | 525 | if __name__ == "__main__": 526 | parser = argparse.ArgumentParser() 527 | parser.add_argument("--checkname", type=str, default='exp-reenact-rec5-p5') 528 | parser.add_argument("--describ", type=str, default='aug fix') 529 | parser.add_argument("--train_path", type=str, default='data/voxceleb/vox2-png/') 530 | parser.add_argument("--val_path", type=str, default='data/voxceleb/vox2-png/') 531 | parser.add_argument("--ckpt", type=str, default='') 532 | parser.add_argument( 533 | "--dataset", 534 | type=str, 535 | default="danbooru", 536 | choices=[ 537 | "celeba_hq", 538 | "afhq", 539 | "ffhq", 540 | "lsun/church_outdoor", 541 | "lsun/car", 542 | "lsun/bedroom", 543 | ], 544 | ) 545 | parser.add_argument("--iter", type=int, default=500000) 546 | parser.add_argument("--save_network_interval", type=int, default=5000) 547 | parser.add_argument("--save_img_interval", type=int, default=1000) 548 | parser.add_argument("--small_generator", action="store_true") 549 | parser.add_argument("--batch", type=int, default=8, help="total batch sizes") 550 | parser.add_argument("--size", type=int, choices=[128, 256, 512, 1024], default=256) 551 | parser.add_argument("--r1", type=float, default=10) 552 | parser.add_argument("--d_reg_every", type=int, default=16) 553 | parser.add_argument("--lr", type=float, default=0.001) 554 | parser.add_argument("--lr_mul", type=float, default=1) 555 | parser.add_argument("--channel_multiplier", type=int, default=2) 556 | parser.add_argument("--latent_channel_size", type=int, default=512) 557 | parser.add_argument("--latent_spatial_size", type=int, default=32) 558 | parser.add_argument("--num_workers", type=int, default=8) 559 | parser.add_argument("--image_interval", type=int, default=50) 560 | parser.add_argument( 561 | "--normalize_mode", 562 | type=str, 563 | choices=["LayerNorm", "InstanceNorm2d", "BatchNorm2d", "GroupNorm"], 564 | default="LayerNorm", 565 | ) 566 | parser.add_argument("--mapping_layer_num", type=int, default=8) 567 | 568 | parser.add_argument("--lambda_x_rec_loss", type=float, default=5) 569 | parser.add_argument("--lambda_d_loss", type=float, default=1) 570 | parser.add_argument("--lambda_perceptual_loss", type=float, default=5) 571 | parser.add_argument("--lambda_indomainGAN_D_loss", type=float, default=1) 572 | parser.add_argument("--lambda_indomainGAN_E_loss", type=float, default=1) 573 | 574 | input_args = parser.parse_args() 575 | ngpus = 2 576 | # ngpus = torch.cuda.device_count() 577 | print("{} GPUS!".format(ngpus)) 578 | 579 | assert input_args.batch % ngpus == 0 580 | input_args.batch_per_gpu = input_args.batch // ngpus 581 | input_args.ngpus = ngpus 582 | print("{} batch per gpu!".format(input_args.batch_per_gpu)) 583 | 584 | run(ddp_main, ngpus, input_args) -------------------------------------------------------------------------------- /train_swap.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | import argparse 11 | import yaml 12 | import matplotlib.pyplot as plt 13 | from torch import autograd, nn, optim 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | from torch.nn import functional as F 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from torch.utils import data 19 | from torchvision import transforms 20 | import torchvision 21 | from tqdm import tqdm 22 | 23 | from training import lpips 24 | from training.model import Discriminator, Encoder_return_32, Generator_globalatt_flow 25 | from training.pose import Encoder_Pose 26 | from training.dataset import * 27 | from training.vgg import VGG 28 | from criteria import id_loss 29 | from criteria.cx_style_loss import CXLoss 30 | from utils import common 31 | 32 | torch.backends.cudnn.benchmark = True 33 | 34 | 35 | def log_images(step, logdir, name, im_data, subscript=None, log_latest=False): 36 | fig = common.vis_faces(im_data) 37 | if log_latest: 38 | step = 0 39 | if subscript: 40 | path = os.path.join(logdir, name, '{}_{:04d}.jpg'.format(subscript, step)) 41 | else: 42 | path = os.path.join(logdir, name, '{:04d}.jpg'.format(step)) 43 | os.makedirs(os.path.dirname(path), exist_ok=True) 44 | fig.savefig(path) 45 | plt.close(fig) 46 | 47 | 48 | def parse_and_log_images(step, logdir, x, y, y_hat, title, subscript=None, display_count=2): 49 | im_data = [] 50 | for i in range(display_count): 51 | cur_im_data = { 52 | 'input_face': common.log_input_image(x[i]), 53 | 'target_face': common.tensor2im(y[i]), 54 | 'output_face': common.tensor2im(y_hat[i]), 55 | } 56 | im_data.append(cur_im_data) 57 | log_images(step, logdir, title, im_data=im_data, subscript=subscript) 58 | 59 | 60 | def save_args(path, args): 61 | args_dict = args.__dict__ 62 | with open(path, 'w') as f: 63 | yaml.dump(args_dict, f) 64 | 65 | 66 | def setup(rank, world_size): 67 | os.environ["MASTER_ADDR"] = "localhost" 68 | os.environ["MASTER_PORT"] = "12345" 69 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 70 | 71 | 72 | def gather_grad(params, world_size): 73 | for param in params: 74 | if param.grad is not None: 75 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 76 | param.grad.data.div_(world_size) 77 | 78 | 79 | def requires_grad(model, flag=True): 80 | for p in model.parameters(): 81 | p.requires_grad = flag 82 | 83 | 84 | def accumulate(model1, model2, decay=0.999): 85 | with torch.no_grad(): 86 | par1 = dict(model1.named_parameters()) 87 | par2 = dict(model2.named_parameters()) 88 | 89 | for k in par1.keys(): 90 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 91 | 92 | 93 | def copy_norm_params(model_tgt, model_src): 94 | with torch.no_grad(): 95 | src_state_dict = model_src.state_dict() 96 | tgt_state_dict = model_tgt.state_dict() 97 | names = [name for name, _ in model_tgt.named_parameters()] 98 | 99 | for n in names: 100 | del src_state_dict[n] 101 | 102 | tgt_state_dict.update(src_state_dict) 103 | model_tgt.load_state_dict(tgt_state_dict) 104 | 105 | 106 | def sample_data(loader): 107 | while True: 108 | for batch in loader: 109 | yield batch 110 | 111 | 112 | def d_logistic_loss(real_pred, fake_pred): 113 | real_loss = F.softplus(-real_pred) 114 | fake_loss = F.softplus(fake_pred) 115 | 116 | return real_loss.mean() + fake_loss.mean() 117 | 118 | 119 | def d_r1_loss(real_pred, real_img): 120 | (grad_real,) = autograd.grad( 121 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 122 | ) 123 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 124 | 125 | return grad_penalty 126 | 127 | 128 | def g_nonsaturating_loss(fake_pred): 129 | loss = F.softplus(-fake_pred).mean() 130 | return loss 131 | 132 | 133 | def make_noise(batch, latent_channel_size, device): 134 | return torch.randn(batch, latent_channel_size, device=device) 135 | 136 | 137 | class DDPModel(nn.Module): 138 | def __init__(self, device, args): 139 | super(DDPModel, self).__init__() 140 | 141 | self.generator = Generator_globalatt_flow( 142 | args.size, 143 | args.latent_channel_size, 144 | args.latent_spatial_size, 145 | lr_mul=args.lr_mul, 146 | channel_multiplier=args.channel_multiplier, 147 | normalize_mode=args.normalize_mode, 148 | small_generator=args.small_generator 149 | ) 150 | self.g_ema = Generator_globalatt_flow( 151 | args.size, 152 | args.latent_channel_size, 153 | args.latent_spatial_size, 154 | lr_mul=args.lr_mul, 155 | channel_multiplier=args.channel_multiplier, 156 | normalize_mode=args.normalize_mode, 157 | small_generator=args.small_generator 158 | ) 159 | 160 | self.discriminator = Discriminator( 161 | args.size, channel_multiplier=args.channel_multiplier 162 | ) 163 | 164 | self.encoder = Encoder_return_32( 165 | args.size, 166 | args.latent_channel_size, 167 | args.latent_spatial_size, 168 | in_ch=3, 169 | channel_multiplier=args.channel_multiplier, 170 | ) 171 | 172 | self.e_ema = Encoder_return_32( 173 | args.size, 174 | args.latent_channel_size, 175 | args.latent_spatial_size, 176 | in_ch=3, 177 | channel_multiplier=args.channel_multiplier, 178 | ) 179 | 180 | self.encoder_p = Encoder_Pose() 181 | self.e_ema_p = Encoder_Pose() 182 | 183 | self.l1_loss = nn.L1Loss(size_average=True) 184 | self.mse_loss = nn.MSELoss(size_average=True) 185 | self.id_loss = id_loss.IDLoss().eval() 186 | self.percept = lpips.exportPerceptualLoss( 187 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 188 | ) 189 | # cx loss 190 | self.vgg = VGG() 191 | self.vgg.load_state_dict(torch.load('pretrained_models/vgg19_conv.pth')) 192 | requires_grad(self.vgg, False) 193 | print('vgg for cx loss ready') 194 | 195 | self.cx_loss = CXLoss(sigma=0.5) 196 | 197 | self.device = device 198 | self.args = args 199 | 200 | def calc_cx(self, real, fake): 201 | style_layer = ['r32', 'r42'] 202 | vgg_style = self.vgg(real, style_layer) 203 | vgg_fake = self.vgg(fake, style_layer) 204 | cx_style_loss = 0 205 | 206 | for i, val in enumerate(vgg_fake): 207 | cx_style_loss += self.cx_loss(vgg_style[i], vgg_fake[i]) 208 | return cx_style_loss 209 | 210 | def forward(self, real_img, mode): 211 | if mode == "D": 212 | # real img is list 213 | with torch.no_grad(): 214 | trg = real_img[0] 215 | src = real_img[1] 216 | 217 | trg_src = torch.cat([trg, src], dim=0) 218 | w, w_feat = self.encoder(trg_src) 219 | 220 | w_feat_tgt = [torch.chunk(f, 2, dim=0)[0] for f in w_feat][::-1] 221 | 222 | trg_w, src_w = torch.chunk(w, 2, dim=0) 223 | 224 | flow, _ = self.encoder_p(trg) 225 | 226 | fake_img = self.generator([trg_w, src_w, w_feat_tgt, flow]) 227 | 228 | real_pred = self.discriminator(trg) 229 | fake_pred = self.discriminator(fake_img) 230 | d_loss = d_logistic_loss(real_pred, fake_pred) 231 | 232 | return ( 233 | d_loss, 234 | real_pred.mean(), 235 | fake_pred.mean(), 236 | ) 237 | 238 | elif mode == "D_reg": 239 | # real_img is tensor 240 | real_img.requires_grad = True 241 | real_pred = self.discriminator(real_img) 242 | r1_loss = d_r1_loss(real_pred, real_img) 243 | d_reg_loss = ( 244 | self.args.r1 / 2 * r1_loss * self.args.d_reg_every + 0 * real_pred[0] 245 | ) 246 | 247 | return d_reg_loss, r1_loss 248 | 249 | elif mode == "E_x_rec": 250 | # real img is list 251 | trg = real_img[0] 252 | src = real_img[1] 253 | same = real_img[2] 254 | 255 | trg_src = torch.cat([trg, src], dim=0) 256 | w, w_feat = self.encoder(trg_src) 257 | w_feat_tgt = [torch.chunk(f, 2, dim=0)[0] for f in w_feat][::-1] 258 | 259 | trg_w, src_w = torch.chunk(w, 2, dim=0) 260 | 261 | flow, _ = self.encoder_p(trg) 262 | 263 | fake_img = self.generator([trg_w, src_w, w_feat_tgt, flow]) 264 | 265 | # l2 loss for same id 266 | same = same.unsqueeze(-1).unsqueeze(-1) 267 | same = same.expand(trg.shape) 268 | 269 | x_rec_loss = self.mse_loss(torch.mul(trg, same), torch.mul(fake_img, same)) 270 | perceptual_loss = self.percept(trg, fake_img).mean() 271 | 272 | # id loss 273 | id_loss, sim_improvement, id_logs = self.id_loss(fake_img, src, trg) 274 | 275 | # contextual loss 276 | cx_loss = self.calc_cx(trg, fake_img) 277 | 278 | fake_pred_from_E = self.discriminator(fake_img) 279 | indomainGAN_E_loss = F.softplus(-fake_pred_from_E).mean() 280 | 281 | return x_rec_loss, perceptual_loss, indomainGAN_E_loss, id_loss, cx_loss, fake_img 282 | 283 | elif mode == "cal_mse_lpips": 284 | # real img is list 285 | trg = real_img[0] 286 | src = real_img[1] 287 | same = real_img[2] 288 | 289 | trg_src = torch.cat([trg, src], dim=0) 290 | w, w_feat = self.e_ema(trg_src) 291 | w_feat_tgt = [torch.chunk(f, 2, dim=0)[0] for f in w_feat][::-1] 292 | 293 | trg_w, src_w = torch.chunk(w, 2, dim=0) 294 | 295 | flow, _ = self.e_ema_p(trg) 296 | fake_img = self.g_ema([trg_w, src_w, w_feat_tgt, flow]) 297 | 298 | same = same.unsqueeze(-1).unsqueeze(-1) 299 | same = same.expand(trg.shape) 300 | 301 | x_rec_loss = self.mse_loss(torch.mul(trg, same), torch.mul(fake_img, same)) 302 | perceptual_loss = self.percept(trg, fake_img).mean() 303 | cx_loss = self.calc_cx(trg, fake_img) 304 | 305 | return x_rec_loss, perceptual_loss, cx_loss, fake_img 306 | 307 | 308 | def run(ddp_fn, world_size, args): 309 | print("world size", world_size) 310 | mp.spawn(ddp_fn, args=(world_size, args), nprocs=world_size, join=True) 311 | 312 | 313 | def ddp_main(rank, world_size, args): 314 | print(f"Running DDP model on rank {rank}.") 315 | setup(rank, world_size) 316 | map_location = f"cuda:{rank}" 317 | torch.cuda.set_device(map_location) 318 | 319 | if args.ckpt: # ignore current arguments 320 | ckpt = torch.load(args.ckpt, map_location=map_location) 321 | ckpt_p = torch.load('session/reenactment/checkpoints/500000.pt', map_location=map_location) 322 | 323 | train_args = ckpt["train_args"] 324 | print("load model:", args.ckpt) 325 | train_args.start_iter = int(args.ckpt.split("/")[-1].replace(".pt", "")) 326 | print(f"continue training from {train_args.start_iter} iter") 327 | args.ckpt = True 328 | args.start_iter = train_args.start_iter 329 | else: 330 | args.start_iter = 0 331 | 332 | # create model and move it to GPU with id rank 333 | model = DDPModel(device=map_location, args=args).to(map_location) 334 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 335 | model.train() 336 | 337 | ## let loss model in eval mode 338 | model.module.id_loss.eval() 339 | model.module.percept.eval() 340 | 341 | g_module = model.module.generator 342 | g_ema_module = model.module.g_ema 343 | g_ema_module.eval() 344 | accumulate(g_ema_module, g_module, 0) 345 | 346 | e_module = model.module.encoder 347 | e_ema_module = model.module.e_ema 348 | e_ema_module.eval() 349 | accumulate(e_ema_module, e_module, 0) 350 | 351 | e_module_p = model.module.encoder_p 352 | e_ema_module_p = model.module.e_ema_p 353 | e_ema_module_p.eval() 354 | accumulate(e_ema_module_p, e_module_p, 0) 355 | 356 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 357 | r1_val = 0 358 | g_optim = optim.Adam( 359 | g_module.parameters(), 360 | lr=args.lr, 361 | betas=(0, 0.99), 362 | ) 363 | 364 | d_optim = optim.Adam( 365 | model.module.discriminator.parameters(), 366 | lr=args.lr * d_reg_ratio, 367 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 368 | ) 369 | 370 | e_optim = optim.Adam( 371 | e_module.parameters(), 372 | lr=args.lr, 373 | betas=(0, 0.99), 374 | ) 375 | 376 | e_optim_p = optim.Adam( 377 | e_module_p.parameters(), 378 | lr=args.lr * 0.01, 379 | betas=(0, 0.99), 380 | ) 381 | 382 | accum = 0.999 383 | 384 | if args.ckpt: 385 | model.module.generator.load_state_dict(ckpt["generator"]) 386 | model.module.discriminator.load_state_dict(ckpt["discriminator"]) 387 | model.module.g_ema.load_state_dict(ckpt["g_ema"]) 388 | 389 | model.module.encoder.load_state_dict(ckpt["encoder"]) 390 | model.module.e_ema.load_state_dict(ckpt["e_ema"]) 391 | 392 | # load pose encoder 393 | print('load pose') 394 | model.module.encoder_p.load_state_dict(ckpt_p["encoder_p"]) 395 | model.module.e_ema_p.load_state_dict(ckpt_p["e_ema_p"]) 396 | 397 | 398 | transform = transforms.Compose([ 399 | transforms.Resize((256, 256)), 400 | transforms.ToTensor(), 401 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 402 | ]) 403 | 404 | save_dir = os.path.join("./session", args.checkname) 405 | os.makedirs(save_dir, 0o777, exist_ok=True) 406 | os.makedirs(save_dir + "/checkpoints", 0o777, exist_ok=True) 407 | os.makedirs(save_dir + "/imgs", 0o777, exist_ok=True) 408 | save_args(os.path.join('session',args.checkname,'args.yaml') , args) 409 | 410 | train_dataset = SwapTrainDataset(args.train_img_path, transform) 411 | val_dataset = SwapValDataset(args.val_img_path, transform) 412 | 413 | print(f"train_dataset: {len(train_dataset)}, val_dataset: {len(val_dataset)}") 414 | 415 | val_sampler = torch.utils.data.distributed.DistributedSampler( 416 | val_dataset, num_replicas=world_size, rank=rank, shuffle=True 417 | ) 418 | 419 | val_loader = data.DataLoader( 420 | val_dataset, 421 | batch_size=args.batch_per_gpu, 422 | drop_last=True, 423 | sampler=val_sampler, 424 | num_workers=args.num_workers, 425 | pin_memory=True, 426 | ) 427 | 428 | train_sampler = torch.utils.data.distributed.DistributedSampler( 429 | train_dataset, num_replicas=world_size, rank=rank, shuffle=True 430 | ) 431 | 432 | train_loader = data.DataLoader( 433 | train_dataset, 434 | batch_size=args.batch_per_gpu, 435 | drop_last=True, 436 | sampler=train_sampler, 437 | num_workers=args.num_workers, 438 | pin_memory=True, 439 | ) 440 | 441 | train_loader = sample_data(train_loader) 442 | pbar = range(args.start_iter, args.iter) 443 | pbar = tqdm(pbar, initial=args.start_iter, mininterval=1) 444 | 445 | epoch = -1 446 | gpu_group = dist.new_group(list(range(args.ngpus))) 447 | 448 | for i in pbar: 449 | if i > args.iter: 450 | print("Done!") 451 | break 452 | elif i % (len(train_dataset) // args.batch) == 0: 453 | epoch += 1 454 | val_sampler.set_epoch(epoch) 455 | train_sampler.set_epoch(epoch) 456 | print("epoch: ", epoch) 457 | 458 | trg_img, src_img, same = next(train_loader) 459 | trg_img = trg_img.to(map_location) 460 | src_img = src_img.to(map_location) 461 | same = same.to(map_location) 462 | 463 | requires_grad(model.module.discriminator, True) 464 | # D adv 465 | d_loss, real_score, fake_score = model([trg_img, src_img], "D") 466 | d_loss = d_loss.mean() 467 | 468 | d_optim.zero_grad() 469 | 470 | ( 471 | d_loss * args.lambda_d_loss 472 | ).backward() 473 | d_optim.step() 474 | 475 | # D reg 476 | 477 | d_regularize = i % args.d_reg_every == 0 478 | if d_regularize: 479 | d_reg_loss, r1_loss = model(trg_img, "D_reg") # r1 loss 480 | d_reg_loss = d_reg_loss.mean() 481 | d_optim.zero_grad() 482 | d_reg_loss.backward() 483 | d_optim.step() 484 | r1_val = r1_loss.mean().item() 485 | 486 | requires_grad(model.module.discriminator, False) 487 | 488 | # E_x_rec, 489 | x_rec_loss, perceptual_loss, indomainGAN_E_loss, id_loss, cx_loss, fake_img = model([trg_img, src_img, same], "E_x_rec") 490 | 491 | x_rec_loss = x_rec_loss.mean() * args.lambda_x_rec_loss 492 | perceptual_loss = perceptual_loss.mean() * args.lambda_perceptual_loss 493 | indomainGAN_E_loss = indomainGAN_E_loss.mean() * args.lambda_indomainGAN_E_loss 494 | id_loss = id_loss * args.lambda_id_loss 495 | cx_loss = cx_loss * args.lambda_cx_loss 496 | 497 | e_optim.zero_grad() 498 | g_optim.zero_grad() 499 | e_optim_p.zero_grad() 500 | 501 | encoder_loss = ( 502 | x_rec_loss 503 | + perceptual_loss 504 | + indomainGAN_E_loss 505 | + id_loss 506 | + cx_loss 507 | ) 508 | 509 | encoder_loss.backward() 510 | 511 | e_optim.step() 512 | g_optim.step() 513 | e_optim_p.step() 514 | 515 | pbar.set_description( 516 | (f"g: {indomainGAN_E_loss.item():.4f}; d: {d_loss.item():.4f}; r1: {r1_val:.4f}; rec: {x_rec_loss.item():.4f}; id: {id_loss.item():.4f}; cx: {cx_loss.item():.4f}; per: {perceptual_loss.item():.4f}") 517 | ) 518 | 519 | # log image 520 | if rank == 0: 521 | if i % args.image_interval == 0 or (i < 1000 and i % 25 == 0): 522 | parse_and_log_images(i, save_dir, trg_img, src_img, fake_img, title='imgs/train/') 523 | 524 | with torch.no_grad(): 525 | accumulate(g_ema_module, g_module, accum) 526 | accumulate(e_ema_module, e_module, accum) 527 | accumulate(e_ema_module_p, e_module_p, accum) 528 | 529 | if i % args.save_img_interval == 0: 530 | copy_norm_params(g_ema_module, g_module) 531 | copy_norm_params(e_ema_module, e_module) 532 | copy_norm_params(e_ema_module_p, e_module_p) 533 | 534 | x_rec_loss_avg, perceptual_loss_avg, cx_loss_avg = 0, 0, 0 535 | iter_num = 0 536 | for (trg_img, src_img, same) in tqdm(val_loader): 537 | trg_img = trg_img.to(map_location) 538 | src_img = src_img.to(map_location) 539 | same = same.to(map_location) 540 | 541 | x_rec_loss, perceptual_loss, cx_loss, fake_img = model([trg_img, src_img, same], "cal_mse_lpips") 542 | 543 | x_rec_loss_avg += x_rec_loss.mean() 544 | perceptual_loss_avg += perceptual_loss.mean() 545 | cx_loss_avg += cx_loss.mean() 546 | iter_num += 1 547 | 548 | # log images 549 | if rank == 0: 550 | parse_and_log_images(i, save_dir, trg_img, src_img, fake_img, title='imgs/test/') 551 | 552 | x_rec_loss_avg /= iter_num 553 | perceptual_loss_avg /= iter_num 554 | cx_loss_avg /= iter_num 555 | 556 | dist.reduce( 557 | x_rec_loss_avg, dst=0, op=dist.ReduceOp.SUM, group=gpu_group 558 | ) 559 | dist.reduce( 560 | perceptual_loss_avg, 561 | dst=0, 562 | op=dist.ReduceOp.SUM, 563 | group=gpu_group, 564 | ) 565 | if i % args.save_network_interval == 0: 566 | if rank == 0: 567 | x_rec_loss_avg = x_rec_loss_avg / args.ngpus 568 | perceptual_loss_avg = perceptual_loss_avg / args.ngpus 569 | x_rec_loss_avg_val = x_rec_loss_avg.item() 570 | perceptual_loss_avg_val = perceptual_loss_avg.item() 571 | cx_loss_avg_val = cx_loss_avg.item() 572 | 573 | print( 574 | f"x_rec_loss_avg: {x_rec_loss_avg_val}, perceptual_loss_avg: {perceptual_loss_avg_val}, cx_loss_avg: {cx_loss_avg_val}" 575 | ) 576 | torch.save( 577 | { 578 | "generator": model.module.generator.state_dict(), 579 | "discriminator": model.module.discriminator.state_dict(), 580 | "encoder": model.module.encoder.state_dict(), 581 | "encoder_p": model.module.encoder_p.state_dict(), 582 | "g_ema": g_ema_module.state_dict(), 583 | "e_ema": e_ema_module.state_dict(), 584 | "e_ema_p": e_ema_module_p.state_dict(), 585 | "train_args": args, 586 | }, 587 | f"{save_dir}/checkpoints/{str(i).zfill(6)}.pt", 588 | ) 589 | 590 | 591 | if __name__ == "__main__": 592 | parser = argparse.ArgumentParser() 593 | parser.add_argument("--checkname", type=str, default='exp-swap-id2.5-cx0.5') 594 | parser.add_argument("--describ", type=str, default='no') 595 | parser.add_argument("--train_img_path", type=str, default='data/CelebA-HQ/train/images/') 596 | parser.add_argument("--val_img_path", type=str, default='data/CelebA-HQ/val/images/') 597 | parser.add_argument( 598 | "--dataset", 599 | type=str, 600 | default="danbooru", 601 | choices=[ 602 | "celeba_hq", 603 | "afhq", 604 | "ffhq", 605 | "lsun/church_outdoor", 606 | "lsun/car", 607 | "lsun/bedroom", 608 | ], 609 | ) 610 | parser.add_argument("--iter", type=int, default=500000) 611 | parser.add_argument("--save_network_interval", type=int, default=5000) 612 | parser.add_argument("--save_img_interval", type=int, default=1000) 613 | parser.add_argument("--small_generator", action="store_true") 614 | parser.add_argument("--batch", type=int, default=8, help="total batch sizes") 615 | parser.add_argument("--size", type=int, choices=[128, 256, 512, 1024], default=256) 616 | parser.add_argument("--r1", type=float, default=10) 617 | parser.add_argument("--d_reg_every", type=int, default=16) 618 | parser.add_argument("--lr", type=float, default=0.001) 619 | parser.add_argument("--lr_mul", type=float, default=1) 620 | parser.add_argument("--channel_multiplier", type=int, default=2) 621 | parser.add_argument("--latent_channel_size", type=int, default=512) 622 | parser.add_argument("--latent_spatial_size", type=int, default=32) 623 | parser.add_argument("--num_workers", type=int, default=8) 624 | parser.add_argument("--image_interval", type=int, default=50) 625 | parser.add_argument( 626 | "--normalize_mode", 627 | type=str, 628 | choices=["LayerNorm", "InstanceNorm2d", "BatchNorm2d", "GroupNorm"], 629 | default="LayerNorm", 630 | ) 631 | parser.add_argument("--mapping_layer_num", type=int, default=8) 632 | 633 | parser.add_argument("--lambda_x_rec_loss", type=float, default=1) 634 | parser.add_argument("--lambda_d_loss", type=float, default=1) 635 | parser.add_argument("--lambda_id_loss", type=float, default=2.5) 636 | parser.add_argument("--lambda_cx_loss", type=float, default=0.5) 637 | parser.add_argument("--lambda_perceptual_loss", type=float, default=1) 638 | parser.add_argument("--lambda_indomainGAN_D_loss", type=float, default=1) 639 | parser.add_argument("--lambda_indomainGAN_E_loss", type=float, default=1) 640 | 641 | input_args = parser.parse_args() 642 | ngpus = 2 643 | print("{} GPUS!".format(ngpus)) 644 | 645 | assert input_args.batch % ngpus == 0 646 | input_args.batch_per_gpu = input_args.batch // ngpus 647 | input_args.ngpus = ngpus 648 | print("{} batch per gpu!".format(input_args.batch_per_gpu)) 649 | 650 | run(ddp_main, ngpus, input_args) -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/__init__.py -------------------------------------------------------------------------------- /training/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def face_augmentation(self, img, crop_size): 12 | img = self._color_transfer(img) 13 | img = self._reshape(img, crop_size) 14 | img = self._blur_and_sharp(img) 15 | return img 16 | 17 | def aug(self, img, crop_size): 18 | img = img[None] 19 | img = self.face_augmentation(img, crop_size) 20 | img = img[0] 21 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 22 | return img 23 | 24 | def _blur_and_sharp(self, img): 25 | blur = np.random.randint(0, 2) 26 | img2 = img.copy() 27 | output = [] 28 | for i in range(len(img2)): 29 | if blur: 30 | ksize = np.random.choice([3, 5, 7, 9]) 31 | output.append(cv2.medianBlur(img2[i], ksize)) 32 | else: 33 | kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) 34 | output.append(cv2.filter2D(img2[i], -1, kernel)) 35 | output = np.stack(output) 36 | return output 37 | 38 | def _color_transfer(self, img): 39 | 40 | transfer_c = np.random.uniform(0.3, 1.6) 41 | 42 | start_channel = np.random.randint(0, 2) 43 | end_channel = np.random.randint(start_channel + 1, 4) 44 | 45 | img2 = img.copy() 46 | 47 | img2[:, :, :, start_channel:end_channel] = np.minimum(np.maximum(img[:, :, :, start_channel:end_channel] * transfer_c, 48 | np.zeros(img[:, :, :, start_channel:end_channel].shape)), 49 | np.ones(img[:, :, :, start_channel:end_channel].shape) * 255) 50 | return img2 51 | 52 | def perspective_transform(self, img, crop_size=224, pers_size=10, enlarge_size=-10): 53 | h, w, _ = img.shape 54 | dst = np.array([ 55 | [-enlarge_size, -enlarge_size], 56 | [-enlarge_size + pers_size, w + enlarge_size], 57 | [h + enlarge_size, -enlarge_size], 58 | [h + enlarge_size - pers_size, w + enlarge_size], ], dtype=np.float32) 59 | src = np.array([[-enlarge_size, -enlarge_size], [-enlarge_size, w + enlarge_size], 60 | [h + enlarge_size, -enlarge_size], [h + enlarge_size, w + enlarge_size]]).astype(np.float32()) 61 | M = cv2.getPerspectiveTransform(src, dst) 62 | warped = cv2.warpPerspective(img, M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE) 63 | return warped, M 64 | 65 | def _reshape(self, img, crop_size): 66 | reshape = np.random.randint(0, 2) 67 | reshape_size = np.random.randint(15, 25) 68 | extra_padding_size = np.random.randint(0, reshape_size // 2) 69 | pers_size = np.random.randint(20, 30) * pow(-1, np.random.randint(2)) 70 | 71 | enlarge_size = np.random.randint(20, 40) * pow(-1, np.random.randint(2)) 72 | shape = img[0].shape 73 | img2 = img.copy() 74 | output = [] 75 | for i in range(len(img2)): 76 | if reshape: 77 | im = cv2.resize(img2[i], (shape[0] - reshape_size*2, shape[1] + reshape_size*2)) 78 | im = cv2.copyMakeBorder(im, 0, 0, reshape_size + extra_padding_size, reshape_size + extra_padding_size, cv2.cv2.BORDER_REFLECT) 79 | im = im[reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :, :] 80 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) 81 | output.append(im) 82 | else: 83 | im = cv2.resize(img2[i], (shape[0] + reshape_size*2, shape[1] - reshape_size*2)) 84 | im = cv2.copyMakeBorder(im, reshape_size + extra_padding_size, reshape_size + extra_padding_size, 0, 0, cv2.cv2.BORDER_REFLECT) 85 | im = im[:, reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :] 86 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) 87 | output.append(im) 88 | output = np.stack(output) 89 | return output 90 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | """ 4 | StyleMapGAN 5 | Copyright (c) 2021-present NAVER Corp. 6 | 7 | This work is licensed under the Creative Commons Attribution-NonCommercial 8 | 4.0 International License. To view a copy of this license, visit 9 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 10 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 11 | """ 12 | import random 13 | from PIL import Image 14 | from torch.utils.data import Dataset 15 | import os 16 | import torch 17 | from training.base_dataset import BaseDataset 18 | import collections 19 | import cv2 20 | 21 | # class MultiResolutionDataset(Dataset): 22 | # def __init__(self, path, transform, resolution=256): 23 | # self.env = lmdb.open( 24 | # path, 25 | # max_readers=32, 26 | # readonly=True, 27 | # lock=False, 28 | # readahead=False, 29 | # meminit=False, 30 | # ) 31 | # 32 | # if not self.env: 33 | # raise IOError("Cannot open lmdb dataset", path) 34 | # 35 | # with self.env.begin(write=False) as txn: 36 | # self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 37 | # 38 | # self.resolution = resolution 39 | # self.transform = transform 40 | # 41 | # def __len__(self): 42 | # return self.length 43 | # 44 | # def __getitem__(self, index): 45 | # with self.env.begin(write=False) as txn: 46 | # key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 47 | # img_bytes = txn.get(key) 48 | # 49 | # buffer = BytesIO(img_bytes) 50 | # img = Image.open(buffer) 51 | # img = self.transform(img) 52 | # 53 | # return img 54 | 55 | 56 | # class GTMaskDataset(Dataset): 57 | # def __init__(self, dataset_folder, transform, resolution=256): 58 | # 59 | # self.env = lmdb.open( 60 | # f"{dataset_folder}/LMDB_test", 61 | # max_readers=32, 62 | # readonly=True, 63 | # lock=False, 64 | # readahead=False, 65 | # meminit=False, 66 | # ) 67 | # 68 | # if not self.env: 69 | # raise IOError("Cannot open lmdb dataset", f"{dataset_folder}/LMDB_test") 70 | # 71 | # with self.env.begin(write=False) as txn: 72 | # self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 73 | # 74 | # self.resolution = resolution 75 | # self.transform = transform 76 | # 77 | # # convert filename to celeba_hq index 78 | # CelebA_HQ_to_CelebA = ( 79 | # f"{dataset_folder}/local_editing/CelebA-HQ-to-CelebA-mapping.txt" 80 | # ) 81 | # CelebA_to_CelebA_HQ_dict = {} 82 | # 83 | # original_test_path = f"{dataset_folder}/raw_images/test/images" 84 | # mask_label_path = f"{dataset_folder}/local_editing/GT_labels" 85 | # 86 | # with open(CelebA_HQ_to_CelebA, "r") as fp: 87 | # read_line = fp.readline() 88 | # attrs = re.sub(" +", " ", read_line).strip().split(" ") 89 | # while True: 90 | # read_line = fp.readline() 91 | # 92 | # if not read_line: 93 | # break 94 | # 95 | # idx, orig_idx, orig_file = ( 96 | # re.sub(" +", " ", read_line).strip().split(" ") 97 | # ) 98 | # 99 | # CelebA_to_CelebA_HQ_dict[orig_file] = idx 100 | # 101 | # self.mask = [] 102 | # 103 | # for filename in os.listdir(original_test_path): 104 | # CelebA_HQ_filename = CelebA_to_CelebA_HQ_dict[filename] 105 | # CelebA_HQ_filename = CelebA_HQ_filename + ".png" 106 | # self.mask.append(os.path.join(mask_label_path, CelebA_HQ_filename)) 107 | # 108 | # def __len__(self): 109 | # return self.length 110 | # 111 | # def __getitem__(self, index): 112 | # with self.env.begin(write=False) as txn: 113 | # key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 114 | # img_bytes = txn.get(key) 115 | # 116 | # buffer = BytesIO(img_bytes) 117 | # img = Image.open(buffer) 118 | # img = self.transform(img) 119 | # 120 | # mask = Image.open(self.mask[index]) 121 | # 122 | # mask = mask.resize((self.resolution, self.resolution), Image.NEAREST) 123 | # mask = transforms.ToTensor()(mask) 124 | # 125 | # mask = mask.squeeze() 126 | # mask *= 255 127 | # mask = mask.long() 128 | # 129 | # assert mask.shape == (self.resolution, self.resolution) 130 | # return img, mask 131 | # 132 | # 133 | # class DataSetFromDir(Dataset): 134 | # def __init__(self, main_dir, transform): 135 | # self.main_dir = main_dir 136 | # self.transform = transform 137 | # all_imgs = os.listdir(main_dir) 138 | # self.total_imgs = [] 139 | # 140 | # for img in all_imgs: 141 | # if ".png" in img: 142 | # self.total_imgs.append(img) 143 | # 144 | # def __len__(self): 145 | # return len(self.total_imgs) 146 | # 147 | # def __getitem__(self, idx): 148 | # img_loc = os.path.join(self.main_dir, self.total_imgs[idx]) 149 | # image = Image.open(img_loc).convert("RGB") 150 | # tensor_image = self.transform(image) 151 | # return tensor_image 152 | # 153 | # 154 | # class DataSetTestLocalEditing(Dataset): 155 | # def __init__(self, main_dir, transform): 156 | # self.main_dir = main_dir 157 | # self.transform = transform 158 | # 159 | # all_imgs = os.listdir(os.path.join(main_dir, "mask")) 160 | # self.total_imgs = [] 161 | # 162 | # for img in all_imgs: 163 | # if ".png" in img: 164 | # self.total_imgs.append(img) 165 | # 166 | # def __len__(self): 167 | # return len(self.total_imgs) 168 | # 169 | # def __getitem__(self, idx): 170 | # image_mask = self.transform( 171 | # Image.open( 172 | # os.path.join(self.main_dir, "mask", self.total_imgs[idx]) 173 | # ).convert("RGB") 174 | # ) 175 | # image_reference = self.transform( 176 | # Image.open( 177 | # os.path.join(self.main_dir, "reference_image", self.total_imgs[idx]) 178 | # ).convert("RGB") 179 | # ) 180 | # # image_reference_recon = self.transform(Image.open(os.path.join(self.main_dir, 'reference_image', self.total_imgs[idx].replace('.png', '_recon_img.png'))).convert("RGB")) 181 | # 182 | # image_source = self.transform( 183 | # Image.open( 184 | # os.path.join(self.main_dir, "source_image", self.total_imgs[idx]) 185 | # ).convert("RGB") 186 | # ) 187 | # # image_source_recon = self.transform(Image.open(os.path.join(self.main_dir, 'source_image', self.total_imgs[idx].replace('.png', '_recon_img.png'))).convert("RGB")) 188 | # 189 | # image_synthesized = self.transform( 190 | # Image.open( 191 | # os.path.join(self.main_dir, "synthesized_image", self.total_imgs[idx]) 192 | # ).convert("RGB") 193 | # ) 194 | # 195 | # return image_mask, image_reference, image_source, image_synthesized 196 | 197 | 198 | class SwapTrainDataset(Dataset): 199 | def __init__(self, root, transform=None): 200 | super(SwapTrainDataset, self).__init__() 201 | self.root = root 202 | self.files = [ 203 | os.path.join(path, filename) 204 | for path, dirs, files in os.walk(root) 205 | for filename in files 206 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg") 207 | ] 208 | self.transform = transform 209 | 210 | def __getitem__(self, index): 211 | l = len(self.files) 212 | s_idx = index % l 213 | if index >= 4 * l: 214 | f_idx = s_idx 215 | 216 | else: 217 | f_idx = random.randrange(l) 218 | 219 | if f_idx == s_idx: 220 | same = torch.ones(1) 221 | else: 222 | same = torch.zeros(1) 223 | 224 | f_img = Image.open(self.files[f_idx]) 225 | s_img = Image.open(self.files[s_idx]) 226 | 227 | f_img = f_img.convert('RGB') 228 | s_img = s_img.convert('RGB') 229 | 230 | if self.transform is not None: 231 | f_img = self.transform(f_img) 232 | s_img = self.transform(s_img) 233 | 234 | return f_img, s_img, same 235 | 236 | def __len__(self): 237 | return len(self.files) * 5 238 | 239 | 240 | class SwapValDataset(Dataset): 241 | def __init__(self, root, transform=None): 242 | super(SwapValDataset, self).__init__() 243 | self.root = root 244 | self.files = [ 245 | os.path.join(path, filename) 246 | for path, dirs, files in os.walk(root) 247 | for filename in files 248 | if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg") 249 | ] 250 | self.transfrom = transform 251 | 252 | def __getitem__(self, index): 253 | l = len(self.files) 254 | 255 | f_idx = index // l 256 | s_idx = index % l 257 | 258 | if f_idx == s_idx: 259 | same = torch.ones(1) 260 | else: 261 | same = torch.zeros(1) 262 | 263 | f_img = Image.open(self.files[f_idx]) 264 | s_img = Image.open(self.files[s_idx]) 265 | 266 | f_img = f_img.convert('RGB') 267 | s_img = s_img.convert('RGB') 268 | 269 | if self.transfrom is not None: 270 | f_img = self.transfrom(f_img) 271 | s_img = self.transfrom(s_img) 272 | 273 | return f_img, s_img, same 274 | 275 | def __len__(self): 276 | return len(self.files) * len(self.files) 277 | 278 | 279 | class SwapTestTxtDataset(Dataset): 280 | def __init__(self, root, root_txt, transform=None, suffix='.png'): 281 | super(SwapTestTxtDataset, self).__init__() 282 | self.root = root 283 | self.txt = root_txt 284 | f = open(root_txt) 285 | file_pair = [s.strip() for s in f.readlines()] 286 | self.file_trg = [root + s.replace(suffix, '').split('_')[0] + suffix for s in file_pair] 287 | self.file_src = [root + s.replace(suffix, '').split('_')[1] + suffix for s in file_pair] 288 | 289 | self.transform = transform 290 | 291 | def __getitem__(self, index): 292 | 293 | f_img = Image.open(self.file_trg[index]) 294 | s_img = Image.open(self.file_src[index]) 295 | 296 | f_img = f_img.convert('RGB') 297 | s_img = s_img.convert('RGB') 298 | 299 | f_img_n = self.file_trg[index].split('/')[-1].split('.')[0] 300 | s_img_n = self.file_src[index].split('/')[-1].split('.')[0] 301 | 302 | if self.transform is not None: 303 | f_img = self.transform(f_img) 304 | s_img = self.transform(s_img) 305 | 306 | return [f_img, f_img_n], [s_img, s_img_n] 307 | 308 | def __len__(self): 309 | return len(self.file_trg) 310 | 311 | 312 | class Dataset_scale_trans(BaseDataset): 313 | def __init__(self, data_root, train=True, scale=True, transforms=None): 314 | self.data_root = data_root 315 | mode = 'test' if train else 'val' 316 | self.root = os.path.join(self.data_root, mode) 317 | 318 | videos = sorted(os.listdir(self.root)) 319 | 320 | self.video_items, self.person_ids = self.get_video_index(videos) 321 | self.idx_by_person_id = self.group_by_key(self.video_items, key='person_id') 322 | 323 | self.person_ids_woaug = self.person_ids 324 | self.person_ids = self.person_ids * 100 325 | 326 | self.transforms = transforms 327 | self.scale = scale 328 | 329 | def default_loader(self, path): 330 | return Image.open(path).convert('RGB') 331 | 332 | def get_video_index(self, videos): 333 | video_items = [] 334 | for video in videos: 335 | video_items.append(self.Video_Item(video)) 336 | 337 | person_ids = sorted(list({video.split('#')[0] for video in videos})) 338 | return video_items, person_ids 339 | 340 | def group_by_key(self, video_list, key): 341 | return_dict = collections.defaultdict(list) 342 | for index, video_item in enumerate(video_list): 343 | return_dict[video_item[key]].append(index) 344 | return return_dict 345 | 346 | def Video_Item(self, video_name): 347 | video_item = {} 348 | video_item['video_name'] = video_name 349 | video_item['person_id'] = video_name.split('#')[0] 350 | video_item['num_frame'] = [int(float(t[:4])) for t in os.listdir(os.path.join(self.root, video_name))] 351 | 352 | return video_item 353 | 354 | def random_select_frames(self, video_item, k): 355 | num_frame = video_item['num_frame'] 356 | frame_idx = random.choices(num_frame, k=k) 357 | return frame_idx 358 | 359 | def __len__(self): 360 | return len(self.person_ids) 361 | 362 | def __getitem__(self, index): 363 | # sample pairs 364 | person_id_s = self.person_ids[index] 365 | video_item_s = self.video_items[random.choices(self.idx_by_person_id[person_id_s], k=1)[0]] 366 | 367 | [frame_source_1, frame_source_2] = self.random_select_frames(video_item_s, 2) 368 | 369 | img_s1_path = os.path.join(self.root, video_item_s['video_name'], str(frame_source_1).zfill(4) + '.png') 370 | img_s2_path = os.path.join(self.root, video_item_s['video_name'], str(frame_source_2).zfill(4) + '.png') 371 | 372 | img_s1, img_s2 = self.default_loader(img_s1_path), self.default_loader(img_s2_path) 373 | if self.transforms: 374 | img_s1, img_s2_gt = self.transforms(img_s1), self.transforms(img_s2) 375 | 376 | if self.scale: 377 | img_s2_scale = self.transforms(self.aug(cv2.imread(img_s2_path), 256)) 378 | 379 | return [img_s1, img_s2_gt, img_s2_scale], [img_s1_path, img_s2_path] 380 | 381 | 382 | class Dataset_for_test(Dataset): 383 | def __init__(self, data_root, mode='test', root_txt='', suffix='.jpg', transforms=None): 384 | self.data_root = data_root 385 | self.root = os.path.join(self.data_root) 386 | 387 | f = open(root_txt) 388 | file_pair = [s.strip() for s in f.readlines()] 389 | self.file_pair = file_pair 390 | 391 | self.file_src = [self.root + '/' + s.replace(suffix, '').split('_')[0] + suffix for s in file_pair] 392 | self.file_drv = [self.root + '/' + s.replace(suffix, '').split('_')[1] + suffix for s in file_pair] 393 | 394 | self.transforms = transforms 395 | 396 | def default_loader(self, path): 397 | return Image.open(path).convert('RGB') 398 | 399 | def __len__(self): 400 | return len(self.file_src) 401 | 402 | def __getitem__(self, index): 403 | # sample pairs 404 | img_s_path = self.file_src[index] 405 | img_d_path = self.file_drv[index] 406 | 407 | img_s, img_d = self.default_loader(img_s_path), self.default_loader(img_d_path) 408 | if self.transforms: 409 | img_s, img_d = self.transforms(img_s), self.transforms(img_d) 410 | 411 | return [img_s, img_d], [img_s_path, img_d_path], self.file_pair[index] 412 | -------------------------------------------------------------------------------- /training/dataset_ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | StyleMapGAN 3 | Copyright (c) 2021-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | # Dataset code for the DDP training setting. 12 | 13 | from io import BytesIO 14 | from PIL import Image 15 | from torch.utils.data import Dataset 16 | import lmdb 17 | 18 | 19 | class MultiResolutionDataset(Dataset): 20 | def __init__(self, path, transform, resolution=256): 21 | self.path = path 22 | self.resolution = resolution 23 | self.transform = transform 24 | self.length = None 25 | 26 | def _open(self): 27 | self.env = lmdb.open( 28 | self.path, 29 | max_readers=32, 30 | readonly=True, 31 | lock=False, 32 | readahead=False, 33 | meminit=False, 34 | ) 35 | 36 | if not self.env: 37 | raise IOError(f"Cannot open lmdb dataset {self.path}") 38 | 39 | with self.env.begin(write=False) as txn: 40 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 41 | 42 | def _close(self): 43 | if self.env is not None: 44 | self.env.close() 45 | self.env = None 46 | 47 | def __len__(self): 48 | if self.length is None: 49 | self._open() 50 | self._close() 51 | 52 | return self.length 53 | 54 | def __getitem__(self, index): 55 | if self.env is None: 56 | self._open() 57 | 58 | with self.env.begin(write=False) as txn: 59 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 60 | img_bytes = txn.get(key) 61 | 62 | buffer = BytesIO(img_bytes) 63 | img = Image.open(buffer) 64 | img = self.transform(img) 65 | 66 | return img 67 | -------------------------------------------------------------------------------- /training/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/__init__.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/__init__.py 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | from skimage.measure import compare_ssim 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | from . import dist_model 15 | 16 | 17 | class exportPerceptualLoss(torch.nn.Module): 18 | def __init__( 19 | self, model="net-lin", net="alex", colorspace="rgb", spatial=False, use_gpu=True 20 | ): # VGG using our perceptually-learned weights (LPIPS metric) 21 | super(exportPerceptualLoss, self).__init__() 22 | print("Setting up Perceptual loss...") 23 | self.use_gpu = use_gpu 24 | self.spatial = spatial 25 | self.model = dist_model.exportModel() 26 | self.model.initialize( 27 | model=model, 28 | net=net, 29 | use_gpu=use_gpu, 30 | colorspace=colorspace, 31 | spatial=self.spatial, 32 | ) 33 | print("...[%s] initialized" % self.model.name()) 34 | print("...Done") 35 | 36 | def forward(self, pred, target): 37 | return self.model.forward(target, pred) 38 | 39 | 40 | class PerceptualLoss(torch.nn.Module): 41 | def __init__( 42 | self, 43 | model="net-lin", 44 | net="alex", 45 | colorspace="rgb", 46 | spatial=False, 47 | use_gpu=True, 48 | gpu_ids=[0], 49 | ): # VGG using our perceptually-learned weights (LPIPS metric) 50 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 51 | super(PerceptualLoss, self).__init__() 52 | print("Setting up Perceptual loss...") 53 | self.use_gpu = use_gpu 54 | self.spatial = spatial 55 | self.gpu_ids = gpu_ids 56 | self.model = dist_model.DistModel() 57 | self.model.initialize( 58 | model=model, 59 | net=net, 60 | use_gpu=use_gpu, 61 | colorspace=colorspace, 62 | spatial=self.spatial, 63 | gpu_ids=gpu_ids, 64 | ) 65 | print("...[%s] initialized" % self.model.name()) 66 | print("...Done") 67 | 68 | def forward(self, pred, target, normalize=False): 69 | """ 70 | Pred and target are Variables. 71 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 72 | If normalize is False, assumes the images are already between [-1,+1] 73 | 74 | Inputs pred and target are Nx3xHxW 75 | Output pytorch Variable N long 76 | """ 77 | 78 | if normalize: 79 | target = 2 * target - 1 80 | pred = 2 * pred - 1 81 | 82 | return self.model.forward(target, pred) 83 | 84 | 85 | def normalize_tensor(in_feat, eps=1e-10): 86 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 87 | return in_feat / (norm_factor + eps) 88 | 89 | 90 | def l2(p0, p1, range=255.0): 91 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 92 | 93 | 94 | def psnr(p0, p1, peak=255.0): 95 | return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 96 | 97 | 98 | def dssim(p0, p1, range=255.0): 99 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 100 | 101 | 102 | def rgb2lab(in_img, mean_cent=False): 103 | from skimage import color 104 | 105 | img_lab = color.rgb2lab(in_img) 106 | if mean_cent: 107 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 108 | return img_lab 109 | 110 | 111 | def tensor2np(tensor_obj): 112 | # change dimension of a tensor object into a numpy array 113 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 114 | 115 | 116 | def np2tensor(np_obj): 117 | # change dimenion of np array into tensor array 118 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 119 | 120 | 121 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 122 | # image tensor to lab tensor 123 | from skimage import color 124 | 125 | img = tensor2im(image_tensor) 126 | img_lab = color.rgb2lab(img) 127 | if mc_only: 128 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 129 | if to_norm and not mc_only: 130 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 131 | img_lab = img_lab / 100.0 132 | 133 | return np2tensor(img_lab) 134 | 135 | 136 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 137 | from skimage import color 138 | import warnings 139 | 140 | warnings.filterwarnings("ignore") 141 | 142 | lab = tensor2np(lab_tensor) * 100.0 143 | lab[:, :, 0] = lab[:, :, 0] + 50 144 | 145 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 146 | if return_inbnd: 147 | # convert back to lab, see if we match 148 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 149 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 150 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 151 | return (im2tensor(rgb_back), mask) 152 | else: 153 | return im2tensor(rgb_back) 154 | 155 | 156 | # def rgb2lab(input): 157 | # from skimage import color 158 | # 159 | # return color.rgb2lab(input / 255.0) 160 | 161 | 162 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 163 | image_numpy = image_tensor[0].cpu().float().numpy() 164 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 165 | return image_numpy.astype(imtype) 166 | 167 | 168 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 169 | return torch.Tensor( 170 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 171 | ) 172 | 173 | 174 | def tensor2vec(vector_tensor): 175 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 176 | 177 | 178 | def voc_ap(rec, prec, use_07_metric=False): 179 | """ap = voc_ap(rec, prec, [use_07_metric]) 180 | Compute VOC AP given precision and recall. 181 | If use_07_metric is true, uses the 182 | VOC 07 11 point method (default:False). 183 | """ 184 | if use_07_metric: 185 | # 11 point metric 186 | ap = 0.0 187 | for t in np.arange(0.0, 1.1, 0.1): 188 | if np.sum(rec >= t) == 0: 189 | p = 0 190 | else: 191 | p = np.max(prec[rec >= t]) 192 | ap = ap + p / 11.0 193 | else: 194 | # correct AP calculation 195 | # first append sentinel values at the end 196 | mrec = np.concatenate(([0.0], rec, [1.0])) 197 | mpre = np.concatenate(([0.0], prec, [0.0])) 198 | 199 | # compute the precision envelope 200 | for i in range(mpre.size - 1, 0, -1): 201 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 202 | 203 | # to calculate area under PR curve, look for points 204 | # where X axis (recall) changes value 205 | i = np.where(mrec[1:] != mrec[:-1])[0] 206 | 207 | # and sum (\Delta recall) * prec 208 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 209 | return ap 210 | 211 | 212 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 213 | # # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 214 | # image_numpy = image_tensor[0].cpu().float().numpy() 215 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 216 | # return image_numpy.astype(imtype) 217 | # 218 | # 219 | # def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 220 | # # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 221 | # return torch.Tensor( 222 | # (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 223 | # ) 224 | -------------------------------------------------------------------------------- /training/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/base_model.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/trainer.py 4 | """ 5 | import os 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class BaseModel: 11 | def __init__(self): 12 | pass 13 | 14 | def name(self): 15 | return "BaseModel" 16 | 17 | def initialize(self, use_gpu=True, gpu_ids=[0]): 18 | self.use_gpu = use_gpu 19 | self.gpu_ids = gpu_ids 20 | 21 | def forward(self): 22 | pass 23 | 24 | # def get_image_paths(self): 25 | # pass 26 | 27 | def optimize_parameters(self): 28 | pass 29 | 30 | def get_current_visuals(self): 31 | return self.input 32 | 33 | def get_current_errors(self): 34 | return {} 35 | 36 | def save(self, label): 37 | pass 38 | 39 | # helper saving function that can be used by subclasses 40 | def save_network(self, network, path, network_label, epoch_label): 41 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 42 | save_path = os.path.join(path, save_filename) 43 | torch.save(network.state_dict(), save_path) 44 | 45 | # helper loading function that can be used by subclasses 46 | def load_network(self, network, network_label, epoch_label): 47 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 48 | save_path = os.path.join(self.save_dir, save_filename) 49 | print("Loading network from %s" % save_path) 50 | network.load_state_dict(torch.load(save_path)) 51 | 52 | def update_learning_rate(self): 53 | pass 54 | 55 | def get_image_paths(self): 56 | return self.image_paths 57 | 58 | def save_done(self, flag=False): 59 | np.save(os.path.join(self.save_dir, "done_flag"), flag) 60 | np.savetxt( 61 | os.path.join(self.save_dir, "done_flag"), 62 | [ 63 | flag, 64 | ], 65 | fmt="%i", 66 | ) 67 | -------------------------------------------------------------------------------- /training/lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/dist_model.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/trainer.py 4 | """ 5 | 6 | from __future__ import absolute_import 7 | import os 8 | import torch 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | from .base_model import BaseModel 12 | from scipy.ndimage import zoom 13 | from tqdm import tqdm 14 | import numpy as np 15 | 16 | from . import networks_basic as networks 17 | from training import lpips as util 18 | 19 | 20 | class exportModel(torch.nn.Module): 21 | def name(self): 22 | return self.model_name 23 | 24 | def initialize( 25 | self, 26 | model="net-lin", 27 | net="vgg", 28 | colorspace="Lab", 29 | pnet_rand=False, 30 | pnet_tune=False, 31 | model_path=None, 32 | use_gpu=True, 33 | printNet=False, 34 | spatial=False, 35 | is_train=False, 36 | lr=0.0001, 37 | beta1=0.5, 38 | version="0.1", 39 | ): 40 | 41 | self.model = model 42 | self.net = net 43 | self.is_train = is_train 44 | self.spatial = spatial 45 | self.use_gpu = use_gpu 46 | self.model_name = "%s [%s]" % (model, net) 47 | 48 | assert self.model == "net-lin" # pretrained net + linear layer 49 | self.net = networks.PNetLin( 50 | pnet_rand=pnet_rand, 51 | pnet_tune=pnet_tune, 52 | pnet_type=net, 53 | use_dropout=True, 54 | spatial=spatial, 55 | version=version, 56 | lpips=True, 57 | ) 58 | kw = {} 59 | if not use_gpu: 60 | kw["map_location"] = "cpu" 61 | if model_path is None: 62 | import inspect 63 | 64 | model_path = os.path.abspath( 65 | os.path.join( 66 | inspect.getfile(self.initialize), 67 | "..", 68 | "weights/v%s/%s.pth" % (version, net), 69 | ) 70 | ) 71 | 72 | assert not is_train 73 | print("Loading model from: %s" % model_path) 74 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 75 | self.net.eval() 76 | 77 | if printNet: 78 | print("---------- Networks initialized -------------") 79 | networks.print_network(self.net) 80 | print("-----------------------------------------------") 81 | 82 | def forward(self, in0, in1, retPerLayer=False): 83 | 84 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 85 | 86 | 87 | class DistModel(BaseModel): 88 | def name(self): 89 | return self.model_name 90 | 91 | def initialize( 92 | self, 93 | model="net-lin", 94 | net="alex", 95 | colorspace="Lab", 96 | pnet_rand=False, 97 | pnet_tune=False, 98 | model_path=None, 99 | use_gpu=True, 100 | printNet=False, 101 | spatial=False, 102 | is_train=False, 103 | lr=0.0001, 104 | beta1=0.5, 105 | version="0.1", 106 | gpu_ids=[0], 107 | ): 108 | 109 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 110 | 111 | self.model = model 112 | self.net = net 113 | self.is_train = is_train 114 | self.spatial = spatial 115 | self.gpu_ids = gpu_ids 116 | self.model_name = "%s [%s]" % (model, net) 117 | 118 | if self.model == "net-lin": # pretrained net + linear layer 119 | self.net = networks.PNetLin( 120 | pnet_rand=pnet_rand, 121 | pnet_tune=pnet_tune, 122 | pnet_type=net, 123 | use_dropout=True, 124 | spatial=spatial, 125 | version=version, 126 | lpips=True, 127 | ) 128 | kw = {} 129 | if not use_gpu: 130 | kw["map_location"] = "cpu" 131 | if model_path is None: 132 | import inspect 133 | 134 | model_path = os.path.abspath( 135 | os.path.join( 136 | inspect.getfile(self.initialize), 137 | "..", 138 | "weights/v%s/%s.pth" % (version, net), 139 | ) 140 | ) 141 | 142 | if not is_train: 143 | print("Loading model from: %s" % model_path) 144 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 145 | 146 | elif self.model == "net": # pretrained network 147 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 148 | elif self.model in ["L2", "l2"]: 149 | self.net = networks.L2( 150 | use_gpu=use_gpu, colorspace=colorspace 151 | ) # not really a network, only for testing 152 | self.model_name = "L2" 153 | elif self.model in ["DSSIM", "dssim", "SSIM", "ssim"]: 154 | self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace) 155 | self.model_name = "SSIM" 156 | else: 157 | raise ValueError("Model [%s] not recognized." % self.model) 158 | 159 | self.parameters = list(self.net.parameters()) 160 | 161 | if self.is_train: # training mode 162 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 163 | self.rankLoss = networks.BCERankingLoss() 164 | self.parameters += list(self.rankLoss.net.parameters()) 165 | self.lr = lr 166 | self.old_lr = lr 167 | self.optimizer_net = torch.optim.Adam( 168 | self.parameters, lr=lr, betas=(beta1, 0.999) 169 | ) 170 | else: # test mode 171 | self.net.eval() 172 | 173 | if use_gpu: 174 | self.net.to(gpu_ids[0]) 175 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 176 | if self.is_train: 177 | self.rankLoss = self.rankLoss.to( 178 | device=gpu_ids[0] 179 | ) # just put this on GPU0 180 | 181 | if printNet: 182 | print("---------- Networks initialized -------------") 183 | networks.print_network(self.net) 184 | print("-----------------------------------------------") 185 | 186 | def forward(self, in0, in1, retPerLayer=False): 187 | 188 | 189 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 190 | 191 | # ***** training FUNCTIONS ***** 192 | def optimize_parameters(self): 193 | self.forward_train() 194 | self.optimizer_net.zero_grad() 195 | self.backward_train() 196 | self.optimizer_net.step() 197 | self.clamp_weights() 198 | 199 | def clamp_weights(self): 200 | for module in self.net.modules(): 201 | if hasattr(module, "weight") and module.kernel_size == (1, 1): 202 | module.weight.data = torch.clamp(module.weight.data, min=0) 203 | 204 | def set_input(self, data): 205 | self.input_ref = data["ref"] 206 | self.input_p0 = data["p0"] 207 | self.input_p1 = data["p1"] 208 | self.input_judge = data["judge"] 209 | 210 | if self.use_gpu: 211 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 212 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 213 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 214 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 215 | 216 | self.var_ref = Variable(self.input_ref, requires_grad=True) 217 | self.var_p0 = Variable(self.input_p0, requires_grad=True) 218 | self.var_p1 = Variable(self.input_p1, requires_grad=True) 219 | 220 | def forward_train(self): # run forward pass 221 | # print(self.net.module.scaling_layer.shift) 222 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 223 | 224 | self.d0 = self.forward(self.var_ref, self.var_p0) 225 | self.d1 = self.forward(self.var_ref, self.var_p1) 226 | self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) 227 | 228 | self.var_judge = Variable(1.0 * self.input_judge).view(self.d0.size()) 229 | 230 | self.loss_total = self.rankLoss.forward( 231 | self.d0, self.d1, self.var_judge * 2.0 - 1.0 232 | ) 233 | 234 | return self.loss_total 235 | 236 | def backward_train(self): 237 | torch.mean(self.loss_total).backward() 238 | 239 | def compute_accuracy(self, d0, d1, judge): 240 | """ d0, d1 are Variables, judge is a Tensor """ 241 | d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() 242 | judge_per = judge.cpu().numpy().flatten() 243 | return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) 244 | 245 | def get_current_errors(self): 246 | retDict = OrderedDict( 247 | [("loss_total", self.loss_total.data.cpu().numpy()), ("acc_r", self.acc_r)] 248 | ) 249 | 250 | for key in retDict.keys(): 251 | retDict[key] = np.mean(retDict[key]) 252 | 253 | return retDict 254 | 255 | def get_current_visuals(self): 256 | zoom_factor = 256 / self.var_ref.data.size()[2] 257 | 258 | ref_img = util.tensor2im(self.var_ref.data) 259 | p0_img = util.tensor2im(self.var_p0.data) 260 | p1_img = util.tensor2im(self.var_p1.data) 261 | 262 | ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) 263 | p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) 264 | p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) 265 | 266 | return OrderedDict( 267 | [("ref", ref_img_vis), ("p0", p0_img_vis), ("p1", p1_img_vis)] 268 | ) 269 | 270 | def save(self, path, label): 271 | if self.use_gpu: 272 | self.save_network(self.net.module, path, "", label) 273 | else: 274 | self.save_network(self.net, path, "", label) 275 | self.save_network(self.rankLoss.net, path, "rank", label) 276 | 277 | def update_learning_rate(self, nepoch_decay): 278 | lrd = self.lr / nepoch_decay 279 | lr = self.old_lr - lrd 280 | 281 | for param_group in self.optimizer_net.param_groups: 282 | param_group["lr"] = lr 283 | 284 | print("update lr [%s] decay: %f -> %f" % (type, self.old_lr, lr)) 285 | self.old_lr = lr 286 | 287 | 288 | def score_2afc_dataset(data_loader, func, name=""): 289 | 290 | 291 | d0s = [] 292 | d1s = [] 293 | gts = [] 294 | 295 | for data in tqdm(data_loader.load_data(), desc=name): 296 | d0s += func(data["ref"], data["p0"]).data.cpu().numpy().flatten().tolist() 297 | d1s += func(data["ref"], data["p1"]).data.cpu().numpy().flatten().tolist() 298 | gts += data["judge"].cpu().numpy().flatten().tolist() 299 | 300 | d0s = np.array(d0s) 301 | d1s = np.array(d1s) 302 | gts = np.array(gts) 303 | scores = (d0s < d1s) * (1.0 - gts) + (d1s < d0s) * gts + (d1s == d0s) * 0.5 304 | 305 | return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) 306 | 307 | 308 | def score_jnd_dataset(data_loader, func, name=""): 309 | 310 | 311 | ds = [] 312 | gts = [] 313 | 314 | for data in tqdm(data_loader.load_data(), desc=name): 315 | ds += func(data["p0"], data["p1"]).data.cpu().numpy().tolist() 316 | gts += data["same"].cpu().numpy().flatten().tolist() 317 | 318 | sames = np.array(gts) 319 | ds = np.array(ds) 320 | 321 | sorted_inds = np.argsort(ds) 322 | sames_sorted = sames[sorted_inds] 323 | 324 | TPs = np.cumsum(sames_sorted) 325 | FPs = np.cumsum(1 - sames_sorted) 326 | FNs = np.sum(sames_sorted) - TPs 327 | 328 | precs = TPs / (TPs + FPs) 329 | recs = TPs / (TPs + FNs) 330 | score = util.voc_ap(recs, precs) 331 | 332 | return (score, dict(ds=ds, sames=sames)) 333 | -------------------------------------------------------------------------------- /training/lpips/networks_basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/networks_basic.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 4 | """ 5 | from __future__ import absolute_import 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from . import pretrained_networks as pn 10 | 11 | from training import lpips as util 12 | 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2, 3], keepdim=keepdim) 16 | 17 | 18 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 19 | in_H = in_tens.shape[2] 20 | scale_factor = 1.0 * out_H / in_H 21 | 22 | return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)( 23 | in_tens 24 | ) 25 | 26 | 27 | # Learned perceptual metric 28 | class PNetLin(nn.Module): 29 | def __init__( 30 | self, 31 | pnet_type="vgg", 32 | pnet_rand=False, 33 | pnet_tune=False, 34 | use_dropout=True, 35 | spatial=False, 36 | version="0.1", 37 | lpips=True, 38 | ): 39 | super(PNetLin, self).__init__() 40 | 41 | self.pnet_type = pnet_type 42 | self.pnet_tune = pnet_tune 43 | self.pnet_rand = pnet_rand 44 | self.spatial = spatial 45 | self.lpips = lpips 46 | self.version = version 47 | self.scaling_layer = ScalingLayer() 48 | 49 | if self.pnet_type in ["vgg", "vgg16"]: 50 | net_type = pn.vgg16 51 | self.chns = [64, 128, 256, 512, 512] 52 | elif self.pnet_type == "alex": 53 | net_type = pn.alexnet 54 | self.chns = [64, 192, 384, 256, 256] 55 | elif self.pnet_type == "squeeze": 56 | net_type = pn.squeezenet 57 | self.chns = [64, 128, 256, 384, 384, 512, 512] 58 | self.L = len(self.chns) 59 | 60 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 61 | 62 | if lpips: 63 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 64 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 65 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 66 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 67 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 68 | self.lins = nn.ModuleList( 69 | [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 70 | ) 71 | 72 | if self.pnet_type == "squeeze": # 7 layers for squeezenet 73 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 74 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 75 | self.lins.extend([self.lin5, self.lin6]) 76 | 77 | def forward(self, in0, in1, retPerLayer=False): 78 | # v0.0 - original release had a bug, where input was not scaled 79 | in0_input, in1_input = ( 80 | (self.scaling_layer(in0), self.scaling_layer(in1)) 81 | if self.version == "0.1" 82 | else (in0, in1) 83 | ) 84 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 85 | feats0, feats1, diffs = {}, {}, {} 86 | 87 | for kk in range(self.L): 88 | feats0[kk], feats1[kk] = ( 89 | util.normalize_tensor(outs0[kk]), 90 | util.normalize_tensor(outs1[kk]), 91 | ) 92 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 93 | 94 | if self.lpips: 95 | if self.spatial: 96 | res = [ 97 | upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) 98 | for kk in range(self.L) 99 | ] 100 | else: 101 | res = [ 102 | spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) 103 | for kk in range(self.L) 104 | ] 105 | else: 106 | if self.spatial: 107 | res = [ 108 | upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) 109 | for kk in range(self.L) 110 | ] 111 | else: 112 | res = [ 113 | spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) 114 | for kk in range(self.L) 115 | ] 116 | 117 | val = res[0] 118 | for l in range(1, self.L): 119 | val += res[l] 120 | 121 | if retPerLayer: 122 | return (val, res) 123 | else: 124 | return val 125 | 126 | 127 | class ScalingLayer(nn.Module): 128 | def __init__(self): 129 | super(ScalingLayer, self).__init__() 130 | self.register_buffer( 131 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 132 | ) 133 | self.register_buffer( 134 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 135 | ) 136 | 137 | def forward(self, inp): 138 | return (inp - self.shift) / self.scale 139 | 140 | 141 | class NetLinLayer(nn.Module): 142 | """ A single linear layer which does a 1x1 conv """ 143 | 144 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 145 | super(NetLinLayer, self).__init__() 146 | 147 | layers = ( 148 | [ 149 | nn.Dropout(), 150 | ] 151 | if (use_dropout) 152 | else [] 153 | ) 154 | layers += [ 155 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 156 | ] 157 | self.model = nn.Sequential(*layers) 158 | 159 | 160 | class Dist2LogitLayer(nn.Module): 161 | """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """ 162 | 163 | def __init__(self, chn_mid=32, use_sigmoid=True): 164 | super(Dist2LogitLayer, self).__init__() 165 | 166 | layers = [ 167 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), 168 | ] 169 | layers += [ 170 | nn.LeakyReLU(0.2, True), 171 | ] 172 | layers += [ 173 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), 174 | ] 175 | layers += [ 176 | nn.LeakyReLU(0.2, True), 177 | ] 178 | layers += [ 179 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), 180 | ] 181 | if use_sigmoid: 182 | layers += [ 183 | nn.Sigmoid(), 184 | ] 185 | self.model = nn.Sequential(*layers) 186 | 187 | def forward(self, d0, d1, eps=0.1): 188 | return self.model.forward( 189 | torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) 190 | ) 191 | 192 | 193 | class BCERankingLoss(nn.Module): 194 | def __init__(self, chn_mid=32): 195 | super(BCERankingLoss, self).__init__() 196 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 197 | # self.parameters = list(self.net.parameters()) 198 | self.loss = torch.nn.BCELoss() 199 | 200 | def forward(self, d0, d1, judge): 201 | per = (judge + 1.0) / 2.0 202 | self.logit = self.net.forward(d0, d1) 203 | return self.loss(self.logit, per) 204 | 205 | 206 | # L2, DSSIM training 207 | class FakeNet(nn.Module): 208 | def __init__(self, use_gpu=True, colorspace="Lab"): 209 | super(FakeNet, self).__init__() 210 | self.use_gpu = use_gpu 211 | self.colorspace = colorspace 212 | 213 | 214 | class L2(FakeNet): 215 | def forward(self, in0, in1, retPerLayer=None): 216 | assert in0.size()[0] == 1 # currently only supports batchSize 1 217 | 218 | if self.colorspace == "RGB": 219 | (N, C, X, Y) = in0.size() 220 | value = torch.mean( 221 | torch.mean( 222 | torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 223 | ).view(N, 1, 1, Y), 224 | dim=3, 225 | ).view(N) 226 | return value 227 | elif self.colorspace == "Lab": 228 | value = util.l2( 229 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 230 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 231 | range=100.0, 232 | ).astype("float") 233 | ret_var = Variable(torch.Tensor((value,))) 234 | if self.use_gpu: 235 | ret_var = ret_var.cuda() 236 | return ret_var 237 | 238 | 239 | class DSSIM(FakeNet): 240 | def forward(self, in0, in1, retPerLayer=None): 241 | assert in0.size()[0] == 1 # currently only supports batchSize 1 242 | 243 | if self.colorspace == "RGB": 244 | value = util.dssim( 245 | 1.0 * util.tensor2im(in0.data), 246 | 1.0 * util.tensor2im(in1.data), 247 | range=255.0, 248 | ).astype("float") 249 | elif self.colorspace == "Lab": 250 | value = util.dssim( 251 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 252 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 253 | range=100.0, 254 | ).astype("float") 255 | ret_var = Variable(torch.Tensor((value,))) 256 | if self.use_gpu: 257 | ret_var = ret_var.cuda() 258 | return ret_var 259 | 260 | 261 | def print_network(net): 262 | num_params = 0 263 | for param in net.parameters(): 264 | num_params += param.numel() 265 | print("Network", net) 266 | print("Total number of parameters: %d" % num_params) 267 | -------------------------------------------------------------------------------- /training/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/lpips/pretrained_networks.py 3 | Refer to https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py 4 | """ 5 | from collections import namedtuple 6 | import torch 7 | from torchvision import models as tv 8 | 9 | 10 | class squeezenet(torch.nn.Module): 11 | def __init__(self, requires_grad=False, pretrained=True): 12 | super(squeezenet, self).__init__() 13 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | self.slice6 = torch.nn.Sequential() 20 | self.slice7 = torch.nn.Sequential() 21 | self.N_slices = 7 22 | for x in range(2): 23 | self.slice1.add_module(str(x), pretrained_features[x]) 24 | for x in range(2, 5): 25 | self.slice2.add_module(str(x), pretrained_features[x]) 26 | for x in range(5, 8): 27 | self.slice3.add_module(str(x), pretrained_features[x]) 28 | for x in range(8, 10): 29 | self.slice4.add_module(str(x), pretrained_features[x]) 30 | for x in range(10, 11): 31 | self.slice5.add_module(str(x), pretrained_features[x]) 32 | for x in range(11, 12): 33 | self.slice6.add_module(str(x), pretrained_features[x]) 34 | for x in range(12, 13): 35 | self.slice7.add_module(str(x), pretrained_features[x]) 36 | if not requires_grad: 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, X): 41 | h = self.slice1(X) 42 | h_relu1 = h 43 | h = self.slice2(h) 44 | h_relu2 = h 45 | h = self.slice3(h) 46 | h_relu3 = h 47 | h = self.slice4(h) 48 | h_relu4 = h 49 | h = self.slice5(h) 50 | h_relu5 = h 51 | h = self.slice6(h) 52 | h_relu6 = h 53 | h = self.slice7(h) 54 | h_relu7 = h 55 | vgg_outputs = namedtuple( 56 | "SqueezeOutputs", 57 | ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], 58 | ) 59 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 60 | 61 | return out 62 | 63 | 64 | class alexnet(torch.nn.Module): 65 | def __init__(self, requires_grad=False, pretrained=True): 66 | super(alexnet, self).__init__() 67 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 68 | self.slice1 = torch.nn.Sequential() 69 | self.slice2 = torch.nn.Sequential() 70 | self.slice3 = torch.nn.Sequential() 71 | self.slice4 = torch.nn.Sequential() 72 | self.slice5 = torch.nn.Sequential() 73 | self.N_slices = 5 74 | for x in range(2): 75 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 76 | for x in range(2, 5): 77 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 78 | for x in range(5, 8): 79 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 80 | for x in range(8, 10): 81 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 82 | for x in range(10, 12): 83 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 84 | if not requires_grad: 85 | for param in self.parameters(): 86 | param.requires_grad = False 87 | 88 | def forward(self, X): 89 | h = self.slice1(X) 90 | h_relu1 = h 91 | h = self.slice2(h) 92 | h_relu2 = h 93 | h = self.slice3(h) 94 | h_relu3 = h 95 | h = self.slice4(h) 96 | h_relu4 = h 97 | h = self.slice5(h) 98 | h_relu5 = h 99 | alexnet_outputs = namedtuple( 100 | "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] 101 | ) 102 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 103 | 104 | return out 105 | 106 | 107 | class vgg16(torch.nn.Module): 108 | def __init__(self, requires_grad=False, pretrained=True): 109 | super(vgg16, self).__init__() 110 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 111 | self.slice1 = torch.nn.Sequential() 112 | self.slice2 = torch.nn.Sequential() 113 | self.slice3 = torch.nn.Sequential() 114 | self.slice4 = torch.nn.Sequential() 115 | self.slice5 = torch.nn.Sequential() 116 | self.N_slices = 5 117 | for x in range(4): 118 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 119 | for x in range(4, 9): 120 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 121 | for x in range(9, 16): 122 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 123 | for x in range(16, 23): 124 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 125 | for x in range(23, 30): 126 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 127 | if not requires_grad: 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | 131 | def forward(self, X): 132 | h = self.slice1(X) 133 | h_relu1_2 = h 134 | h = self.slice2(h) 135 | h_relu2_2 = h 136 | h = self.slice3(h) 137 | h_relu3_3 = h 138 | h = self.slice4(h) 139 | h_relu4_3 = h 140 | h = self.slice5(h) 141 | h_relu5_3 = h 142 | vgg_outputs = namedtuple( 143 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 144 | ) 145 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 146 | 147 | return out 148 | 149 | 150 | class resnet(torch.nn.Module): 151 | def __init__(self, requires_grad=False, pretrained=True, num=18): 152 | super(resnet, self).__init__() 153 | if num == 18: 154 | self.net = tv.resnet18(pretrained=pretrained) 155 | elif num == 34: 156 | self.net = tv.resnet34(pretrained=pretrained) 157 | elif num == 50: 158 | self.net = tv.resnet50(pretrained=pretrained) 159 | elif num == 101: 160 | self.net = tv.resnet101(pretrained=pretrained) 161 | elif num == 152: 162 | self.net = tv.resnet152(pretrained=pretrained) 163 | self.N_slices = 5 164 | 165 | self.conv1 = self.net.conv1 166 | self.bn1 = self.net.bn1 167 | self.relu = self.net.relu 168 | self.maxpool = self.net.maxpool 169 | self.layer1 = self.net.layer1 170 | self.layer2 = self.net.layer2 171 | self.layer3 = self.net.layer3 172 | self.layer4 = self.net.layer4 173 | 174 | def forward(self, X): 175 | h = self.conv1(X) 176 | h = self.bn1(h) 177 | h = self.relu(h) 178 | h_relu1 = h 179 | h = self.maxpool(h) 180 | h = self.layer1(h) 181 | h_conv2 = h 182 | h = self.layer2(h) 183 | h_conv3 = h 184 | h = self.layer3(h) 185 | h_conv4 = h 186 | h = self.layer4(h) 187 | h_conv5 = h 188 | 189 | outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) 190 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 191 | 192 | return out 193 | -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /training/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xc-csc101/UniFace/05f15e57668e7ce30399233c2456bb0f4cb35055/training/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /training/op/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/__init__.py 3 | """ 4 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 5 | from .upfirdn2d import upfirdn2d 6 | -------------------------------------------------------------------------------- /training/op/fused_act.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py 3 | """ 4 | import os 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | fused = load( 13 | "fused", 14 | sources=[ 15 | os.path.join(module_path, "fused_bias_act.cpp"), 16 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 17 | ], 18 | ) 19 | 20 | 21 | class FusedLeakyReLUFunctionBackward(Function): 22 | @staticmethod 23 | def forward(ctx, grad_output, out, negative_slope, scale): 24 | ctx.save_for_backward(out) 25 | ctx.negative_slope = negative_slope 26 | ctx.scale = scale 27 | 28 | empty = grad_output.new_empty(0) 29 | 30 | grad_input = fused.fused_bias_act( 31 | grad_output, empty, out, 3, 1, negative_slope, scale 32 | ) 33 | 34 | dim = [0] 35 | 36 | if grad_input.ndim > 2: 37 | dim += list(range(2, grad_input.ndim)) 38 | 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | return grad_input, grad_bias 42 | 43 | @staticmethod 44 | def backward(ctx, gradgrad_input, gradgrad_bias): 45 | (out,) = ctx.saved_tensors 46 | gradgrad_out = fused.fused_bias_act( 47 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 48 | ) 49 | 50 | return gradgrad_out, None, None, None 51 | 52 | 53 | class FusedLeakyReLUFunction(Function): 54 | @staticmethod 55 | def forward(ctx, input, bias, negative_slope, scale): 56 | empty = input.new_empty(0) 57 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 58 | ctx.save_for_backward(out) 59 | ctx.negative_slope = negative_slope 60 | ctx.scale = scale 61 | 62 | return out 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | (out,) = ctx.saved_tensors 67 | 68 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 69 | grad_output, out, ctx.negative_slope, ctx.scale 70 | ) 71 | 72 | return grad_input, grad_bias, None, None 73 | 74 | 75 | class FusedLeakyReLU(nn.Module): 76 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 77 | super().__init__() 78 | 79 | self.bias = nn.Parameter(torch.zeros(channel)) 80 | self.negative_slope = negative_slope 81 | self.scale = scale 82 | 83 | def forward(self, input): 84 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 85 | 86 | 87 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 88 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 89 | -------------------------------------------------------------------------------- /training/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | 3 | #include 4 | 5 | 6 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 7 | int act, int grad, float alpha, float scale); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 14 | int act, int grad, float alpha, float scale) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(bias); 17 | 18 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 23 | } -------------------------------------------------------------------------------- /training/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /training/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Refer to https : //github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | 3 | #include 4 | 5 | torch::Tensor 6 | upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, 7 | int up_x, int up_y, int down_x, int down_y, 8 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) \ 13 | CHECK_CUDA(x); \ 14 | CHECK_CONTIGUOUS(x) 15 | 16 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 17 | int up_x, int up_y, int down_x, int down_y, 18 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) 19 | { 20 | CHECK_CUDA(input); 21 | CHECK_CUDA(kernel); 22 | 23 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 24 | } 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 27 | { 28 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 29 | } -------------------------------------------------------------------------------- /training/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py 3 | """ 4 | import os 5 | 6 | import torch 7 | from torch.autograd import Function 8 | from torch.utils.cpp_extension import load 9 | 10 | 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_op = load( 13 | "upfirdn2d", 14 | sources=[ 15 | os.path.join(module_path, "upfirdn2d.cpp"), 16 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 17 | ], 18 | ) 19 | 20 | 21 | class UpFirDn2dBackward(Function): 22 | @staticmethod 23 | def forward( 24 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 25 | ): 26 | 27 | up_x, up_y = up 28 | down_x, down_y = down 29 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 30 | 31 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 32 | 33 | grad_input = upfirdn2d_op.upfirdn2d( 34 | grad_output, 35 | grad_kernel, 36 | down_x, 37 | down_y, 38 | up_x, 39 | up_y, 40 | g_pad_x0, 41 | g_pad_x1, 42 | g_pad_y0, 43 | g_pad_y1, 44 | ) 45 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 46 | 47 | ctx.save_for_backward(kernel) 48 | 49 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 50 | 51 | ctx.up_x = up_x 52 | ctx.up_y = up_y 53 | ctx.down_x = down_x 54 | ctx.down_y = down_y 55 | ctx.pad_x0 = pad_x0 56 | ctx.pad_x1 = pad_x1 57 | ctx.pad_y0 = pad_y0 58 | ctx.pad_y1 = pad_y1 59 | ctx.in_size = in_size 60 | ctx.out_size = out_size 61 | 62 | return grad_input 63 | 64 | @staticmethod 65 | def backward(ctx, gradgrad_input): 66 | (kernel,) = ctx.saved_tensors 67 | 68 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 69 | 70 | gradgrad_out = upfirdn2d_op.upfirdn2d( 71 | gradgrad_input, 72 | kernel, 73 | ctx.up_x, 74 | ctx.up_y, 75 | ctx.down_x, 76 | ctx.down_y, 77 | ctx.pad_x0, 78 | ctx.pad_x1, 79 | ctx.pad_y0, 80 | ctx.pad_y1, 81 | ) 82 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 83 | gradgrad_out = gradgrad_out.view( 84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 85 | ) 86 | 87 | return gradgrad_out, None, None, None, None, None, None, None, None 88 | 89 | 90 | class UpFirDn2d(Function): 91 | @staticmethod 92 | def forward(ctx, input, kernel, up, down, pad): 93 | up_x, up_y = up 94 | down_x, down_y = down 95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 96 | 97 | kernel_h, kernel_w = kernel.shape 98 | batch, channel, in_h, in_w = input.shape 99 | ctx.in_size = input.shape 100 | 101 | input = input.reshape(-1, in_h, in_w, 1) 102 | 103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 104 | 105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 107 | ctx.out_size = (out_h, out_w) 108 | 109 | ctx.up = (up_x, up_y) 110 | ctx.down = (down_x, down_y) 111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 112 | 113 | g_pad_x0 = kernel_w - pad_x0 - 1 114 | g_pad_y0 = kernel_h - pad_y0 - 1 115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 117 | 118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 119 | 120 | out = upfirdn2d_op.upfirdn2d( 121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 122 | ) 123 | # out = out.view(major, out_h, out_w, minor) 124 | out = out.view(-1, channel, out_h, out_w) 125 | 126 | return out 127 | 128 | @staticmethod 129 | def backward(ctx, grad_output): 130 | kernel, grad_kernel = ctx.saved_tensors 131 | 132 | grad_input = UpFirDn2dBackward.apply( 133 | grad_output, 134 | kernel, 135 | grad_kernel, 136 | ctx.up, 137 | ctx.down, 138 | ctx.pad, 139 | ctx.g_pad, 140 | ctx.in_size, 141 | ctx.out_size, 142 | ) 143 | 144 | return grad_input, None, None, None, None 145 | 146 | 147 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 148 | out = UpFirDn2d.apply( 149 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 150 | ) 151 | 152 | return out 153 | 154 | 155 | def upfirdn2d_native( 156 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 157 | ): 158 | _, in_h, in_w, minor = input.shape 159 | kernel_h, kernel_w = kernel.shape 160 | 161 | out = input.view(-1, in_h, 1, in_w, 1, minor) 162 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 163 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 164 | 165 | out = F.pad( 166 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 167 | ) 168 | out = out[ 169 | :, 170 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 171 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 172 | :, 173 | ] 174 | 175 | out = out.permute(0, 3, 1, 2) 176 | out = out.reshape( 177 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 178 | ) 179 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 180 | out = F.conv2d(out, w) 181 | out = out.reshape( 182 | -1, 183 | minor, 184 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 185 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 186 | ) 187 | out = out.permute(0, 2, 3, 1) 188 | 189 | return out[:, ::down_y, ::down_x, :] 190 | -------------------------------------------------------------------------------- /training/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /training/pose.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import init 3 | import torch 4 | 5 | 6 | class BaseNetwork(nn.Module): 7 | def __init__(self): 8 | super(BaseNetwork, self).__init__() 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | return parser 13 | 14 | def print_network(self): 15 | if isinstance(self, list): 16 | self = self[0] 17 | num_params = 0 18 | for param in self.parameters(): 19 | num_params += param.numel() 20 | print( 21 | 'Network [%s] was created. Total number of parameters: %.1f million. ' 22 | 'To see the architecture, do print(network).' % 23 | (type(self).__name__, num_params / 1000000)) 24 | 25 | def init_weights(self, init_type='normal', gain=0.02): 26 | def init_func(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('BatchNorm2d') != -1: 29 | if hasattr(m, 'weight') and m.weight is not None: 30 | init.normal_(m.weight.data, 1.0, gain) 31 | if hasattr(m, 'bias') and m.bias is not None: 32 | init.constant_(m.bias.data, 0.0) 33 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 34 | or classname.find('Linear') != -1): 35 | if init_type == 'normal': 36 | init.normal_(m.weight.data, 0.0, gain) 37 | elif init_type == 'xavier': 38 | init.xavier_normal_(m.weight.data, gain=gain) 39 | elif init_type == 'xavier_uniform': 40 | init.xavier_uniform_(m.weight.data, gain=1.0) 41 | elif init_type == 'kaiming': 42 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 43 | elif init_type == 'orthogonal': 44 | init.orthogonal_(m.weight.data, gain=gain) 45 | elif init_type == 'none': # uses pytorch's default init method 46 | m.reset_parameters() 47 | else: 48 | raise NotImplementedError( 49 | 'initialization method [%s] is not implemented' % 50 | init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | 61 | 62 | def weights_init_xavier(m, gain=0.02): 63 | classname = m.__class__.__name__ 64 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 65 | or classname.find('Linear') != -1): 66 | nn.init.xavier_normal_(m.weight.data, gain=gain) 67 | if hasattr(m, 'bias'): 68 | nn.init.constant_(m.bias.data, 0.0) 69 | elif classname.find('BatchNorm2d') != -1: 70 | nn.init.normal_(m.weight.data, 1.0, 0.02) 71 | nn.init.constant_(m.bias.data, 0.0) 72 | 73 | 74 | class Encoder_Pose(BaseNetwork): 75 | def __init__(self, 76 | norm_layer=nn.InstanceNorm2d, 77 | use_dropout=False): 78 | super(Encoder_Pose, self).__init__() 79 | input_nc = 3 80 | output_nc = 2 81 | num_downs = 8 82 | ngf = 64 83 | embedding_dim = 512 84 | 85 | self.num_downs = num_downs 86 | use_bias = norm_layer == nn.InstanceNorm2d 87 | 88 | # === down sample === # 89 | self.down0 = nn.Sequential( 90 | nn.Conv2d(input_nc, 91 | ngf * 1, 92 | kernel_size=4, 93 | stride=2, 94 | padding=1, 95 | bias=use_bias), nn.LeakyReLU(0.2, True), norm_layer(ngf)) 96 | nf_mult = 1 97 | for i in range(1, num_downs - 1): 98 | nf_mult_prev = nf_mult 99 | nf_mult = min(2 * nf_mult_prev, 8) 100 | layer = nn.Sequential( 101 | nn.Conv2d(ngf * nf_mult_prev, 102 | ngf * nf_mult, 103 | kernel_size=4, 104 | stride=2, 105 | padding=1, 106 | bias=use_bias), nn.LeakyReLU(0.2, True), 107 | norm_layer(ngf * nf_mult)) 108 | setattr(self, 'down' + str(i), layer) 109 | self.down7 = nn.Sequential( 110 | nn.Conv2d(ngf * nf_mult, 111 | embedding_dim, 112 | kernel_size=4, 113 | stride=2, 114 | padding=1, 115 | bias=use_bias)) 116 | 117 | # === up sample === # 118 | nf_mult = 1 119 | for i in range(1, num_downs - 1): 120 | nf_mult_prev = nf_mult 121 | nf_mult = min(2 * nf_mult_prev, 8) 122 | layer = nn.Sequential( 123 | nn.ReLU(True), nn.Upsample(scale_factor=2.0, mode='bilinear'), 124 | nn.Conv2d(ngf * nf_mult, 125 | ngf * nf_mult_prev, 126 | kernel_size=3, 127 | stride=1, 128 | padding=1, 129 | bias=use_bias), norm_layer(ngf * nf_mult_prev)) 130 | setattr(self, 'up' + str(i), layer) 131 | self.up7 = nn.Sequential( 132 | nn.ReLU(True), nn.Upsample(scale_factor=2.0, mode='bilinear'), 133 | nn.Conv2d(embedding_dim, 134 | ngf * nf_mult, 135 | kernel_size=3, 136 | stride=1, 137 | padding=1, 138 | bias=use_bias), norm_layer(ngf * nf_mult)) 139 | 140 | self.out2 = nn.Sequential( 141 | nn.ReLU(False), nn.Upsample(scale_factor=2.0, mode='bilinear'), 142 | nn.Conv2d(ngf * 4, 143 | output_nc, 144 | kernel_size=3, 145 | stride=1, 146 | padding=1, 147 | bias=use_bias)) 148 | self.tanh = nn.Tanh() 149 | 150 | def forward(self, x): 151 | # === down sampling === # 152 | down0 = self.down0(x) 153 | down1 = self.down1(down0) 154 | down2 = self.down2(down1) 155 | hid = down2 156 | for i in range(3, self.num_downs): 157 | hid = getattr(self, 'down' + str(i))(hid) 158 | pose_code = hid 159 | 160 | # === up sampling === # 161 | for i in range(3, self.num_downs)[::-1]: 162 | hid = getattr(self, 'up' + str(i))(hid) 163 | up_3 = hid 164 | 165 | out_64 = self.out2(up_3) 166 | 167 | return self.tanh(out_64), pose_code.flatten(1) 168 | 169 | def init_weights(self, init_type='normal', gain=0.02): 170 | self.apply(weights_init_xavier) 171 | 172 | 173 | class Encoder_Pose_v2(BaseNetwork): 174 | # not encoder to vector 175 | def __init__(self, 176 | norm_layer=nn.InstanceNorm2d, 177 | use_dropout=False): 178 | super(Encoder_Pose_v2, self).__init__() 179 | input_nc = 3 180 | output_nc = 2 181 | num_downs = 5 182 | ngf = 64 183 | 184 | self.num_downs = num_downs 185 | use_bias = norm_layer == nn.InstanceNorm2d 186 | 187 | # === down sample === # 188 | self.down0 = nn.Sequential( 189 | nn.Conv2d(input_nc, 190 | ngf * 1, 191 | kernel_size=4, 192 | stride=2, 193 | padding=1, 194 | bias=use_bias), nn.LeakyReLU(0.2, True), norm_layer(ngf)) 195 | nf_mult = 1 196 | for i in range(1, num_downs): 197 | nf_mult_prev = nf_mult 198 | nf_mult = min(2 * nf_mult_prev, 8) 199 | layer = nn.Sequential( 200 | nn.Conv2d(ngf * nf_mult_prev, 201 | ngf * nf_mult, 202 | kernel_size=4, 203 | stride=2, 204 | padding=1, 205 | bias=use_bias), nn.LeakyReLU(0.2, True), 206 | norm_layer(ngf * nf_mult)) 207 | setattr(self, 'down' + str(i), layer) 208 | 209 | self.trans1 = nn.Sequential( 210 | nn.Conv2d(ngf * nf_mult, 211 | ngf * nf_mult, 212 | kernel_size=3, 213 | stride=1, 214 | padding=1, 215 | bias=use_bias), nn.LeakyReLU(0.2, True), 216 | norm_layer(ngf * nf_mult)) 217 | 218 | self.trans2 = nn.Sequential( 219 | nn.Conv2d(ngf * nf_mult, 220 | ngf * nf_mult, 221 | kernel_size=3, 222 | stride=1, 223 | padding=1, 224 | bias=use_bias), nn.LeakyReLU(0.2, True), 225 | norm_layer(ngf * nf_mult)) 226 | 227 | # === up sample === # 228 | nf_mult = 2 229 | for i in range(2, num_downs): 230 | nf_mult_prev = nf_mult 231 | nf_mult = min(2 * nf_mult_prev, 8) 232 | layer = nn.Sequential( 233 | nn.ReLU(True), nn.Upsample(scale_factor=2.0, mode='bilinear'), 234 | nn.Conv2d(ngf * nf_mult, 235 | ngf * nf_mult_prev, 236 | kernel_size=3, 237 | stride=1, 238 | padding=1, 239 | bias=use_bias), norm_layer(ngf * nf_mult_prev)) 240 | setattr(self, 'up' + str(i), layer) 241 | 242 | self.out2 = nn.Sequential( 243 | nn.ReLU(False), nn.Upsample(scale_factor=2.0, mode='bilinear'), 244 | nn.Conv2d(ngf * 4, 245 | output_nc, 246 | kernel_size=3, 247 | stride=1, 248 | padding=1, 249 | bias=use_bias)) 250 | self.tanh = nn.Tanh() 251 | 252 | def forward(self, x): 253 | # === down sampling === # 254 | down0 = self.down0(x) 255 | down1 = self.down1(down0) 256 | down2 = self.down2(down1) 257 | hid = down2 258 | for i in range(3, self.num_downs): 259 | hid = getattr(self, 'down' + str(i))(hid) 260 | pose_code = hid 261 | 262 | # trans 263 | hid = self.trans2(self.trans1(hid)) 264 | 265 | # === up sampling === # 266 | for i in range(3, self.num_downs)[::-1]: 267 | hid = getattr(self, 'up' + str(i))(hid) 268 | up_3 = hid 269 | 270 | out_64 = self.out2(up_3) 271 | 272 | return self.tanh(out_64), pose_code.flatten(1) 273 | 274 | def init_weights(self, init_type='normal', gain=0.02): 275 | self.apply(weights_init_xavier) 276 | 277 | 278 | if __name__ == '__main__': 279 | net = Encoder_Pose_v2() 280 | img = torch.randn(1, 3, 256, 256) 281 | out = net(img) 282 | print(out.shape) 283 | -------------------------------------------------------------------------------- /training/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # gram matrix and loss 7 | class GramMatrix(nn.Module): 8 | def forward(self, input): 9 | b, c, h, w = input.size() 10 | F = input.view(b, c, h * w) 11 | G = torch.bmm(F, F.transpose(1, 2)) 12 | G.div_(h * w) 13 | return G 14 | 15 | 16 | class GramMSELoss(nn.Module): 17 | def forward(self, input, target): 18 | out = nn.MSELoss()(GramMatrix()(input), target) 19 | return (out) 20 | 21 | 22 | # vgg definition that conveniently let's you grab the outputs from any layer 23 | class VGG(nn.Module): 24 | def __init__(self, pool='max'): 25 | super(VGG, self).__init__() 26 | # vgg modules 27 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 28 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 29 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 30 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 34 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | if pool == 'max': 44 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 45 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 46 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 47 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 49 | elif pool == 'avg': 50 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 51 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 52 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 53 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 54 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 55 | 56 | def forward(self, x, out_keys): 57 | out = {} 58 | out['r11'] = F.relu(self.conv1_1(x)) 59 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 60 | out['p1'] = self.pool1(out['r12']) 61 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 62 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 63 | out['p2'] = self.pool2(out['r22']) 64 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 65 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 66 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 67 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 68 | out['p3'] = self.pool3(out['r34']) 69 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 70 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 71 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 72 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 73 | out['p4'] = self.pool4(out['r44']) 74 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 75 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 76 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 77 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 78 | out['p5'] = self.pool5(out['r54']) 79 | return [out[key] for key in out_keys] 80 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # Log images 8 | def log_input_image(x, label_nc=0): 9 | if label_nc == 0: 10 | return tensor2im(x) 11 | elif label_nc == 1: 12 | return tensor2sketch(x) 13 | else: 14 | return tensor2map(x) 15 | 16 | 17 | def tensor2im(var): 18 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 19 | var = ((var + 1) / 2) 20 | var[var < 0] = 0 21 | var[var > 1] = 1 22 | var = var * 255 23 | return Image.fromarray(var.astype('uint8')) 24 | 25 | 26 | def tensor2map(var): 27 | mask = np.argmax(var.data.cpu().numpy(), axis=0) 28 | colors = get_colors() 29 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) 30 | for class_idx in np.unique(mask): 31 | mask_image[mask == class_idx] = colors[class_idx] 32 | mask_image = mask_image.astype('uint8') 33 | return Image.fromarray(mask_image) 34 | 35 | 36 | def tensor2sketch(var): 37 | im = var[0].cpu().detach().numpy() 38 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) 39 | im = (im * 255).astype(np.uint8) 40 | return Image.fromarray(im) 41 | 42 | 43 | # Visualization utils 44 | def get_colors(): 45 | # currently support up to 19 classes (for the celebs-hq-mask dataset) 46 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 47 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 48 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 49 | return colors 50 | 51 | 52 | def vis_faces(log_hooks): 53 | display_count = len(log_hooks) 54 | fig = plt.figure(figsize=(8, 4 * display_count)) 55 | gs = fig.add_gridspec(display_count, 3) 56 | for i in range(display_count): 57 | hooks_dict = log_hooks[i] 58 | fig.add_subplot(gs[i, 0]) 59 | if 'diff_input' in hooks_dict: 60 | vis_faces_with_id(hooks_dict, fig, gs, i) 61 | else: 62 | vis_faces_no_id(hooks_dict, fig, gs, i) 63 | plt.tight_layout() 64 | return fig 65 | 66 | 67 | def vis_faces_v2(log_hooks): 68 | display_count = len(log_hooks) 69 | fig = plt.figure(figsize=(12, 4 * display_count)) 70 | gs = fig.add_gridspec(display_count, 5) 71 | for i in range(display_count): 72 | hooks_dict = log_hooks[i] 73 | fig.add_subplot(gs[i, 0]) 74 | if 'diff_input' in hooks_dict: 75 | vis_faces_with_id(hooks_dict, fig, gs, i) 76 | else: 77 | vis_faces_no_id_v2(hooks_dict, fig, gs, i) 78 | plt.tight_layout() 79 | return fig 80 | 81 | 82 | def vis_faces_v3(log_hooks): 83 | display_count = len(log_hooks) 84 | fig = plt.figure(figsize=(12, 4 * display_count)) 85 | gs = fig.add_gridspec(display_count, 4) 86 | for i in range(display_count): 87 | hooks_dict = log_hooks[i] 88 | fig.add_subplot(gs[i, 0]) 89 | if 'diff_input' in hooks_dict: 90 | vis_faces_with_id(hooks_dict, fig, gs, i) 91 | else: 92 | vis_faces_no_id_v3(hooks_dict, fig, gs, i) 93 | plt.tight_layout() 94 | return fig 95 | 96 | 97 | def vis_faces_v4(log_hooks): 98 | display_count = len(log_hooks) 99 | fig = plt.figure(figsize=(14, 4 * display_count)) 100 | gs = fig.add_gridspec(display_count, 6) 101 | for i in range(display_count): 102 | hooks_dict = log_hooks[i] 103 | fig.add_subplot(gs[i, 0]) 104 | if 'diff_input' in hooks_dict: 105 | vis_faces_with_id(hooks_dict, fig, gs, i) 106 | else: 107 | vis_faces_no_id_v4(hooks_dict, fig, gs, i) 108 | plt.tight_layout() 109 | return fig 110 | 111 | 112 | def vis_faces_with_id(hooks_dict, fig, gs, i): 113 | plt.imshow(hooks_dict['input_face']) 114 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 115 | fig.add_subplot(gs[i, 1]) 116 | plt.imshow(hooks_dict['target_face']) 117 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 118 | float(hooks_dict['diff_target']))) 119 | fig.add_subplot(gs[i, 2]) 120 | plt.imshow(hooks_dict['output_face']) 121 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 122 | 123 | 124 | def vis_faces_no_id(hooks_dict, fig, gs, i): 125 | plt.imshow(hooks_dict['input_face'], cmap="gray") 126 | plt.title('Input') 127 | fig.add_subplot(gs[i, 1]) 128 | plt.imshow(hooks_dict['target_face']) 129 | plt.title('Target') 130 | fig.add_subplot(gs[i, 2]) 131 | plt.imshow(hooks_dict['output_face']) 132 | plt.title('Output') 133 | 134 | 135 | def vis_faces_no_id_v2(hooks_dict, fig, gs, i): 136 | plt.imshow(hooks_dict['source_face'], cmap="gray") 137 | plt.title('source_face') 138 | fig.add_subplot(gs[i, 1]) 139 | plt.imshow(hooks_dict['drv1_face']) 140 | plt.title('drv1_face') 141 | fig.add_subplot(gs[i, 2]) 142 | plt.imshow(hooks_dict['drv2_face']) 143 | plt.title('drv2_face') 144 | fig.add_subplot(gs[i, 3]) 145 | plt.imshow(hooks_dict['output1']) 146 | plt.title('output1') 147 | fig.add_subplot(gs[i, 4]) 148 | plt.imshow(hooks_dict['output2']) 149 | plt.title('output2') 150 | 151 | 152 | def vis_faces_no_id_v3(hooks_dict, fig, gs, i): 153 | plt.imshow(hooks_dict['input_face'], cmap="gray") 154 | plt.title('input_face') 155 | 156 | fig.add_subplot(gs[i, 1]) 157 | plt.imshow(hooks_dict['target_face_scale']) 158 | plt.title('target_face_scale') 159 | 160 | fig.add_subplot(gs[i, 2]) 161 | plt.imshow(hooks_dict['target_face_gt']) 162 | plt.title('target_face_gt') 163 | 164 | fig.add_subplot(gs[i, 3]) 165 | plt.imshow(hooks_dict['output_face']) 166 | plt.title('output_face') 167 | 168 | 169 | def vis_faces_no_id_v4(hooks_dict, fig, gs, i): 170 | plt.imshow(hooks_dict['source_face'], cmap="gray") 171 | plt.title('source_face') 172 | 173 | fig.add_subplot(gs[i, 1]) 174 | plt.imshow(hooks_dict['drv1_face_input']) 175 | plt.title('drv1_face_input') 176 | 177 | fig.add_subplot(gs[i, 2]) 178 | plt.imshow(hooks_dict['drv1_face_gt']) 179 | plt.title('drv1_face_gt') 180 | 181 | fig.add_subplot(gs[i, 3]) 182 | plt.imshow(hooks_dict['output1']) 183 | plt.title('output1') 184 | 185 | fig.add_subplot(gs[i, 4]) 186 | plt.imshow(hooks_dict['drv2_face']) 187 | plt.title('drv2_face') 188 | 189 | fig.add_subplot(gs[i, 5]) 190 | plt.imshow(hooks_dict['output2']) 191 | plt.title('output2') 192 | -------------------------------------------------------------------------------- /utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from matplotlib.colors import hsv_to_rgb 4 | 5 | 6 | def flow_to_image(flow, max_flow=None): 7 | if max_flow is not None: 8 | max_flow = max(max_flow, 1.) 9 | else: 10 | max_flow = np.max(flow) 11 | 12 | n = 8 13 | u, v = flow[:, :, 0], flow[:, :, 1] 14 | mag = np.sqrt(np.square(u) + np.square(v)) 15 | angle = np.arctan2(v, u) 16 | im_h = np.mod(angle / (2 * np.pi) + 1, 1) 17 | im_s = np.clip(mag * n / max_flow, a_min=0, a_max=1) 18 | im_v = np.clip(n - im_s, a_min=0, a_max=1) 19 | im = hsv_to_rgb(np.stack([im_h, im_s, im_v], 2)) 20 | return (im * 255).astype(np.uint8) 21 | 22 | 23 | def resize_flow(flow, new_shape): 24 | _, _, h, w = flow.shape 25 | new_h, new_w = new_shape 26 | flow = torch.nn.functional.interpolate(flow, (new_h, new_w), 27 | mode='bilinear', align_corners=True) 28 | scale_h, scale_w = h / float(new_h), w / float(new_w) 29 | flow[:, 0] /= scale_w 30 | flow[:, 1] /= scale_h 31 | return flow 32 | --------------------------------------------------------------------------------