├── LICENSE ├── README.md ├── lpips ├── __init__.py ├── base_model.py ├── dist_model.py ├── lpips.py ├── networks_basic.py ├── pretrained_networks.py └── trainer.py ├── main.py ├── model.py ├── networks ├── FlowNetC.py ├── FlowNetFusion.py ├── FlowNetS.py ├── FlowNetSD.py ├── __init__.py ├── channelnorm_package │ ├── __init__.py │ ├── channelnorm.py │ ├── channelnorm_cuda.cc │ ├── channelnorm_kernel.cu │ ├── channelnorm_kernel.cuh │ └── setup.py ├── correlation_package │ ├── __init__.py │ ├── correlation.py │ ├── correlation_cuda.cc │ ├── correlation_cuda_kernel.cu │ ├── correlation_cuda_kernel.cuh │ └── setup.py ├── resample2d_package │ ├── __init__.py │ ├── resample2d.py │ ├── resample2d_cuda.cc │ ├── resample2d_kernel.cu │ ├── resample2d_kernel.cuh │ └── setup.py └── submodules.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── perceptual_model.py ├── read_image.py ├── stylegan.yaml └── stylegan_layers.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Naive_young 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GANInversion_with_ConsecutiveImgs 2 | Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images" https://arxiv.org/pdf/2107.13812.pdf 3 | 4 | **1**. Build the environment with stylegan.yaml (Anaconda is required) \ 5 | **2**. Compile FlowNet2 dependencies (correlation, resample, and channel norm layers).\ 6 | Reference: https://github.com/phoenix104104/fast_blind_video_consistency. \ 7 | **3**. Download the StyleGAN weight and FlowNet weight from: https://drive.google.com/file/d/1g2gp4tR0wAc6uG24qkt82pM3afD-vfBT/view?usp=sharing. \ 8 | **4**. Python main.py 9 | 10 | -------------------------------------------------------------------------------- /lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from torch.autograd import Variable 9 | 10 | from lpips.trainer import * 11 | from lpips.lpips import * 12 | 13 | # class PerceptualLoss(torch.nn.Module): 14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | # super(PerceptualLoss, self).__init__() 17 | # print('Setting up Perceptual loss...') 18 | # self.use_gpu = use_gpu 19 | # self.spatial = spatial 20 | # self.gpu_ids = gpu_ids 21 | # self.model = dist_model.DistModel() 22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | # print('...[%s] initialized'%self.model.name()) 24 | # print('...Done') 25 | 26 | # def forward(self, pred, target, normalize=False): 27 | # """ 28 | # Pred and target are Variables. 29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | # If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | # Inputs pred and target are Nx3xHxW 33 | # Output pytorch Variable N long 34 | # """ 35 | 36 | # if normalize: 37 | # target = 2 * target - 1 38 | # pred = 2 * pred - 1 39 | 40 | # return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | from skimage.measure import compare_ssim 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def load_image(path): 104 | if(path[-3:] == 'dng'): 105 | import rawpy 106 | with rawpy.imread(path) as raw: 107 | img = raw.postprocess() 108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 109 | import cv2 110 | return cv2.imread(path)[:,:,::-1] 111 | else: 112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 113 | 114 | return img 115 | 116 | def rgb2lab(input): 117 | from skimage import color 118 | return color.rgb2lab(input / 255.) 119 | 120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 121 | image_numpy = image_tensor[0].cpu().float().numpy() 122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 123 | return image_numpy.astype(imtype) 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | def tensor2vec(vector_tensor): 130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 131 | 132 | 133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 135 | image_numpy = image_tensor[0].cpu().float().numpy() 136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 137 | return image_numpy.astype(imtype) 138 | 139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 141 | return torch.Tensor((image / factor - cent) 142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 143 | 144 | 145 | 146 | def voc_ap(rec, prec, use_07_metric=False): 147 | """ ap = voc_ap(rec, prec, [use_07_metric]) 148 | Compute VOC AP given precision and recall. 149 | If use_07_metric is true, uses the 150 | VOC 07 11 point method (default:False). 151 | """ 152 | if use_07_metric: 153 | # 11 point metric 154 | ap = 0. 155 | for t in np.arange(0., 1.1, 0.1): 156 | if np.sum(rec >= t) == 0: 157 | p = 0 158 | else: 159 | p = np.max(prec[rec >= t]) 160 | ap = ap + p / 11. 161 | else: 162 | # correct AP calculation 163 | # first append sentinel values at the end 164 | mrec = np.concatenate(([0.], rec, [1.])) 165 | mpre = np.concatenate(([0.], prec, [0.])) 166 | 167 | # compute the precision envelope 168 | for i in range(mpre.size - 1, 0, -1): 169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 170 | 171 | # to calculate area under PR curve, look for points 172 | # where X axis (recall) changes value 173 | i = np.where(mrec[1:] != mrec[:-1])[0] 174 | 175 | # and sum (\Delta recall) * prec 176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 177 | return ap 178 | 179 | -------------------------------------------------------------------------------- /lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s predicted human judgment (h*) 67 | self.rankLoss = lpips.BCERankingLoss() 68 | self.parameters += list(self.rankLoss.net.parameters()) 69 | self.lr = lr 70 | self.old_lr = lr 71 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 72 | else: # test mode 73 | self.net.eval() 74 | 75 | if(use_gpu): 76 | self.net.to(gpu_ids[0]) 77 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 78 | if(self.is_train): 79 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 80 | 81 | if(printNet): 82 | print('---------- Networks initialized -------------') 83 | networks.print_network(self.net) 84 | print('-----------------------------------------------') 85 | 86 | def forward(self, in0, in1, retPerLayer=False): 87 | ''' Function computes the distance between image patches in0 and in1 88 | INPUTS 89 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 90 | OUTPUT 91 | computed distances between in0 and in1 92 | ''' 93 | 94 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 95 | 96 | # ***** TRAINING FUNCTIONS ***** 97 | def optimize_parameters(self): 98 | self.forward_train() 99 | self.optimizer_net.zero_grad() 100 | self.backward_train() 101 | self.optimizer_net.step() 102 | self.clamp_weights() 103 | 104 | def clamp_weights(self): 105 | for module in self.net.modules(): 106 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 107 | module.weight.data = torch.clamp(module.weight.data,min=0) 108 | 109 | def set_input(self, data): 110 | self.input_ref = data['ref'] 111 | self.input_p0 = data['p0'] 112 | self.input_p1 = data['p1'] 113 | self.input_judge = data['judge'] 114 | 115 | if(self.use_gpu): 116 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 117 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 118 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 119 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 120 | 121 | self.var_ref = Variable(self.input_ref,requires_grad=True) 122 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 123 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 124 | 125 | def forward_train(self): # run forward pass 126 | self.d0 = self.forward(self.var_ref, self.var_p0) 127 | self.d1 = self.forward(self.var_ref, self.var_p1) 128 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 129 | 130 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 131 | 132 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 133 | 134 | return self.loss_total 135 | 136 | def backward_train(self): 137 | torch.mean(self.loss_total).backward() 138 | 139 | def compute_accuracy(self,d0,d1,judge): 140 | ''' d0, d1 are Variables, judge is a Tensor ''' 141 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 197 | self.old_lr = lr 198 | 199 | 200 | def get_image_paths(self): 201 | return self.image_paths 202 | 203 | def save_done(self, flag=False): 204 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 205 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 206 | 207 | 208 | def score_2afc_dataset(data_loader, func, name=''): 209 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 210 | distance function 'func' in dataset 'data_loader' 211 | INPUTS 212 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 213 | func - callable distance function - calling d=func(in0,in1) should take 2 214 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 215 | OUTPUTS 216 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 217 | [1] - dictionary with following elements 218 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 219 | gts - N array in [0,1], preferred patch selected by human evaluators 220 | (closer to "0" for left patch p0, "1" for right patch p1, 221 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 222 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 223 | CONSTS 224 | N - number of test triplets in data_loader 225 | ''' 226 | 227 | d0s = [] 228 | d1s = [] 229 | gts = [] 230 | 231 | for data in tqdm(data_loader.load_data(), desc=name): 232 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 233 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 234 | gts+=data['judge'].cpu().numpy().flatten().tolist() 235 | 236 | d0s = np.array(d0s) 237 | d1s = np.array(d1s) 238 | gts = np.array(gts) 239 | scores = (d0s args.start: 167 | index = index-1 168 | 169 | Fwarp = flow_warping(Gimgs[args.start], flows_forward[index]) 170 | warps_Gforward.append(Fwarp) 171 | Bwarp = flow_warping(x, flows_backward[index]) 172 | warps_Gbackward.append(Bwarp) 173 | 174 | 175 | tc_losses = MSE_Loss(torch.stack(warps_Iforward), torch.stack(warps_Gforward)) + MSE_Loss(torch.stack(warps_Ibackward), torch.stack(warps_Gbackward)) 176 | 177 | losses = mse_losses + perceptual_losses + tc_losses #+ distance_losses + distance_loss_ours 178 | losses.backward(retain_graph=True) 179 | optimizer.step() 180 | 181 | w_pbar.set_description( 182 | ( 183 | f'loss: {losses.item():.4f}; perceptual: {perceptual_losses.item():.4f}; mse: {mse_losses.item():.4f}; tc: {tc_losses.item():.4f}' 184 | ) 185 | ) 186 | del losses,mse_losses,perceptual_losses,tc_losses 187 | 188 | save_image(Gimgs[args.start].squeeze(0).clamp(0,1),inv_path+"/{}.png".format(t_names[args.start])) 189 | np.save(param_path+"{}.npy".format(t_names[args.start]),dlatents[args.start].detach().cpu().numpy()) 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | def caluclate_loss(synth_img,img,perceptual_net,img_p,MSE_Loss,upsample2d): 199 | #calculate MSE Loss 200 | mse_loss=MSE_Loss(synth_img,img) # (lamda_mse/N)*||G(w)-I||^2 201 | 202 | #calculate Perceptual Loss 203 | real_0,real_1,real_2,real_3=perceptual_net(img_p) 204 | synth_p=upsample2d(synth_img) 205 | synth_0,synth_1,synth_2,synth_3=perceptual_net(synth_p) 206 | 207 | perceptual_loss=0 208 | perceptual_loss+=MSE_Loss(synth_0,real_0) 209 | perceptual_loss+=MSE_Loss(synth_1,real_1) 210 | perceptual_loss+=MSE_Loss(synth_2,real_2) 211 | perceptual_loss+=MSE_Loss(synth_3,real_3) 212 | 213 | return mse_loss,perceptual_loss 214 | 215 | 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer('kernel', kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer('kernel', kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer('kernel', kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 128 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 166 | ) 167 | 168 | 169 | class ScaledLeakyReLU(nn.Module): 170 | def __init__(self, negative_slope=0.2): 171 | super().__init__() 172 | 173 | self.negative_slope = negative_slope 174 | 175 | def forward(self, input): 176 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 177 | 178 | return out * math.sqrt(2) 179 | 180 | 181 | class ModulatedConv2d(nn.Module): 182 | def __init__( 183 | self, 184 | in_channel, 185 | out_channel, 186 | kernel_size, 187 | style_dim, 188 | demodulate=True, 189 | upsample=False, 190 | downsample=False, 191 | blur_kernel=[1, 3, 3, 1], 192 | ): 193 | super().__init__() 194 | 195 | self.eps = 1e-8 196 | self.kernel_size = kernel_size 197 | self.in_channel = in_channel 198 | self.out_channel = out_channel 199 | self.upsample = upsample 200 | self.downsample = downsample 201 | 202 | if upsample: 203 | factor = 2 204 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 205 | pad0 = (p + 1) // 2 + factor - 1 206 | pad1 = p // 2 + 1 207 | 208 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 209 | 210 | if downsample: 211 | factor = 2 212 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 213 | pad0 = (p + 1) // 2 214 | pad1 = p // 2 215 | 216 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 217 | 218 | fan_in = in_channel * kernel_size ** 2 219 | self.scale = 1 / math.sqrt(fan_in) 220 | self.padding = kernel_size // 2 221 | 222 | self.weight = nn.Parameter( 223 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 224 | ) 225 | 226 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 227 | 228 | self.demodulate = demodulate 229 | 230 | def __repr__(self): 231 | return ( 232 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 233 | f'upsample={self.upsample}, downsample={self.downsample})' 234 | ) 235 | 236 | def forward(self, input, style): 237 | batch, in_channel, height, width = input.shape 238 | 239 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 240 | weight = self.scale * self.weight * style 241 | 242 | if self.demodulate: 243 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 244 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 245 | 246 | weight = weight.view( 247 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 248 | ) 249 | 250 | if self.upsample: 251 | input = input.view(1, batch * in_channel, height, width) 252 | weight = weight.view( 253 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 254 | ) 255 | weight = weight.transpose(1, 2).reshape( 256 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 257 | ) 258 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 259 | _, _, height, width = out.shape 260 | out = out.view(batch, self.out_channel, height, width) 261 | out = self.blur(out) 262 | 263 | elif self.downsample: 264 | input = self.blur(input) 265 | _, _, height, width = input.shape 266 | input = input.view(1, batch * in_channel, height, width) 267 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 268 | _, _, height, width = out.shape 269 | out = out.view(batch, self.out_channel, height, width) 270 | 271 | else: 272 | input = input.view(1, batch * in_channel, height, width) 273 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 274 | _, _, height, width = out.shape 275 | out = out.view(batch, self.out_channel, height, width) 276 | 277 | return out 278 | 279 | 280 | class NoiseInjection(nn.Module): 281 | def __init__(self): 282 | super().__init__() 283 | 284 | self.weight = nn.Parameter(torch.zeros(1)) 285 | 286 | def forward(self, image, noise=None): 287 | if noise is None: 288 | batch, _, height, width = image.shape 289 | noise = image.new_empty(batch, 1, height, width).normal_() 290 | 291 | return image + self.weight * noise 292 | 293 | 294 | class ConstantInput(nn.Module): 295 | def __init__(self, channel, size=4): 296 | super().__init__() 297 | 298 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 299 | 300 | def forward(self, input): 301 | batch = input.shape[0] 302 | out = self.input.repeat(batch, 1, 1, 1) 303 | 304 | return out 305 | 306 | 307 | class StyledConv(nn.Module): 308 | def __init__( 309 | self, 310 | in_channel, 311 | out_channel, 312 | kernel_size, 313 | style_dim, 314 | upsample=False, 315 | blur_kernel=[1, 3, 3, 1], 316 | demodulate=True, 317 | ): 318 | super().__init__() 319 | 320 | self.conv = ModulatedConv2d( 321 | in_channel, 322 | out_channel, 323 | kernel_size, 324 | style_dim, 325 | upsample=upsample, 326 | blur_kernel=blur_kernel, 327 | demodulate=demodulate, 328 | ) 329 | 330 | self.noise = NoiseInjection() 331 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 332 | # self.activate = ScaledLeakyReLU(0.2) 333 | self.activate = FusedLeakyReLU(out_channel) 334 | 335 | def forward(self, input, style, noise=None): 336 | out = self.conv(input, style) 337 | out = self.noise(out, noise=noise) 338 | # out = out + self.bias 339 | out = self.activate(out) 340 | 341 | return out 342 | 343 | 344 | class ToRGB(nn.Module): 345 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 346 | super().__init__() 347 | 348 | if upsample: 349 | self.upsample = Upsample(blur_kernel) 350 | 351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 353 | 354 | def forward(self, input, style, skip=None): 355 | out = self.conv(input, style) 356 | out = out + self.bias 357 | 358 | if skip is not None: 359 | skip = self.upsample(skip) 360 | 361 | out = out + skip 362 | 363 | return out 364 | 365 | 366 | class Generator(nn.Module): 367 | def __init__( 368 | self, 369 | size, 370 | style_dim, 371 | n_mlp, 372 | channel_multiplier=2, 373 | blur_kernel=[1, 3, 3, 1], 374 | lr_mlp=0.01, 375 | ): 376 | super().__init__() 377 | 378 | self.size = size 379 | 380 | self.style_dim = style_dim 381 | 382 | layers = [PixelNorm()] 383 | 384 | for i in range(n_mlp): 385 | layers.append( 386 | EqualLinear( 387 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 388 | ) 389 | ) 390 | 391 | self.style = nn.Sequential(*layers) 392 | 393 | self.channels = { 394 | 4: 512, 395 | 8: 512, 396 | 16: 512, 397 | 32: 512, 398 | 64: 256 * channel_multiplier, 399 | 128: 128 * channel_multiplier, 400 | 256: 64 * channel_multiplier, 401 | 512: 32 * channel_multiplier, 402 | 1024: 16 * channel_multiplier, 403 | } 404 | 405 | self.input = ConstantInput(self.channels[4]) 406 | self.conv1 = StyledConv( 407 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 408 | ) 409 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 410 | 411 | self.log_size = int(math.log(size, 2)) 412 | self.num_layers = (self.log_size - 2) * 2 + 1 413 | 414 | self.convs = nn.ModuleList() 415 | self.upsamples = nn.ModuleList() 416 | self.to_rgbs = nn.ModuleList() 417 | self.noises = nn.Module() 418 | 419 | in_channel = self.channels[4] 420 | 421 | for layer_idx in range(self.num_layers): 422 | res = (layer_idx + 5) // 2 423 | shape = [1, 1, 2 ** res, 2 ** res] 424 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) 425 | 426 | for i in range(3, self.log_size + 1): 427 | out_channel = self.channels[2 ** i] 428 | 429 | self.convs.append( 430 | StyledConv( 431 | in_channel, 432 | out_channel, 433 | 3, 434 | style_dim, 435 | upsample=True, 436 | blur_kernel=blur_kernel, 437 | ) 438 | ) 439 | 440 | self.convs.append( 441 | StyledConv( 442 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 443 | ) 444 | ) 445 | 446 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 447 | 448 | in_channel = out_channel 449 | 450 | self.n_latent = self.log_size * 2 - 2 451 | 452 | def make_noise(self): 453 | device = self.input.input.device 454 | 455 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 456 | 457 | for i in range(3, self.log_size + 1): 458 | for _ in range(2): 459 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 460 | 461 | return noises 462 | 463 | def mean_latent(self, n_latent): 464 | latent_in = torch.randn( 465 | n_latent, self.style_dim, device=self.input.input.device 466 | ) 467 | latent = self.style(latent_in).mean(0, keepdim=True) 468 | 469 | return latent 470 | 471 | def get_latent(self, input): 472 | return self.style(input) 473 | 474 | def forward( 475 | self, 476 | styles, 477 | return_latents=False, 478 | inject_index=None, 479 | truncation=1, 480 | truncation_latent=None, 481 | input_is_latent=False, 482 | noise=None, 483 | randomize_noise=True, 484 | ): 485 | if not input_is_latent: 486 | styles = [self.style(s) for s in styles] 487 | a = styles 488 | if noise is None: 489 | if randomize_noise: 490 | noise = [None] * self.num_layers 491 | else: 492 | noise = [ 493 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 494 | ] 495 | 496 | if truncation < 1: 497 | style_t = [] 498 | 499 | for style in styles: 500 | style_t.append( 501 | truncation_latent + truncation * (style - truncation_latent) 502 | ) 503 | 504 | styles = style_t 505 | 506 | if len(styles) < 2: 507 | inject_index = self.n_latent 508 | 509 | if styles[0].ndim < 3: 510 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 511 | 512 | else: 513 | latent = styles[0] 514 | 515 | else: 516 | if inject_index is None: 517 | inject_index = random.randint(1, self.n_latent - 1) 518 | 519 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 520 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 521 | 522 | latent = torch.cat([latent, latent2], 1) 523 | 524 | out = self.input(latent) 525 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 526 | 527 | skip = self.to_rgb1(out, latent[:, 1]) 528 | 529 | i = 1 530 | for conv1, conv2, noise1, noise2, to_rgb in zip( 531 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 532 | ): 533 | out = conv1(out, latent[:, i], noise=noise1) 534 | out = conv2(out, latent[:, i + 1], noise=noise2) 535 | skip = to_rgb(out, latent[:, i + 2], skip) 536 | 537 | i += 2 538 | 539 | image = skip 540 | 541 | if return_latents: 542 | return image, latent#[0] 543 | 544 | else: 545 | return image, None 546 | 547 | 548 | class ConvLayer(nn.Sequential): 549 | def __init__( 550 | self, 551 | in_channel, 552 | out_channel, 553 | kernel_size, 554 | downsample=False, 555 | blur_kernel=[1, 3, 3, 1], 556 | bias=True, 557 | activate=True, 558 | ): 559 | layers = [] 560 | 561 | if downsample: 562 | factor = 2 563 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 564 | pad0 = (p + 1) // 2 565 | pad1 = p // 2 566 | 567 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 568 | 569 | stride = 2 570 | self.padding = 0 571 | 572 | else: 573 | stride = 1 574 | self.padding = kernel_size // 2 575 | 576 | layers.append( 577 | EqualConv2d( 578 | in_channel, 579 | out_channel, 580 | kernel_size, 581 | padding=self.padding, 582 | stride=stride, 583 | bias=bias and not activate, 584 | ) 585 | ) 586 | 587 | if activate: 588 | if bias: 589 | layers.append(FusedLeakyReLU(out_channel)) 590 | 591 | else: 592 | layers.append(ScaledLeakyReLU(0.2)) 593 | 594 | super().__init__(*layers) 595 | 596 | 597 | class ResBlock(nn.Module): 598 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 599 | super().__init__() 600 | 601 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 602 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 603 | 604 | self.skip = ConvLayer( 605 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 606 | ) 607 | 608 | def forward(self, input): 609 | out = self.conv1(input) 610 | out = self.conv2(out) 611 | 612 | skip = self.skip(input) 613 | out = (out + skip) / math.sqrt(2) 614 | 615 | return out 616 | 617 | 618 | class Discriminator(nn.Module): 619 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 620 | super().__init__() 621 | 622 | channels = { 623 | 4: 512//2, 624 | 8: 512//2, 625 | 16: 512//2, 626 | 32: 512//2, 627 | 64: 256//2 * channel_multiplier, 628 | 128: 128//2 * channel_multiplier, 629 | 256: 64//2 * channel_multiplier, 630 | 512: 32//2 * channel_multiplier, 631 | 1024: 16//2 * channel_multiplier, 632 | } 633 | 634 | convs = [ConvLayer(3, channels[size], 1)] 635 | 636 | log_size = int(math.log(size, 2)) 637 | 638 | in_channel = channels[size] 639 | 640 | for i in range(log_size, 2, -1): 641 | out_channel = channels[2 ** (i - 1)] 642 | 643 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 644 | 645 | in_channel = out_channel 646 | 647 | self.convs = nn.Sequential(*convs) 648 | 649 | self.stddev_group = 4 650 | self.stddev_feat = 1 651 | 652 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 653 | self.final_linear = nn.Sequential( 654 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 655 | EqualLinear(channels[4], 1), 656 | ) 657 | 658 | def forward(self, input): 659 | out = self.convs(input) 660 | 661 | batch, channel, height, width = out.shape 662 | group = min(batch, self.stddev_group) 663 | stddev = out.view( 664 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 665 | ) 666 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 667 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 668 | stddev = stddev.repeat(group, 1, height, width) 669 | out = torch.cat([out, stddev], 1) 670 | 671 | out = self.final_conv(out) 672 | 673 | out = out.view(batch, -1) 674 | out = self.final_linear(out) 675 | 676 | return out 677 | 678 | 679 | class Encoder(nn.Module): 680 | def __init__(self, size, w_dim=512): 681 | super().__init__() 682 | 683 | channels = { 684 | 4: 512//2, 685 | 8: 512//2, 686 | 16: 512//2, 687 | 32: 512//2, 688 | 64: 256//2, 689 | 128: 128//2, 690 | 256: 64//2, 691 | 512: 32//2, 692 | 1024: 16//2 693 | } 694 | self.w_dim = w_dim 695 | log_size = int(math.log(size, 2)) 696 | 697 | self.n_latents = log_size*2 - 2 698 | 699 | convs = [ConvLayer(1, channels[size], 1)] 700 | 701 | in_channel = channels[size] 702 | for i in range(log_size, 2, -1): 703 | out_channel = channels[2 ** (i - 1)] 704 | convs.append(ResBlock(in_channel, out_channel)) 705 | in_channel = out_channel 706 | 707 | convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False)) 708 | 709 | self.convs = nn.Sequential(*convs) 710 | 711 | def forward(self, input): 712 | out = self.convs(input) 713 | return out.view(len(input), self.n_latents, self.w_dim) 714 | -------------------------------------------------------------------------------- /networks/FlowNetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .correlation_package.correlation import Correlation 9 | 10 | from .submodules import * 11 | 'Parameter count , 39,175,298 ' 12 | 13 | class FlowNetC(nn.Module): 14 | def __init__(self,args, batchNorm=True, div_flow = 20): 15 | super(FlowNetC,self).__init__() 16 | 17 | self.batchNorm = batchNorm 18 | self.div_flow = div_flow 19 | 20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 24 | 25 | if args.fp16: 26 | self.corr = nn.Sequential( 27 | tofp32(), 28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), 29 | tofp16()) 30 | else: 31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 32 | 33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 34 | self.conv3_1 = conv(self.batchNorm, 473, 256) 35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 36 | self.conv4_1 = conv(self.batchNorm, 512, 512) 37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 38 | self.conv5_1 = conv(self.batchNorm, 512, 512) 39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 40 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 41 | 42 | self.deconv5 = deconv(1024,512) 43 | self.deconv4 = deconv(1026,256) 44 | self.deconv3 = deconv(770,128) 45 | self.deconv2 = deconv(386,64) 46 | 47 | self.predict_flow6 = predict_flow(1024) 48 | self.predict_flow5 = predict_flow(1026) 49 | self.predict_flow4 = predict_flow(770) 50 | self.predict_flow3 = predict_flow(386) 51 | self.predict_flow2 = predict_flow(194) 52 | 53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | if m.bias is not None: 61 | init.uniform_(m.bias) 62 | init.xavier_uniform_(m.weight) 63 | 64 | if isinstance(m, nn.ConvTranspose2d): 65 | if m.bias is not None: 66 | init.uniform_(m.bias) 67 | init.xavier_uniform_(m.weight) 68 | # init_deconv_bilinear(m.weight) 69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 70 | 71 | def forward(self, x): 72 | x1 = x[:,0:3,:,:] 73 | x2 = x[:,3::,:,:] 74 | 75 | out_conv1a = self.conv1(x1) 76 | out_conv2a = self.conv2(out_conv1a) 77 | out_conv3a = self.conv3(out_conv2a) 78 | 79 | # FlownetC bottom input stream 80 | out_conv1b = self.conv1(x2) 81 | 82 | out_conv2b = self.conv2(out_conv1b) 83 | out_conv3b = self.conv3(out_conv2b) 84 | 85 | # Merge streams 86 | out_corr = self.corr(out_conv3a, out_conv3b) # False 87 | out_corr = self.corr_activation(out_corr) 88 | 89 | # Redirect top input stream and concatenate 90 | out_conv_redir = self.conv_redir(out_conv3a) 91 | 92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) 93 | 94 | # Merged conv layers 95 | out_conv3_1 = self.conv3_1(in_conv3_1) 96 | 97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) 98 | 99 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 100 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 101 | 102 | flow6 = self.predict_flow6(out_conv6) 103 | flow6_up = self.upsampled_flow6_to_5(flow6) 104 | out_deconv5 = self.deconv5(out_conv6) 105 | 106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 107 | 108 | flow5 = self.predict_flow5(concat5) 109 | flow5_up = self.upsampled_flow5_to_4(flow5) 110 | out_deconv4 = self.deconv4(concat5) 111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 112 | 113 | flow4 = self.predict_flow4(concat4) 114 | flow4_up = self.upsampled_flow4_to_3(flow4) 115 | out_deconv3 = self.deconv3(concat4) 116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) 117 | 118 | flow3 = self.predict_flow3(concat3) 119 | flow3_up = self.upsampled_flow3_to_2(flow3) 120 | out_deconv2 = self.deconv2(concat3) 121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) 122 | 123 | flow2 = self.predict_flow2(concat2) 124 | 125 | if self.training: 126 | return flow2,flow3,flow4,flow5,flow6 127 | else: 128 | return flow2, 129 | -------------------------------------------------------------------------------- /networks/FlowNetFusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 581,226' 10 | 11 | class FlowNetFusion(nn.Module): 12 | def __init__(self,args, batchNorm=True): 13 | super(FlowNetFusion,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 11, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | 22 | self.deconv1 = deconv(128,32) 23 | self.deconv0 = deconv(162,16) 24 | 25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32) 26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16) 27 | 28 | self.predict_flow2 = predict_flow(128) 29 | self.predict_flow1 = predict_flow(32) 30 | self.predict_flow0 = predict_flow(16) 31 | 32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | if m.bias is not None: 38 | init.uniform_(m.bias) 39 | init.xavier_uniform_(m.weight) 40 | 41 | if isinstance(m, nn.ConvTranspose2d): 42 | if m.bias is not None: 43 | init.uniform_(m.bias) 44 | init.xavier_uniform_(m.weight) 45 | # init_deconv_bilinear(m.weight) 46 | 47 | def forward(self, x): 48 | out_conv0 = self.conv0(x) 49 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 50 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 51 | 52 | flow2 = self.predict_flow2(out_conv2) 53 | flow2_up = self.upsampled_flow2_to_1(flow2) 54 | out_deconv1 = self.deconv1(out_conv2) 55 | 56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1) 57 | out_interconv1 = self.inter_conv1(concat1) 58 | flow1 = self.predict_flow1(out_interconv1) 59 | flow1_up = self.upsampled_flow1_to_0(flow1) 60 | out_deconv0 = self.deconv0(concat1) 61 | 62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1) 63 | out_interconv0 = self.inter_conv0(concat0) 64 | flow0 = self.predict_flow0(out_interconv0) 65 | 66 | return flow0 67 | 68 | -------------------------------------------------------------------------------- /networks/FlowNetS.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | import math 10 | import numpy as np 11 | 12 | from .submodules import * 13 | 'Parameter count : 38,676,504 ' 14 | 15 | class FlowNetS(nn.Module): 16 | def __init__(self, args, input_channels = 12, batchNorm=True): 17 | super(FlowNetS,self).__init__() 18 | 19 | self.batchNorm = batchNorm 20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv3_1 = conv(self.batchNorm, 256, 256) 24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 25 | self.conv4_1 = conv(self.batchNorm, 512, 512) 26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 27 | self.conv5_1 = conv(self.batchNorm, 512, 512) 28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 29 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 30 | 31 | self.deconv5 = deconv(1024,512) 32 | self.deconv4 = deconv(1026,256) 33 | self.deconv3 = deconv(770,128) 34 | self.deconv2 = deconv(386,64) 35 | 36 | self.predict_flow6 = predict_flow(1024) 37 | self.predict_flow5 = predict_flow(1026) 38 | self.predict_flow4 = predict_flow(770) 39 | self.predict_flow3 = predict_flow(386) 40 | self.predict_flow2 = predict_flow(194) 41 | 42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | if m.bias is not None: 50 | init.uniform_(m.bias) 51 | init.xavier_uniform_(m.weight) 52 | 53 | if isinstance(m, nn.ConvTranspose2d): 54 | if m.bias is not None: 55 | init.uniform_(m.bias) 56 | init.xavier_uniform_(m.weight) 57 | # init_deconv_bilinear(m.weight) 58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 59 | 60 | def forward(self, x): 61 | out_conv1 = self.conv1(x) 62 | 63 | out_conv2 = self.conv2(out_conv1) 64 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 65 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 66 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 67 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 68 | 69 | flow6 = self.predict_flow6(out_conv6) 70 | flow6_up = self.upsampled_flow6_to_5(flow6) 71 | out_deconv5 = self.deconv5(out_conv6) 72 | 73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 74 | flow5 = self.predict_flow5(concat5) 75 | flow5_up = self.upsampled_flow5_to_4(flow5) 76 | out_deconv4 = self.deconv4(concat5) 77 | 78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 79 | flow4 = self.predict_flow4(concat4) 80 | flow4_up = self.upsampled_flow4_to_3(flow4) 81 | out_deconv3 = self.deconv3(concat4) 82 | 83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 84 | flow3 = self.predict_flow3(concat3) 85 | flow3_up = self.upsampled_flow3_to_2(flow3) 86 | out_deconv2 = self.deconv2(concat3) 87 | 88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 89 | flow2 = self.predict_flow2(concat2) 90 | 91 | if self.training: 92 | return flow2,flow3,flow4,flow5,flow6 93 | else: 94 | return flow2, 95 | 96 | -------------------------------------------------------------------------------- /networks/FlowNetSD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 45,371,666' 10 | 11 | class FlowNetSD(nn.Module): 12 | def __init__(self, args, batchNorm=True): 13 | super(FlowNetSD,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 6, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2) 22 | self.conv3_1 = conv(self.batchNorm, 256, 256) 23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 24 | self.conv4_1 = conv(self.batchNorm, 512, 512) 25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 26 | self.conv5_1 = conv(self.batchNorm, 512, 512) 27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 28 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 29 | 30 | self.deconv5 = deconv(1024,512) 31 | self.deconv4 = deconv(1026,256) 32 | self.deconv3 = deconv(770,128) 33 | self.deconv2 = deconv(386,64) 34 | 35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512) 36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256) 37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128) 38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64) 39 | 40 | self.predict_flow6 = predict_flow(1024) 41 | self.predict_flow5 = predict_flow(512) 42 | self.predict_flow4 = predict_flow(256) 43 | self.predict_flow3 = predict_flow(128) 44 | self.predict_flow2 = predict_flow(64) 45 | 46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | if m.bias is not None: 54 | init.uniform_(m.bias) 55 | init.xavier_uniform_(m.weight) 56 | 57 | if isinstance(m, nn.ConvTranspose2d): 58 | if m.bias is not None: 59 | init.uniform_(m.bias) 60 | init.xavier_uniform_(m.weight) 61 | # init_deconv_bilinear(m.weight) 62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 63 | 64 | 65 | 66 | def forward(self, x): 67 | out_conv0 = self.conv0(x) 68 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 69 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 70 | 71 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 72 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 73 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 74 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 75 | 76 | flow6 = self.predict_flow6(out_conv6) 77 | flow6_up = self.upsampled_flow6_to_5(flow6) 78 | out_deconv5 = self.deconv5(out_conv6) 79 | 80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 81 | out_interconv5 = self.inter_conv5(concat5) 82 | flow5 = self.predict_flow5(out_interconv5) 83 | 84 | flow5_up = self.upsampled_flow5_to_4(flow5) 85 | out_deconv4 = self.deconv4(concat5) 86 | 87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 88 | out_interconv4 = self.inter_conv4(concat4) 89 | flow4 = self.predict_flow4(out_interconv4) 90 | flow4_up = self.upsampled_flow4_to_3(flow4) 91 | out_deconv3 = self.deconv3(concat4) 92 | 93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 94 | out_interconv3 = self.inter_conv3(concat3) 95 | flow3 = self.predict_flow3(out_interconv3) 96 | flow3_up = self.upsampled_flow3_to_2(flow3) 97 | out_deconv2 = self.deconv2(concat3) 98 | 99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 100 | out_interconv2 = self.inter_conv2(concat2) 101 | flow2 = self.predict_flow2(out_interconv2) 102 | 103 | if self.training: 104 | return flow2,flow3,flow4,flow5,flow6 105 | else: 106 | return flow2, 107 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/InvertingGANs_with_ConsecutiveImgs/9078a48ec3474dacdd02693b051e3addef1c5697/networks/__init__.py -------------------------------------------------------------------------------- /networks/channelnorm_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/InvertingGANs_with_ConsecutiveImgs/9078a48ec3474dacdd02693b051e3addef1c5697/networks/channelnorm_package/__init__.py -------------------------------------------------------------------------------- /networks/channelnorm_package/channelnorm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from torch.nn.modules.module import Module 3 | import channelnorm_cuda 4 | 5 | class ChannelNormFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, norm_deg=2): 9 | assert input1.is_contiguous() 10 | b, _, h, w = input1.size() 11 | output = input1.new(b, 1, h, w).zero_() 12 | 13 | channelnorm_cuda.forward(input1, output, norm_deg) 14 | ctx.save_for_backward(input1, output) 15 | ctx.norm_deg = norm_deg 16 | 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input1, output = ctx.saved_tensors 22 | 23 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 24 | 25 | channelnorm_cuda.backward(input1, output, grad_output.data, 26 | grad_input1.data, ctx.norm_deg) 27 | 28 | return grad_input1, None 29 | 30 | 31 | class ChannelNorm(Module): 32 | 33 | def __init__(self, norm_deg=2): 34 | super(ChannelNorm, self).__init__() 35 | self.norm_deg = norm_deg 36 | 37 | def forward(self, input1): 38 | return ChannelNormFunction.apply(input1, self.norm_deg) 39 | 40 | -------------------------------------------------------------------------------- /networks/channelnorm_package/channelnorm_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "channelnorm_kernel.cuh" 5 | 6 | int channelnorm_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& output, 9 | int norm_deg) { 10 | 11 | channelnorm_kernel_forward(input1, output, norm_deg); 12 | return 1; 13 | } 14 | 15 | 16 | int channelnorm_cuda_backward( 17 | at::Tensor& input1, 18 | at::Tensor& output, 19 | at::Tensor& gradOutput, 20 | at::Tensor& gradInput1, 21 | int norm_deg) { 22 | 23 | channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); 24 | return 1; 25 | } 26 | 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); 29 | m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); 30 | } 31 | 32 | -------------------------------------------------------------------------------- /networks/channelnorm_package/channelnorm_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "channelnorm_kernel.cuh" 6 | 7 | #define CUDA_NUM_THREADS 512 8 | 9 | #define DIM0(TENSOR) ((TENSOR).x) 10 | #define DIM1(TENSOR) ((TENSOR).y) 11 | #define DIM2(TENSOR) ((TENSOR).z) 12 | #define DIM3(TENSOR) ((TENSOR).w) 13 | 14 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 15 | 16 | using at::Half; 17 | 18 | template 19 | __global__ void kernel_channelnorm_update_output( 20 | const int n, 21 | const scalar_t* __restrict__ input1, 22 | const long4 input1_size, 23 | const long4 input1_stride, 24 | scalar_t* __restrict__ output, 25 | const long4 output_size, 26 | const long4 output_stride, 27 | int norm_deg) { 28 | 29 | int index = blockIdx.x * blockDim.x + threadIdx.x; 30 | 31 | if (index >= n) { 32 | return; 33 | } 34 | 35 | int dim_b = DIM0(output_size); 36 | int dim_c = DIM1(output_size); 37 | int dim_h = DIM2(output_size); 38 | int dim_w = DIM3(output_size); 39 | int dim_chw = dim_c * dim_h * dim_w; 40 | 41 | int b = ( index / dim_chw ) % dim_b; 42 | int y = ( index / dim_w ) % dim_h; 43 | int x = ( index ) % dim_w; 44 | 45 | int i1dim_c = DIM1(input1_size); 46 | int i1dim_h = DIM2(input1_size); 47 | int i1dim_w = DIM3(input1_size); 48 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; 49 | int i1dim_hw = i1dim_h * i1dim_w; 50 | 51 | float result = 0.0; 52 | 53 | for (int c = 0; c < i1dim_c; ++c) { 54 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; 55 | scalar_t val = input1[i1Index]; 56 | result += static_cast(val * val); 57 | } 58 | result = sqrt(result); 59 | output[index] = static_cast(result); 60 | } 61 | 62 | 63 | template 64 | __global__ void kernel_channelnorm_backward_input1( 65 | const int n, 66 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 67 | const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, 68 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 69 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, 70 | int norm_deg) { 71 | 72 | int index = blockIdx.x * blockDim.x + threadIdx.x; 73 | 74 | if (index >= n) { 75 | return; 76 | } 77 | 78 | float val = 0.0; 79 | 80 | int dim_b = DIM0(gradInput_size); 81 | int dim_c = DIM1(gradInput_size); 82 | int dim_h = DIM2(gradInput_size); 83 | int dim_w = DIM3(gradInput_size); 84 | int dim_chw = dim_c * dim_h * dim_w; 85 | int dim_hw = dim_h * dim_w; 86 | 87 | int b = ( index / dim_chw ) % dim_b; 88 | int y = ( index / dim_w ) % dim_h; 89 | int x = ( index ) % dim_w; 90 | 91 | 92 | int outIndex = b * dim_hw + y * dim_w + x; 93 | val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); 94 | gradInput[index] = static_cast(val); 95 | 96 | } 97 | 98 | void channelnorm_kernel_forward( 99 | at::Tensor& input1, 100 | at::Tensor& output, 101 | int norm_deg) { 102 | 103 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 104 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 105 | 106 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 107 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 108 | 109 | int n = output.numel(); 110 | 111 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { 112 | 113 | kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 114 | //at::globalContext().getCurrentCUDAStream() >>>( 115 | n, 116 | input1.data(), 117 | input1_size, 118 | input1_stride, 119 | output.data(), 120 | output_size, 121 | output_stride, 122 | norm_deg); 123 | 124 | })); 125 | 126 | // TODO: ATen-equivalent check 127 | 128 | // THCudaCheck(cudaGetLastError()); 129 | } 130 | 131 | void channelnorm_kernel_backward( 132 | at::Tensor& input1, 133 | at::Tensor& output, 134 | at::Tensor& gradOutput, 135 | at::Tensor& gradInput1, 136 | int norm_deg) { 137 | 138 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 139 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 140 | 141 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 142 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 143 | 144 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 145 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 146 | 147 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 148 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 149 | 150 | int n = gradInput1.numel(); 151 | 152 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { 153 | 154 | kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 155 | //at::globalContext().getCurrentCUDAStream() >>>( 156 | n, 157 | input1.data(), 158 | input1_size, 159 | input1_stride, 160 | output.data(), 161 | output_size, 162 | output_stride, 163 | gradOutput.data(), 164 | gradOutput_size, 165 | gradOutput_stride, 166 | gradInput1.data(), 167 | gradInput1_size, 168 | gradInput1_stride, 169 | norm_deg 170 | ); 171 | 172 | })); 173 | 174 | // TODO: Add ATen-equivalent check 175 | 176 | // THCudaCheck(cudaGetLastError()); 177 | } 178 | -------------------------------------------------------------------------------- /networks/channelnorm_package/channelnorm_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void channelnorm_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& output, 8 | int norm_deg); 9 | 10 | 11 | void channelnorm_kernel_backward( 12 | at::Tensor& input1, 13 | at::Tensor& output, 14 | at::Tensor& gradOutput, 15 | at::Tensor& gradInput1, 16 | int norm_deg); 17 | -------------------------------------------------------------------------------- /networks/channelnorm_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_52,code=sm_52', 12 | '-gencode', 'arch=compute_60,code=sm_60', 13 | '-gencode', 'arch=compute_61,code=sm_61', 14 | '-gencode', 'arch=compute_70,code=sm_70', 15 | '-gencode', 'arch=compute_70,code=compute_70' 16 | ] 17 | 18 | setup( 19 | name='channelnorm_cuda', 20 | ext_modules=[ 21 | CUDAExtension('channelnorm_cuda', [ 22 | 'channelnorm_cuda.cc', 23 | 'channelnorm_kernel.cu' 24 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 25 | ], 26 | cmdclass={ 27 | 'build_ext': BuildExtension 28 | }) 29 | -------------------------------------------------------------------------------- /networks/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/InvertingGANs_with_ConsecutiveImgs/9078a48ec3474dacdd02693b051e3addef1c5697/networks/correlation_package/__init__.py -------------------------------------------------------------------------------- /networks/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | @staticmethod 9 | def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 10 | ctx.save_for_backward(input1, input2) 11 | 12 | ctx.pad_size = pad_size 13 | ctx.kernel_size = kernel_size 14 | ctx.max_displacement = max_displacement 15 | ctx.stride1 = stride1 16 | ctx.stride2 = stride2 17 | ctx.corr_multiply = corr_multiply 18 | 19 | with torch.cuda.device_of(input1): 20 | rbot1 = input1.new() 21 | rbot2 = input2.new() 22 | output = input1.new() 23 | 24 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 25 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) 26 | 27 | return output 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input1, input2 = ctx.saved_tensors 32 | 33 | with torch.cuda.device_of(input1): 34 | rbot1 = input1.new() 35 | rbot2 = input2.new() 36 | 37 | grad_input1 = input1.new() 38 | grad_input2 = input2.new() 39 | 40 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 41 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) 42 | 43 | return grad_input1, grad_input2, None, None, None, None, None, None 44 | 45 | 46 | class Correlation(Module): 47 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 48 | super(Correlation, self).__init__() 49 | self.pad_size = pad_size 50 | self.kernel_size = kernel_size 51 | self.max_displacement = max_displacement 52 | self.stride1 = stride1 53 | self.stride2 = stride2 54 | self.corr_multiply = corr_multiply 55 | 56 | def forward(self, input1, input2): 57 | 58 | result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement, self.stride1, self.stride2, self.corr_multiply) 59 | 60 | return result 61 | 62 | -------------------------------------------------------------------------------- /networks/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | 44 | int success = correlation_forward_cuda_kernel( 45 | output, 46 | output.size(0), 47 | output.size(1), 48 | output.size(2), 49 | output.size(3), 50 | output.stride(0), 51 | output.stride(1), 52 | output.stride(2), 53 | output.stride(3), 54 | input1, 55 | input1.size(1), 56 | input1.size(2), 57 | input1.size(3), 58 | input1.stride(0), 59 | input1.stride(1), 60 | input1.stride(2), 61 | input1.stride(3), 62 | input2, 63 | input2.size(1), 64 | input2.stride(0), 65 | input2.stride(1), 66 | input2.stride(2), 67 | input2.stride(3), 68 | rInput1, 69 | rInput2, 70 | pad_size, 71 | kernel_size, 72 | max_displacement, 73 | stride1, 74 | stride2, 75 | corr_type_multiply, 76 | at::cuda::getCurrentCUDAStream() 77 | //at::globalContext().getCurrentCUDAStream() 78 | ); 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | //at::globalContext().getCurrentCUDAStream() 160 | ); 161 | 162 | if (!success) { 163 | AT_ERROR("CUDA call failed"); 164 | } 165 | 166 | return 1; 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 172 | } 173 | 174 | -------------------------------------------------------------------------------- /networks/correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | #define FULL_MASK 0xffffffff 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using at::Half; 15 | 16 | template 17 | __forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { 18 | for (int offset = 16; offset > 0; offset /= 2) 19 | val += __shfl_down_sync(FULL_MASK, val, offset); 20 | return val; 21 | } 22 | 23 | template 24 | __forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { 25 | 26 | static __shared__ scalar_t shared[32]; 27 | int lane = threadIdx.x % warpSize; 28 | int wid = threadIdx.x / warpSize; 29 | 30 | val = warpReduceSum(val); 31 | 32 | if (lane == 0) 33 | shared[wid] = val; 34 | 35 | __syncthreads(); 36 | 37 | val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; 38 | 39 | if (wid == 0) 40 | val = warpReduceSum(val); 41 | 42 | return val; 43 | } 44 | 45 | 46 | template 47 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 48 | { 49 | 50 | // n (batch size), c (num of channels), y (height), x (width) 51 | int n = blockIdx.x; 52 | int y = blockIdx.y; 53 | int x = blockIdx.z; 54 | 55 | int ch_off = threadIdx.x; 56 | scalar_t value; 57 | 58 | int dimcyx = channels * height * width; 59 | int dimyx = height * width; 60 | 61 | int p_dimx = (width + 2 * pad_size); 62 | int p_dimy = (height + 2 * pad_size); 63 | int p_dimyxc = channels * p_dimy * p_dimx; 64 | int p_dimxc = p_dimx * channels; 65 | 66 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 67 | value = input[n * dimcyx + c * dimyx + y * width + x]; 68 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 69 | } 70 | } 71 | 72 | 73 | template 74 | __global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, 75 | const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, 76 | const int nInputChannels, const int inputHeight, const int inputWidth, 77 | const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, 78 | const int max_displacement, const int stride1, const int stride2) { 79 | 80 | int32_t pInputWidth = inputWidth + 2 * pad_size; 81 | int32_t pInputHeight = inputHeight + 2 * pad_size; 82 | 83 | int32_t kernel_rad = (kernel_size - 1) / 2; 84 | 85 | int32_t displacement_rad = max_displacement / stride2; 86 | 87 | int32_t displacement_size = 2 * displacement_rad + 1; 88 | 89 | int32_t n = blockIdx.x; 90 | int32_t y1 = blockIdx.y * stride1 + max_displacement; 91 | int32_t x1 = blockIdx.z * stride1 + max_displacement; 92 | int32_t c = threadIdx.x; 93 | 94 | int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; 95 | 96 | int32_t pdimxc = pInputWidth * nInputChannels; 97 | 98 | int32_t pdimc = nInputChannels; 99 | 100 | int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; 101 | int32_t tdimyx = outputHeight * outputWidth; 102 | int32_t tdimx = outputWidth; 103 | 104 | int32_t nelems = kernel_size * kernel_size * pdimc; 105 | 106 | // element-wise product along channel axis 107 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { 108 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { 109 | int x2 = x1 + ti * stride2; 110 | int y2 = y1 + tj * stride2; 111 | 112 | float acc0 = 0.0f; 113 | 114 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 115 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 116 | // THREADS_PER_BLOCK 117 | #pragma unroll 118 | for (int ch = c; ch < pdimc; ch += blockDim.x) { 119 | 120 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc 121 | + (x1 + i) * pdimc + ch; 122 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc 123 | + (x2 + i) * pdimc + ch; 124 | acc0 += static_cast(rInput1[indx1] * rInput2[indx2]); 125 | } 126 | } 127 | } 128 | 129 | if (blockDim.x == warpSize) { 130 | __syncwarp(); 131 | acc0 = warpReduceSum(acc0); 132 | } else { 133 | __syncthreads(); 134 | acc0 = blockReduceSum(acc0); 135 | } 136 | 137 | if (threadIdx.x == 0) { 138 | 139 | int tc = (tj + displacement_rad) * displacement_size 140 | + (ti + displacement_rad); 141 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx 142 | + blockIdx.z; 143 | output[tindx] = static_cast(acc0 / nelems); 144 | } 145 | } 146 | } 147 | } 148 | 149 | 150 | template 151 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 152 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 153 | const scalar_t* __restrict__ rInput2, 154 | int pad_size, 155 | int kernel_size, 156 | int max_displacement, 157 | int stride1, 158 | int stride2) 159 | { 160 | // n (batch size), c (num of channels), y (height), x (width) 161 | 162 | int n = item; 163 | int y = blockIdx.x * stride1 + pad_size; 164 | int x = blockIdx.y * stride1 + pad_size; 165 | int c = blockIdx.z; 166 | int tch_off = threadIdx.x; 167 | 168 | int kernel_rad = (kernel_size - 1) / 2; 169 | int displacement_rad = max_displacement / stride2; 170 | int displacement_size = 2 * displacement_rad + 1; 171 | 172 | int xmin = (x - kernel_rad - max_displacement) / stride1; 173 | int ymin = (y - kernel_rad - max_displacement) / stride1; 174 | 175 | int xmax = (x + kernel_rad - max_displacement) / stride1; 176 | int ymax = (y + kernel_rad - max_displacement) / stride1; 177 | 178 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 179 | // assumes gradInput1 is pre-allocated and zero filled 180 | return; 181 | } 182 | 183 | if (xmin > xmax || ymin > ymax) { 184 | // assumes gradInput1 is pre-allocated and zero filled 185 | return; 186 | } 187 | 188 | xmin = max(0,xmin); 189 | xmax = min(outputWidth-1,xmax); 190 | 191 | ymin = max(0,ymin); 192 | ymax = min(outputHeight-1,ymax); 193 | 194 | int pInputWidth = inputWidth + 2 * pad_size; 195 | int pInputHeight = inputHeight + 2 * pad_size; 196 | 197 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 198 | int pdimxc = pInputWidth * nInputChannels; 199 | int pdimc = nInputChannels; 200 | 201 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 202 | int tdimyx = outputHeight * outputWidth; 203 | int tdimx = outputWidth; 204 | 205 | int odimcyx = nInputChannels * inputHeight* inputWidth; 206 | int odimyx = inputHeight * inputWidth; 207 | int odimx = inputWidth; 208 | 209 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 210 | 211 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 212 | prod_sum[tch_off] = 0; 213 | 214 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 215 | 216 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 217 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 218 | 219 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 220 | 221 | scalar_t val2 = rInput2[indx2]; 222 | 223 | for (int j = ymin; j <= ymax; ++j) { 224 | for (int i = xmin; i <= xmax; ++i) { 225 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 226 | prod_sum[tch_off] += gradOutput[tindx] * val2; 227 | } 228 | } 229 | } 230 | __syncthreads(); 231 | 232 | if(tch_off == 0) { 233 | scalar_t reduce_sum = 0; 234 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 235 | reduce_sum += prod_sum[idx]; 236 | } 237 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 238 | gradInput1[indx1] = reduce_sum / nelems; 239 | } 240 | 241 | } 242 | 243 | template 244 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 245 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 246 | const scalar_t* __restrict__ rInput1, 247 | int pad_size, 248 | int kernel_size, 249 | int max_displacement, 250 | int stride1, 251 | int stride2) 252 | { 253 | // n (batch size), c (num of channels), y (height), x (width) 254 | 255 | int n = item; 256 | int y = blockIdx.x * stride1 + pad_size; 257 | int x = blockIdx.y * stride1 + pad_size; 258 | int c = blockIdx.z; 259 | 260 | int tch_off = threadIdx.x; 261 | 262 | int kernel_rad = (kernel_size - 1) / 2; 263 | int displacement_rad = max_displacement / stride2; 264 | int displacement_size = 2 * displacement_rad + 1; 265 | 266 | int pInputWidth = inputWidth + 2 * pad_size; 267 | int pInputHeight = inputHeight + 2 * pad_size; 268 | 269 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 270 | int pdimxc = pInputWidth * nInputChannels; 271 | int pdimc = nInputChannels; 272 | 273 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 274 | int tdimyx = outputHeight * outputWidth; 275 | int tdimx = outputWidth; 276 | 277 | int odimcyx = nInputChannels * inputHeight* inputWidth; 278 | int odimyx = inputHeight * inputWidth; 279 | int odimx = inputWidth; 280 | 281 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 282 | 283 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 284 | prod_sum[tch_off] = 0; 285 | 286 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 287 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 288 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 289 | 290 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 291 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 292 | 293 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 294 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 295 | 296 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 297 | // assumes gradInput2 is pre-allocated and zero filled 298 | continue; 299 | } 300 | 301 | if (xmin > xmax || ymin > ymax) { 302 | // assumes gradInput2 is pre-allocated and zero filled 303 | continue; 304 | } 305 | 306 | xmin = max(0,xmin); 307 | xmax = min(outputWidth-1,xmax); 308 | 309 | ymin = max(0,ymin); 310 | ymax = min(outputHeight-1,ymax); 311 | 312 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 313 | scalar_t val1 = rInput1[indx1]; 314 | 315 | for (int j = ymin; j <= ymax; ++j) { 316 | for (int i = xmin; i <= xmax; ++i) { 317 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 318 | prod_sum[tch_off] += gradOutput[tindx] * val1; 319 | } 320 | } 321 | } 322 | 323 | __syncthreads(); 324 | 325 | if(tch_off == 0) { 326 | scalar_t reduce_sum = 0; 327 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 328 | reduce_sum += prod_sum[idx]; 329 | } 330 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 331 | gradInput2[indx2] = reduce_sum / nelems; 332 | } 333 | 334 | } 335 | 336 | int correlation_forward_cuda_kernel(at::Tensor& output, 337 | int ob, 338 | int oc, 339 | int oh, 340 | int ow, 341 | int osb, 342 | int osc, 343 | int osh, 344 | int osw, 345 | 346 | at::Tensor& input1, 347 | int ic, 348 | int ih, 349 | int iw, 350 | int isb, 351 | int isc, 352 | int ish, 353 | int isw, 354 | 355 | at::Tensor& input2, 356 | int gc, 357 | int gsb, 358 | int gsc, 359 | int gsh, 360 | int gsw, 361 | 362 | at::Tensor& rInput1, 363 | at::Tensor& rInput2, 364 | int pad_size, 365 | int kernel_size, 366 | int max_displacement, 367 | int stride1, 368 | int stride2, 369 | int corr_type_multiply, 370 | cudaStream_t stream) 371 | { 372 | 373 | int batchSize = ob; 374 | 375 | int nInputChannels = ic; 376 | int inputWidth = iw; 377 | int inputHeight = ih; 378 | 379 | int nOutputChannels = oc; 380 | int outputWidth = ow; 381 | int outputHeight = oh; 382 | 383 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 384 | dim3 threads_block(THREADS_PER_BLOCK); 385 | 386 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 387 | 388 | channels_first<<>>( 389 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 390 | 391 | })); 392 | 393 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 394 | 395 | channels_first<<>> ( 396 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 397 | 398 | })); 399 | 400 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 401 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 402 | 403 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 404 | 405 | correlation_forward<<>> 406 | (output.data(), nOutputChannels, outputHeight, outputWidth, 407 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 408 | rInput2.data(), 409 | pad_size, 410 | kernel_size, 411 | max_displacement, 412 | stride1, 413 | stride2); 414 | 415 | })); 416 | 417 | cudaError_t err = cudaGetLastError(); 418 | 419 | 420 | // check for errors 421 | if (err != cudaSuccess) { 422 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 423 | return 0; 424 | } 425 | 426 | return 1; 427 | } 428 | 429 | 430 | int correlation_backward_cuda_kernel( 431 | at::Tensor& gradOutput, 432 | int gob, 433 | int goc, 434 | int goh, 435 | int gow, 436 | int gosb, 437 | int gosc, 438 | int gosh, 439 | int gosw, 440 | 441 | at::Tensor& input1, 442 | int ic, 443 | int ih, 444 | int iw, 445 | int isb, 446 | int isc, 447 | int ish, 448 | int isw, 449 | 450 | at::Tensor& input2, 451 | int gsb, 452 | int gsc, 453 | int gsh, 454 | int gsw, 455 | 456 | at::Tensor& gradInput1, 457 | int gisb, 458 | int gisc, 459 | int gish, 460 | int gisw, 461 | 462 | at::Tensor& gradInput2, 463 | int ggc, 464 | int ggsb, 465 | int ggsc, 466 | int ggsh, 467 | int ggsw, 468 | 469 | at::Tensor& rInput1, 470 | at::Tensor& rInput2, 471 | int pad_size, 472 | int kernel_size, 473 | int max_displacement, 474 | int stride1, 475 | int stride2, 476 | int corr_type_multiply, 477 | cudaStream_t stream) 478 | { 479 | 480 | int batchSize = gob; 481 | int num = batchSize; 482 | 483 | int nInputChannels = ic; 484 | int inputWidth = iw; 485 | int inputHeight = ih; 486 | 487 | int nOutputChannels = goc; 488 | int outputWidth = gow; 489 | int outputHeight = goh; 490 | 491 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 492 | dim3 threads_block(THREADS_PER_BLOCK); 493 | 494 | 495 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 496 | 497 | channels_first<<>>( 498 | input1.data(), 499 | rInput1.data(), 500 | nInputChannels, 501 | inputHeight, 502 | inputWidth, 503 | pad_size 504 | ); 505 | })); 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 508 | 509 | channels_first<<>>( 510 | input2.data(), 511 | rInput2.data(), 512 | nInputChannels, 513 | inputHeight, 514 | inputWidth, 515 | pad_size 516 | ); 517 | })); 518 | 519 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 520 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 521 | 522 | for (int n = 0; n < num; ++n) { 523 | 524 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 525 | 526 | 527 | correlation_backward_input1<<>> ( 528 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 529 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 530 | rInput2.data(), 531 | pad_size, 532 | kernel_size, 533 | max_displacement, 534 | stride1, 535 | stride2); 536 | })); 537 | } 538 | 539 | for(int n = 0; n < batchSize; n++) { 540 | 541 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 542 | 543 | correlation_backward_input2<<>>( 544 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 545 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 546 | rInput1.data(), 547 | pad_size, 548 | kernel_size, 549 | max_displacement, 550 | stride1, 551 | stride2); 552 | 553 | })); 554 | } 555 | 556 | // check for errors 557 | cudaError_t err = cudaGetLastError(); 558 | if (err != cudaSuccess) { 559 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 560 | return 0; 561 | } 562 | 563 | return 1; 564 | } 565 | -------------------------------------------------------------------------------- /networks/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /networks/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /networks/resample2d_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/InvertingGANs_with_ConsecutiveImgs/9078a48ec3474dacdd02693b051e3addef1c5697/networks/resample2d_package/__init__.py -------------------------------------------------------------------------------- /networks/resample2d_package/resample2d.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.autograd import Function, Variable 3 | import resample2d_cuda 4 | 5 | class Resample2dFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, input2, kernel_size=1, bilinear= True): 9 | assert input1.is_contiguous() 10 | assert input2.is_contiguous() 11 | 12 | ctx.save_for_backward(input1, input2) 13 | ctx.kernel_size = kernel_size 14 | ctx.bilinear = bilinear 15 | 16 | _, d, _, _ = input1.size() 17 | b, _, h, w = input2.size() 18 | output = input1.new(b, d, h, w).zero_() 19 | 20 | resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear) 21 | 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | grad_output = grad_output.contiguous() 27 | assert grad_output.is_contiguous() 28 | 29 | input1, input2 = ctx.saved_tensors 30 | 31 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 32 | grad_input2 = Variable(input1.new(input2.size()).zero_()) 33 | 34 | resample2d_cuda.backward(input1, input2, grad_output.data, 35 | grad_input1.data, grad_input2.data, 36 | ctx.kernel_size, ctx.bilinear) 37 | 38 | return grad_input1, grad_input2, None, None 39 | 40 | class Resample2d(Module): 41 | 42 | def __init__(self, kernel_size=1, bilinear = True): 43 | super(Resample2d, self).__init__() 44 | self.kernel_size = kernel_size 45 | self.bilinear = bilinear 46 | 47 | def forward(self, input1, input2): 48 | input1_c = input1.contiguous() 49 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear) 50 | -------------------------------------------------------------------------------- /networks/resample2d_package/resample2d_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "resample2d_kernel.cuh" 5 | 6 | int resample2d_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& input2, 9 | at::Tensor& output, 10 | int kernel_size, bool bilinear) { 11 | resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear); 12 | return 1; 13 | } 14 | 15 | int resample2d_cuda_backward( 16 | at::Tensor& input1, 17 | at::Tensor& input2, 18 | at::Tensor& gradOutput, 19 | at::Tensor& gradInput1, 20 | at::Tensor& gradInput2, 21 | int kernel_size, bool bilinear) { 22 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear); 23 | return 1; 24 | } 25 | 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); 30 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /networks/resample2d_package/resample2d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define CUDA_NUM_THREADS 512 6 | #define THREADS_PER_BLOCK 64 7 | 8 | #define DIM0(TENSOR) ((TENSOR).x) 9 | #define DIM1(TENSOR) ((TENSOR).y) 10 | #define DIM2(TENSOR) ((TENSOR).z) 11 | #define DIM3(TENSOR) ((TENSOR).w) 12 | 13 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 14 | 15 | template 16 | __global__ void kernel_resample2d_update_output(const int n, 17 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 18 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 19 | scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size, bool bilinear) { 20 | int index = blockIdx.x * blockDim.x + threadIdx.x; 21 | 22 | if (index >= n) { 23 | return; 24 | } 25 | 26 | scalar_t val = 0.0f; 27 | 28 | int dim_b = DIM0(output_size); 29 | int dim_c = DIM1(output_size); 30 | int dim_h = DIM2(output_size); 31 | int dim_w = DIM3(output_size); 32 | int dim_chw = dim_c * dim_h * dim_w; 33 | int dim_hw = dim_h * dim_w; 34 | 35 | int b = ( index / dim_chw ) % dim_b; 36 | int c = ( index / dim_hw ) % dim_c; 37 | int y = ( index / dim_w ) % dim_h; 38 | int x = ( index ) % dim_w; 39 | 40 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 41 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 42 | 43 | scalar_t xf = static_cast(x) + dx; 44 | scalar_t yf = static_cast(y) + dy; 45 | scalar_t alpha = xf - floor(xf); // alpha 46 | scalar_t beta = yf - floor(yf); // beta 47 | 48 | if (bilinear) { 49 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 50 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 51 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 52 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 53 | 54 | for (int fy = 0; fy < kernel_size; fy += 1) { 55 | for (int fx = 0; fx < kernel_size; fx += 1) { 56 | val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); 57 | val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); 58 | val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); 59 | val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); 60 | } 61 | } 62 | 63 | output[index] = val; 64 | } 65 | else { 66 | int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0); 67 | int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0); 68 | 69 | output[index] = static_cast ( DIM3_INDEX(input1, b, c, yN, xN) ); 70 | } 71 | 72 | } 73 | 74 | 75 | template 76 | __global__ void kernel_resample2d_backward_input1( 77 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 78 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 79 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 80 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) { 81 | 82 | int index = blockIdx.x * blockDim.x + threadIdx.x; 83 | 84 | if (index >= n) { 85 | return; 86 | } 87 | 88 | int dim_b = DIM0(gradOutput_size); 89 | int dim_c = DIM1(gradOutput_size); 90 | int dim_h = DIM2(gradOutput_size); 91 | int dim_w = DIM3(gradOutput_size); 92 | int dim_chw = dim_c * dim_h * dim_w; 93 | int dim_hw = dim_h * dim_w; 94 | 95 | int b = ( index / dim_chw ) % dim_b; 96 | int c = ( index / dim_hw ) % dim_c; 97 | int y = ( index / dim_w ) % dim_h; 98 | int x = ( index ) % dim_w; 99 | 100 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 101 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 102 | 103 | scalar_t xf = static_cast(x) + dx; 104 | scalar_t yf = static_cast(y) + dy; 105 | scalar_t alpha = xf - int(xf); // alpha 106 | scalar_t beta = yf - int(yf); // beta 107 | 108 | int idim_h = DIM2(input1_size); 109 | int idim_w = DIM3(input1_size); 110 | 111 | int xL = max(min( int (floor(xf)), idim_w-1), 0); 112 | int xR = max(min( int (floor(xf)+1), idim_w -1), 0); 113 | int yT = max(min( int (floor(yf)), idim_h-1), 0); 114 | int yB = max(min( int (floor(yf)+1), idim_h-1), 0); 115 | 116 | for (int fy = 0; fy < kernel_size; fy += 1) { 117 | for (int fx = 0; fx < kernel_size; fx += 1) { 118 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 119 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 120 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 121 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 122 | } 123 | } 124 | 125 | } 126 | 127 | template 128 | __global__ void kernel_resample2d_backward_input2( 129 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 130 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 131 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 132 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) { 133 | 134 | int index = blockIdx.x * blockDim.x + threadIdx.x; 135 | 136 | if (index >= n) { 137 | return; 138 | } 139 | 140 | scalar_t output = 0.0; 141 | int kernel_rad = (kernel_size - 1)/2; 142 | 143 | int dim_b = DIM0(gradInput_size); 144 | int dim_c = DIM1(gradInput_size); 145 | int dim_h = DIM2(gradInput_size); 146 | int dim_w = DIM3(gradInput_size); 147 | int dim_chw = dim_c * dim_h * dim_w; 148 | int dim_hw = dim_h * dim_w; 149 | 150 | int b = ( index / dim_chw ) % dim_b; 151 | int c = ( index / dim_hw ) % dim_c; 152 | int y = ( index / dim_w ) % dim_h; 153 | int x = ( index ) % dim_w; 154 | 155 | int odim_c = DIM1(gradOutput_size); 156 | 157 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 158 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 159 | 160 | scalar_t xf = static_cast(x) + dx; 161 | scalar_t yf = static_cast(y) + dy; 162 | 163 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 164 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 165 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 166 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 167 | 168 | if (c % 2) { 169 | float gamma = 1 - (xf - floor(xf)); // alpha 170 | for (int i = 0; i <= 2*kernel_rad; ++i) { 171 | for (int j = 0; j <= 2*kernel_rad; ++j) { 172 | for (int ch = 0; ch < odim_c; ++ch) { 173 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 174 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 175 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 176 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 177 | } 178 | } 179 | } 180 | } 181 | else { 182 | float gamma = 1 - (yf - floor(yf)); // alpha 183 | for (int i = 0; i <= 2*kernel_rad; ++i) { 184 | for (int j = 0; j <= 2*kernel_rad; ++j) { 185 | for (int ch = 0; ch < odim_c; ++ch) { 186 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 187 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 188 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 189 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 190 | } 191 | } 192 | } 193 | 194 | } 195 | 196 | gradInput[index] = output; 197 | 198 | } 199 | 200 | void resample2d_kernel_forward( 201 | at::Tensor& input1, 202 | at::Tensor& input2, 203 | at::Tensor& output, 204 | int kernel_size, 205 | bool bilinear) { 206 | 207 | int n = output.numel(); 208 | 209 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 210 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 211 | 212 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 213 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 214 | 215 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 216 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 217 | 218 | // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF 219 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { 220 | 221 | kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 222 | //at::globalContext().getCurrentCUDAStream() >>>( 223 | n, 224 | input1.data(), 225 | input1_size, 226 | input1_stride, 227 | input2.data(), 228 | input2_size, 229 | input2_stride, 230 | output.data(), 231 | output_size, 232 | output_stride, 233 | kernel_size, 234 | bilinear); 235 | 236 | // })); 237 | 238 | // TODO: ATen-equivalent check 239 | 240 | // THCudaCheck(cudaGetLastError()); 241 | 242 | } 243 | 244 | void resample2d_kernel_backward( 245 | at::Tensor& input1, 246 | at::Tensor& input2, 247 | at::Tensor& gradOutput, 248 | at::Tensor& gradInput1, 249 | at::Tensor& gradInput2, 250 | int kernel_size, 251 | bool bilinear) { 252 | 253 | int n = gradOutput.numel(); 254 | 255 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 256 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 257 | 258 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 259 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 260 | 261 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 262 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 263 | 264 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 265 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 266 | 267 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { 268 | 269 | kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 270 | //at::globalContext().getCurrentCUDAStream() >>>( 271 | n, 272 | input1.data(), 273 | input1_size, 274 | input1_stride, 275 | input2.data(), 276 | input2_size, 277 | input2_stride, 278 | gradOutput.data(), 279 | gradOutput_size, 280 | gradOutput_stride, 281 | gradInput1.data(), 282 | gradInput1_size, 283 | gradInput1_stride, 284 | kernel_size, 285 | bilinear 286 | ); 287 | 288 | // })); 289 | 290 | const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); 291 | const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); 292 | 293 | n = gradInput2.numel(); 294 | 295 | // AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { 296 | 297 | 298 | kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 299 | //at::globalContext().getCurrentCUDAStream() >>>( 300 | n, 301 | input1.data(), 302 | input1_size, 303 | input1_stride, 304 | input2.data(), 305 | input2_size, 306 | input2_stride, 307 | gradOutput.data(), 308 | gradOutput_size, 309 | gradOutput_stride, 310 | gradInput2.data(), 311 | gradInput2_size, 312 | gradInput2_stride, 313 | kernel_size, 314 | bilinear 315 | ); 316 | 317 | // })); 318 | 319 | // TODO: Use the ATen equivalent to get last error 320 | 321 | // THCudaCheck(cudaGetLastError()); 322 | 323 | } 324 | -------------------------------------------------------------------------------- /networks/resample2d_package/resample2d_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void resample2d_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& input2, 8 | at::Tensor& output, 9 | int kernel_size, 10 | bool bilinear); 11 | 12 | void resample2d_kernel_backward( 13 | at::Tensor& input1, 14 | at::Tensor& input2, 15 | at::Tensor& gradOutput, 16 | at::Tensor& gradInput1, 17 | at::Tensor& gradInput2, 18 | int kernel_size, 19 | bool bilinear); -------------------------------------------------------------------------------- /networks/resample2d_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='resample2d_cuda', 21 | ext_modules=[ 22 | CUDAExtension('resample2d_cuda', [ 23 | 'resample2d_cuda.cc', 24 | 'resample2d_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /networks/submodules.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 8 | if batchNorm: 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), 11 | nn.BatchNorm2d(out_planes), 12 | nn.LeakyReLU(0.1,inplace=True) 13 | ) 14 | else: 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 17 | nn.LeakyReLU(0.1,inplace=True) 18 | ) 19 | 20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): 21 | if batchNorm: 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 29 | ) 30 | 31 | def predict_flow(in_planes): 32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 33 | 34 | def deconv(in_planes, out_planes): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 37 | nn.LeakyReLU(0.1,inplace=True) 38 | ) 39 | 40 | class tofp16(nn.Module): 41 | def __init__(self): 42 | super(tofp16, self).__init__() 43 | 44 | def forward(self, input): 45 | return input.half() 46 | 47 | 48 | class tofp32(nn.Module): 49 | def __init__(self): 50 | super(tofp32, self).__init__() 51 | 52 | def forward(self, input): 53 | return input.float() 54 | 55 | 56 | def init_deconv_bilinear(weight): 57 | f_shape = weight.size() 58 | heigh, width = f_shape[-2], f_shape[-1] 59 | f = np.ceil(width/2.0) 60 | c = (2 * f - 1 - f % 2) / (2.0 * f) 61 | bilinear = np.zeros([heigh, width]) 62 | for x in range(width): 63 | for y in range(heigh): 64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) 65 | bilinear[x, y] = value 66 | weight.data.fill_(0.) 67 | for i in range(f_shape[0]): 68 | for j in range(f_shape[1]): 69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear) 70 | 71 | 72 | def save_grad(grads, name): 73 | def hook(grad): 74 | grads[name] = grad 75 | return hook 76 | 77 | ''' 78 | def save_grad(grads, name): 79 | def hook(grad): 80 | grads[name] = grad 81 | return hook 82 | import torch 83 | from channelnorm_package.modules.channelnorm import ChannelNorm 84 | model = ChannelNorm().cuda() 85 | grads = {} 86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True) 87 | a.register_hook(save_grad(grads, 'a')) 88 | b = model(a) 89 | y = torch.mean(b) 90 | y.backward() 91 | 92 | ''' 93 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /perceptual_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import torch.nn as nn 4 | 5 | 6 | class VGG16_for_Perceptual(torch.nn.Module): 7 | def __init__(self,requires_grad=False,n_layers=[2,4,14,21]): 8 | super(VGG16_for_Perceptual,self).__init__() 9 | vgg_pretrained_features=models.vgg16(pretrained=True).features 10 | 11 | self.slice0=torch.nn.Sequential() 12 | self.slice1=torch.nn.Sequential() 13 | self.slice2=torch.nn.Sequential() 14 | self.slice3=torch.nn.Sequential() 15 | 16 | for x in range(n_layers[0]):#relu1_1 17 | self.slice0.add_module(str(x),vgg_pretrained_features[x]) 18 | for x in range(n_layers[0],n_layers[1]): #relu1_2 19 | self.slice1.add_module(str(x),vgg_pretrained_features[x]) 20 | for x in range(n_layers[1],n_layers[2]): #relu3_2 21 | self.slice2.add_module(str(x),vgg_pretrained_features[x]) 22 | 23 | for x in range(n_layers[2],n_layers[3]):#relu4_2 24 | self.slice3.add_module(str(x),vgg_pretrained_features[x]) 25 | 26 | 27 | if not requires_grad: 28 | for param in self.parameters(): 29 | param.requires_grad=False 30 | 31 | 32 | 33 | def forward(self,x): 34 | h0=self.slice0(x) 35 | h1=self.slice1(h0) 36 | h2=self.slice2(h1) 37 | h3=self.slice3(h2) 38 | 39 | return h0,h1,h2,h3 40 | 41 | 42 | -------------------------------------------------------------------------------- /read_image.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from torchvision import transforms 3 | from PIL import Image 4 | 5 | def image_reader(img_path,resize=None): 6 | 7 | with open(img_path,"rb") as f: 8 | image=Image.open(f) 9 | image=image.convert("RGB") 10 | if resize!=None: 11 | image=image.resize((resize,resize)) 12 | transform = transforms.Compose([ 13 | transforms.ToTensor() 14 | ]) 15 | 16 | image = transform(image) 17 | #print (image.shape) 18 | 19 | image=image.unsqueeze(0) 20 | 21 | return image 22 | -------------------------------------------------------------------------------- /stylegan.yaml: -------------------------------------------------------------------------------- 1 | name: stylegan 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 10 | - defaults 11 | dependencies: 12 | - _libgcc_mutex=0.1=conda_forge 13 | - _openmp_mutex=4.5=1_gnu 14 | - ca-certificates=2020.12.5=ha878542_0 15 | - certifi=2020.12.5=py37h89c1867_1 16 | - ld_impl_linux-64=2.35.1=hea4e1c9_1 17 | - libffi=3.3=h58526e2_2 18 | - libgcc-ng=9.3.0=h2828fa1_18 19 | - libgomp=9.3.0=h2828fa1_18 20 | - libstdcxx-ng=9.3.0=h6de172a_18 21 | - ncurses=6.2=h58526e2_4 22 | - openssl=1.1.1i=h7f98852_0 23 | - pip=21.0=pyhd8ed1ab_0 24 | - python=3.7.9=hffdb5ce_0_cpython 25 | - python_abi=3.7=1_cp37m 26 | - readline=8.0=he28a2e2_2 27 | - setuptools=49.6.0=py37h89c1867_3 28 | - sqlite=3.34.0=h74cdb3f_0 29 | - tk=8.6.10=h21135ba_1 30 | - wheel=0.36.2=pyhd3deb0d_0 31 | - xz=5.2.5=h516909a_1 32 | - zlib=1.2.11=h516909a_1010 33 | - pip: 34 | - astroid==2.4.2 35 | - dcnv2==0.1 36 | - deepdish==0.3.6 37 | - isort==5.7.0 38 | - lazy-object-proxy==1.4.3 39 | - mccabe==0.6.1 40 | - numexpr==2.7.2 41 | - numpy==1.19.5 42 | - opencv-python==4.5.1.48 43 | - pillow==8.1.0 44 | - pylint==2.6.0 45 | - scipy==1.6.0 46 | - six==1.15.0 47 | - tables==3.6.1 48 | - toml==0.10.2 49 | - torch==1.7.1+cu110 50 | - torchaudio==0.7.2 51 | - torchvision==0.8.2+cu110 52 | - tqdm==4.56.0 53 | - typed-ast==1.4.2 54 | - typing-extensions==3.7.4.3 55 | - wrapt==1.12.1 56 | prefix: /home/ubuntu/anaconda3/envs/stylegan 57 | 58 | --------------------------------------------------------------------------------