├── .gitignore ├── README.md ├── datas ├── apple │ ├── trainA │ │ └── real_scale.png │ └── trainB │ │ └── real_scale2.png └── art │ ├── trainA │ └── real_scale.png │ └── trainB │ └── real_scale2.png ├── imgs ├── a ├── animalface.jpg ├── artstyle.jpg ├── comparisons.jpg ├── dog.jpg ├── examples.jpg ├── photostyle.jpg ├── style.jpg └── trees.jpg ├── models ├── TuiGAN.py └── model.py ├── options └── config.py ├── requirements.txt ├── train.py └── utils ├── functions.py ├── imresize.py └── manipulate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't track content of these folders 2 | Output/ 3 | TrainedModels/ 4 | __pycache__/ 5 | .idea 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TuiGAN-PyTorch 2 | Official PyTorch Implementation of "[TuiGAN: Learning Versatile Image-to-Image Translation with Two Unpaired Images](https://arxiv.org/abs/2004.04634)" (ECCV 2020 Spotlight) 3 | 4 | ## TuiGAN's applications 5 | TuiGAN can be use for various computer vision tasks ranging from image style transfer to object transformation and appearance transformation: 6 | ![](imgs/examples.jpg) 7 | 8 | ## Usage 9 | 10 | ### Install dependencies 11 | 12 | ``` 13 | python -m pip install -r requirements.txt 14 | ``` 15 | 16 | Our code was tested with python 3.6 and PyToch 1.0.0 or 1.2.0 17 | 18 | ### Train 19 | To train TuiGAN model on two unpaired images, put the first training image under `datas/task_name/trainA` and the second training image under `datas/task_name/trainB`, and run 20 | 21 | ``` 22 | python train.py --input_name --root 23 | ``` 24 | For example, 25 | ``` 26 | python train.py --input_name apple --root datas/apple 27 | ``` 28 | ## Comparison Results 29 | 30 | ### General Unsupervised Image-to-Image Translation 31 | ![](imgs/comparisons.jpg) 32 | ### Image Style Transfer 33 | ![](imgs/style.jpg) 34 | ### Animal Face Translation 35 | ![](imgs/dog.jpg) 36 | ### Painting-to-Image Translation 37 | ![](imgs/trees.jpg) 38 | 39 | ## More Results 40 | 41 | ### Art Style Transfer 42 | 43 | 44 | ### Photorealistic Style Transfer 45 | ![](imgs/photostyle.jpg) 46 | 47 | ### Animal Face Translation 48 | 49 | 50 | -------------------------------------------------------------------------------- /datas/apple/trainA/real_scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/datas/apple/trainA/real_scale.png -------------------------------------------------------------------------------- /datas/apple/trainB/real_scale2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/datas/apple/trainB/real_scale2.png -------------------------------------------------------------------------------- /datas/art/trainA/real_scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/datas/art/trainA/real_scale.png -------------------------------------------------------------------------------- /datas/art/trainB/real_scale2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/datas/art/trainB/real_scale2.png -------------------------------------------------------------------------------- /imgs/a: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /imgs/animalface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/animalface.jpg -------------------------------------------------------------------------------- /imgs/artstyle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/artstyle.jpg -------------------------------------------------------------------------------- /imgs/comparisons.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/comparisons.jpg -------------------------------------------------------------------------------- /imgs/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/dog.jpg -------------------------------------------------------------------------------- /imgs/examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/examples.jpg -------------------------------------------------------------------------------- /imgs/photostyle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/photostyle.jpg -------------------------------------------------------------------------------- /imgs/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/style.jpg -------------------------------------------------------------------------------- /imgs/trees.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/TuiGAN-PyTorch/a63f06ed043766bc3c9eda0e5d56d06ce92cdde4/imgs/trees.jpg -------------------------------------------------------------------------------- /models/TuiGAN.py: -------------------------------------------------------------------------------- 1 | import utils.functions as functions 2 | import models.model as models 3 | import os 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import math 8 | import matplotlib.pyplot as plt 9 | from utils.imresize import imresize 10 | import itertools 11 | from torchvision.utils import save_image 12 | from torchvision import transforms as T 13 | from PIL import Image 14 | import numpy as np 15 | 16 | 17 | def denorm(x): 18 | """Convert the range from [-1, 1] to [0, 1].""" 19 | out = (x + 1) / 2 20 | return out.clamp_(0, 1) 21 | 22 | 23 | class TVLoss(nn.Module): 24 | def __init__(self,TVLoss_weight=1): 25 | super(TVLoss,self).__init__() 26 | self.TVLoss_weight = TVLoss_weight 27 | 28 | def forward(self,x): 29 | batch_size = x.size()[0] 30 | h_x = x.size()[2] 31 | w_x = x.size()[3] 32 | count_h = self._tensor_size(x[:,:,1:,:]) 33 | count_w = self._tensor_size(x[:,:,:,1:]) 34 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 35 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 36 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 37 | 38 | def _tensor_size(self,t): 39 | return t.size()[1]*t.size()[2]*t.size()[3] 40 | 41 | 42 | def train(opt,Gs,Zs,reals,NoiseAmp, Gs2,Zs2,reals2,NoiseAmp2): 43 | real_, real_2 = functions.read_two_domains(opt) 44 | in_s = 0 45 | in_s2 = 0 46 | scale_num = 0 47 | real = imresize(real_,opt.scale1,opt) 48 | real2 = imresize(real_2,opt.scale1,opt) 49 | reals = functions.creat_reals_pyramid(real,reals,opt) 50 | reals2 = functions.creat_reals_pyramid(real2,reals2,opt) 51 | nfc_prev = 0 52 | 53 | while scale_num 0: 287 | if mode == 'rec': 288 | count = 0 289 | for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): 290 | G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] 291 | G_z = m_image(G_z) 292 | z_in = m_image(real_curr) 293 | G_z = G(z_in.detach(),G_z) 294 | G_z = imresize(G_z.detach(),1/opt.scale_factor,opt) 295 | G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]] 296 | count += 1 297 | return G_z 298 | 299 | 300 | def cycle_rec(Gs,Gs2,Zs,reals,NoiseAmp,in_s,m_noise,m_image,opt, epoch): 301 | x_ab = in_s 302 | x_aba = in_s 303 | if len(Gs) > 0: 304 | count = 0 305 | for G,G2,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Gs2,Zs,reals,reals[1:],NoiseAmp): 306 | z = functions.generate_noise([3, Z_opt.shape[2] , Z_opt.shape[3] ], device=opt.device) 307 | z = z.expand(opt.bsz, 3, z.shape[2], z.shape[3]) 308 | z = m_noise(z) 309 | x_ab = x_ab[:,:,0:real_curr.shape[2],0:real_curr.shape[3]] 310 | x_ab = m_image(x_ab) 311 | z_in = noise_amp*z+m_image(real_curr) 312 | x_ab = G(z_in.detach(),x_ab) 313 | 314 | x_aba = G2(x_ab,x_aba) 315 | x_ab = imresize(x_ab.detach(),1/opt.scale_factor,opt) 316 | x_ab = x_ab[:,:,0:real_next.shape[2],0:real_next.shape[3]] 317 | x_aba = imresize(x_aba.detach(),1/opt.scale_factor,opt) 318 | x_aba = x_aba[:,:,0:real_next.shape[2],0:real_next.shape[3]] 319 | count += 1 320 | return x_ab, x_aba 321 | 322 | 323 | def init_models(opt): 324 | # generator initialization 325 | netG = models.GeneratorConcatSkip2CleanAddAlpha(opt).to(opt.device) 326 | netG.apply(models.weights_init) 327 | if opt.netG != '': 328 | netG.load_state_dict(torch.load(opt.netG)) 329 | print(netG) 330 | 331 | # discriminator initialization 332 | netD = models.WDiscriminator(opt).to(opt.device) 333 | netD.apply(models.weights_init) 334 | if opt.netD != '': 335 | netD.load_state_dict(torch.load(opt.netD)) 336 | print(netD) 337 | 338 | # generator2 initialization 339 | netG2 = models.GeneratorConcatSkip2CleanAddAlpha(opt).to(opt.device) 340 | netG2.apply(models.weights_init) 341 | if opt.netG2 != '': 342 | netG2.load_state_dict(torch.load(opt.netG2)) 343 | print(netG2) 344 | 345 | # discriminator2 initialization 346 | netD2 = models.WDiscriminator(opt).to(opt.device) 347 | netD2.apply(models.weights_init) 348 | if opt.netD2 != '': 349 | netD2.load_state_dict(torch.load(opt.netD2)) 350 | print(netD2) 351 | return netD, netG, netD2, netG2 352 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class ConvBlock(nn.Sequential): 8 | def __init__(self, in_channel, out_channel, ker_size, padd, stride): 9 | super(ConvBlock,self).__init__() 10 | self.add_module('conv',nn.Conv2d(in_channel ,out_channel,kernel_size=ker_size,stride=stride,padding=padd)), 11 | self.add_module('norm',nn.BatchNorm2d(out_channel)), 12 | self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True)) 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv2d') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find('Norm') != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | 22 | class WDiscriminator(nn.Module): 23 | def __init__(self, opt): 24 | super(WDiscriminator, self).__init__() 25 | self.is_cuda = torch.cuda.is_available() 26 | N = int(opt.nfc) 27 | self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) 28 | self.body = nn.Sequential() 29 | for i in range(opt.num_layer-2): 30 | N = int(opt.nfc/pow(2,(i+1))) 31 | block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1) 32 | self.body.add_module('block%d'%(i+1),block) 33 | self.tail = nn.Conv2d(max(N,opt.min_nfc),1,kernel_size=opt.ker_size,stride=1,padding=opt.padd_size) 34 | 35 | def forward(self,x): 36 | x = self.head(x) 37 | x = self.body(x) 38 | x = self.tail(x) 39 | return x 40 | 41 | class GeneratorConcatSkip2CleanAddAlpha(nn.Module): 42 | def __init__(self, opt): 43 | super(GeneratorConcatSkip2CleanAddAlpha, self).__init__() 44 | self.is_cuda = torch.cuda.is_available() 45 | N = opt.nfc 46 | self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) 47 | self.body = nn.Sequential() 48 | for i in range(opt.num_layer-2): 49 | N = int(opt.nfc/pow(2,(i+1))) 50 | block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1) 51 | self.body.add_module('block%d'%(i+1),block) 52 | self.tail = nn.Sequential( 53 | nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size), 54 | nn.Tanh() 55 | ) 56 | N = opt.nfc 57 | self.head2 = ConvBlock(opt.nc_im*3,N,opt.ker_size,opt.padd_size,1) 58 | self.body2 = nn.Sequential() 59 | for i in range(2): 60 | N = int(opt.nfc/pow(2,(i+1))) 61 | block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1) 62 | self.body.add_module('block%d'%(i+1),block) 63 | self.tail2 = nn.Sequential( 64 | nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size), 65 | nn.Sigmoid() 66 | ) 67 | def forward(self,x,y2): 68 | y1 = self.head(x) 69 | y1 = self.body(y1) 70 | y1 = self.tail(y1) 71 | ind = int((y2.shape[2]-y1.shape[2])/2) 72 | y2 = y2[:,:,ind:(y2.shape[2]-ind),ind:(y2.shape[3]-ind)] 73 | x_c = torch.cat((x,y1,y2),1) 74 | 75 | a1 = self.head2(x_c) 76 | a1 = self.body2(a1) 77 | a1 = self.tail2(a1) 78 | 79 | return a1*y2 + (1-a1)*y1 80 | -------------------------------------------------------------------------------- /options/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_arguments(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--not_cuda', action='store_true', help='disables cuda', default=0) 7 | 8 | #load, input, save configurations: 9 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 10 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 11 | parser.add_argument('--netG2', default='', help="path to netG2 (to continue training)") 12 | parser.add_argument('--netD2', default='', help="path to netD2 (to continue training)") 13 | parser.add_argument('--manualSeed', type=int, help='manual seed') 14 | parser.add_argument('--nc_z',type=int,help='noise # channels',default=3) 15 | parser.add_argument('--nc_im',type=int,help='image # channels',default=3) 16 | parser.add_argument('--num_shot',type=int,help='shot number',default=5) 17 | parser.add_argument('--out',help='output folder',default='Results') 18 | 19 | #networks hyper parameters: 20 | parser.add_argument('--nfc', type=int, default=32) 21 | parser.add_argument('--min_nfc', type=int, default=32) 22 | parser.add_argument('--ker_size',type=int,help='kernel size',default=3) 23 | parser.add_argument('--num_layer',type=int,help='number of layers',default=5) 24 | parser.add_argument('--stride',help='stride',default=1) 25 | parser.add_argument('--padd_size',type=int,help='net pad size',default=1) 26 | 27 | #pyramid parameters: 28 | parser.add_argument('--scale_factor',type=float,help='pyramid scale factor',default=0.75) 29 | parser.add_argument('--noise_amp',type=float,help='addative noise cont weight',default=0.1) 30 | parser.add_argument('--noise_amp2',type=float,help='addative noise cont weight',default=0.1) 31 | parser.add_argument('--min_size',type=int,help='image minimal size at the coarser scale',default=100) 32 | parser.add_argument('--max_size', type=int,help='image minimal size at the coarser scale', default=250) 33 | 34 | #optimization hyper parameters: 35 | parser.add_argument('--niter', type=int, default=4000, help='number of epochs to train per scale') 36 | parser.add_argument('--gamma',type=float,help='scheduler gamma',default=0.1) 37 | parser.add_argument('--lr_g', type=float, default=0.0005, help='learning rate, default=0.0005') 38 | parser.add_argument('--lr_d', type=float, default=0.0005, help='learning rate, default=0.0005') 39 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 40 | parser.add_argument('--Gsteps',type=int, help='Generator inner steps',default=3) 41 | parser.add_argument('--Dsteps',type=int, help='Discriminator inner steps',default=3) 42 | parser.add_argument('--lambda_grad',type=float, help='gradient penelty weight',default=0.1) 43 | parser.add_argument('--lambda_idt',type=float, help='reconstruction loss weight',default=1) 44 | parser.add_argument('--lambda_cyc',type=float, help='reconstruction loss weight',default=1) 45 | parser.add_argument('--lambda_per',type=float, help='reconstruction loss weight',default=1) 46 | parser.add_argument('--lambda_tv',type=float, help='reconstruction loss weight',default=0.1) 47 | 48 | return parser 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | scikit-image 3 | scikit-learn 4 | scipy 5 | numpy 6 | torch 7 | torchvision 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from options.config import get_arguments 2 | from utils.manipulate import * 3 | from models.TuiGAN import * 4 | import utils.functions as functions 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = get_arguments() 9 | parser.add_argument('--root', help='input image dir', default='datas/apple') 10 | parser.add_argument('--input_name', help='input image name', required=True) 11 | parser.add_argument('--mode', help='task to be done', default='train') 12 | opt = parser.parse_args() 13 | opt = functions.post_config(opt) 14 | Gs = [] 15 | Zs = [] 16 | reals = [] 17 | NoiseAmp = [] 18 | Gs2 = [] 19 | Zs2 = [] 20 | reals2 = [] 21 | NoiseAmp2 = [] 22 | dir2save = functions.generate_dir2save(opt) 23 | 24 | try: 25 | os.makedirs(dir2save) 26 | except OSError: 27 | pass 28 | realA, realB = functions.read_two_domains(opt) 29 | functions.adjust_scales2image(realA, opt) 30 | train(opt, Gs, Zs, reals, NoiseAmp, Gs2, Zs2, reals2, NoiseAmp2) 31 | TuiGAN_generate(Gs, Zs, reals, NoiseAmp, Gs2, Zs2, reals2, NoiseAmp2, opt) 32 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | import numpy as np 5 | import torch.nn as nn 6 | import scipy.io as sio 7 | import math 8 | from skimage import io as img 9 | from skimage import color, morphology, filters 10 | from utils.imresize import imresize 11 | import os 12 | import random 13 | from sklearn.cluster import KMeans 14 | from glob import glob 15 | 16 | 17 | 18 | def denorm(x): 19 | out = (x + 1) / 2 20 | return out.clamp(0, 1) 21 | 22 | def norm(x): 23 | out = (x -0.5) *2 24 | return out.clamp(-1, 1) 25 | 26 | 27 | def convert_image_np(inp): 28 | if inp.shape[1]==3: 29 | inp = denorm(inp) 30 | inp = move_to_cpu(inp[-1,:,:,:]) 31 | inp = inp.numpy().transpose((1,2,0)) 32 | else: 33 | inp = denorm(inp) 34 | inp = move_to_cpu(inp[-1,-1,:,:]) 35 | inp = inp.numpy().transpose((0,1)) 36 | 37 | inp = np.clip(inp,0,1) 38 | return inp 39 | 40 | 41 | def save_image(real_cpu,receptive_feild,ncs,epoch_num,file_name): 42 | fig,ax = plt.subplots(1) 43 | if ncs==1: 44 | ax.imshow(real_cpu.view(real_cpu.size(2),real_cpu.size(3)),cmap='gray') 45 | else: 46 | ax.imshow(convert_image_np(real_cpu.cpu())) 47 | rect = patches.Rectangle((0,0),receptive_feild,receptive_feild,linewidth=5,edgecolor='r',facecolor='none') 48 | ax.add_patch(rect) 49 | ax.axis('off') 50 | plt.savefig(file_name) 51 | plt.close(fig) 52 | 53 | 54 | def convert_image_np_2d(inp): 55 | inp = denorm(inp) 56 | inp = inp.numpy() 57 | return inp 58 | 59 | 60 | def generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1): 61 | if type == 'gaussian': 62 | noise = torch.randn(num_samp, size[0], round(size[1]/scale), round(size[2]/scale), device=device) 63 | noise = upsampling(noise,size[1], size[2]) 64 | if type =='gaussian_mixture': 65 | noise1 = torch.randn(num_samp, size[0], size[1], size[2], device=device)+5 66 | noise2 = torch.randn(num_samp, size[0], size[1], size[2], device=device) 67 | noise = noise1+noise2 68 | if type == 'uniform': 69 | noise = torch.randn(num_samp, size[0], size[1], size[2], device=device) 70 | return noise 71 | 72 | 73 | def upsampling(im,sx,sy): 74 | m = nn.Upsample(size=[round(sx),round(sy)],mode='bilinear',align_corners=True) 75 | return m(im) 76 | 77 | 78 | def reset_grads(model,require_grad): 79 | for p in model.parameters(): 80 | p.requires_grad_(require_grad) 81 | return model 82 | 83 | 84 | def move_to_gpu(t): 85 | if (torch.cuda.is_available()): 86 | t = t.to(torch.device('cuda')) 87 | return t 88 | 89 | 90 | def move_to_cpu(t): 91 | t = t.to(torch.device('cpu')) 92 | return t 93 | 94 | def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA, device): 95 | alpha = torch.rand(1, 1) 96 | alpha = alpha.expand(real_data.size()) 97 | alpha = alpha.to(device) 98 | 99 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 100 | interpolates = interpolates.to(device) 101 | interpolates = torch.autograd.Variable(interpolates, requires_grad=True) 102 | disc_interpolates = netD(interpolates) 103 | 104 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 105 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 106 | create_graph=True, retain_graph=True, only_inputs=True)[0] 107 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA 108 | return gradient_penalty 109 | 110 | 111 | def read_two_domains(opt): 112 | paths_trainA = glob(os.path.join(opt.root, 'trainA/*')) 113 | paths_trainB = glob(os.path.join(opt.root, 'trainB/*')) 114 | for i in range(len(paths_trainA)): 115 | x = img.imread(paths_trainA[i]) 116 | x = np2torch(x, opt) 117 | x = x[:, 0:3, :, :] 118 | if i == 0: 119 | trainA = x 120 | else: 121 | trainA = torch.cat((trainA,x),0) 122 | for i in range(len(paths_trainB)): 123 | y = img.imread(paths_trainB[i]) 124 | y = np2torch(y,opt) 125 | y = y[:,0:3,:,:] 126 | if i== 0: 127 | trainB = y 128 | else: 129 | trainB = torch.cat((trainB,y),0) 130 | return trainA, trainB 131 | 132 | 133 | def np2torch(x,opt): 134 | if opt.nc_im == 3: 135 | x = x[:,:,:,None] 136 | x = x.transpose((3, 2, 0, 1))/255 137 | else: 138 | x = color.rgb2gray(x) 139 | x = x[:,:,None,None] 140 | x = x.transpose(3, 2, 0, 1) 141 | x = torch.from_numpy(x) 142 | if not(opt.not_cuda): 143 | x = move_to_gpu(x) 144 | x = x.type(torch.cuda.FloatTensor) if not(opt.not_cuda) else x.type(torch.FloatTensor) 145 | x = norm(x) 146 | return x 147 | 148 | def torch2uint8(x): 149 | x = x[0,:,:,:] 150 | x = x.permute((1,2,0)) 151 | x = 255*denorm(x) 152 | x = x.cpu().numpy() 153 | x = x.astype(np.uint8) 154 | return x 155 | 156 | def read_image2np(opt): 157 | x = img.imread('%s/%s' % (opt.input_dir,opt.input_name)) 158 | x = x[:, :, 0:3] 159 | return x 160 | 161 | def save_networks(netG,netD,z, netG2,netD2,z2, opt): 162 | torch.save(netG.state_dict(), '%s/netG.pth' % (opt.outf)) 163 | torch.save(netD.state_dict(), '%s/netD.pth' % (opt.outf)) 164 | torch.save(z, '%s/z_opt.pth' % (opt.outf)) 165 | 166 | torch.save(netG2.state_dict(), '%s/netG2.pth' % (opt.outf)) 167 | torch.save(netD2.state_dict(), '%s/netD2.pth' % (opt.outf)) 168 | torch.save(z2, '%s/z2_opt.pth' % (opt.outf)) 169 | 170 | def adjust_scales2image(real_,opt): 171 | opt.num_scales = math.ceil((math.log(math.pow(opt.min_size / (min(real_.shape[2], real_.shape[3])), 1), opt.scale_factor_init))) + 1 172 | scale2stop = math.ceil(math.log(min([opt.max_size, max([real_.shape[2], real_.shape[3]])]) / max([real_.shape[2], real_.shape[3]]),opt.scale_factor_init)) 173 | opt.stop_scale = opt.num_scales - scale2stop 174 | opt.scale1 = min(opt.max_size / max([real_.shape[2], real_.shape[3]]),1) 175 | real = imresize(real_, opt.scale1, opt) 176 | opt.scale_factor = math.pow(opt.min_size/(min(real.shape[2],real.shape[3])),1/(opt.stop_scale)) 177 | scale2stop = math.ceil(math.log(min([opt.max_size, max([real_.shape[2], real_.shape[3]])]) / max([real_.shape[2], real_.shape[3]]),opt.scale_factor_init)) 178 | opt.stop_scale = opt.num_scales - scale2stop 179 | return real 180 | 181 | 182 | 183 | def creat_reals_pyramid(real,reals,opt): 184 | real = real[:,0:3,:,:] 185 | for i in range(0,opt.stop_scale+1,1): 186 | scale = math.pow(opt.scale_factor,opt.stop_scale-i) 187 | curr_real = imresize(real,scale,opt) 188 | reals.append(curr_real) 189 | return reals 190 | 191 | 192 | def load_trained_pyramid(opt, mode_='train'): 193 | mode = opt.mode 194 | opt.mode = 'train' 195 | dir = generate_dir2save(opt) 196 | 197 | if os.path.exists(dir): 198 | Gs = torch.load('%s/Gs.pth' % dir) 199 | Zs = torch.load('%s/Zs.pth' % dir) 200 | reals = torch.load('%s/reals.pth' % dir) 201 | NoiseAmp = torch.load('%s/NoiseAmp.pth' % dir) 202 | else: 203 | print('no appropriate trained model is exist, please train first') 204 | opt.mode = mode 205 | return Gs,Zs,reals,NoiseAmp 206 | 207 | 208 | def load_trained_two_pyramid(opt, mode_='train'): 209 | mode = opt.mode 210 | opt.mode = 'train' 211 | dir = generate_dir2save(opt) 212 | if(os.path.exists(dir)): 213 | Gs = torch.load('%s/Gs.pth' % dir) 214 | Zs = torch.load('%s/Zs.pth' % dir) 215 | reals = torch.load('%s/reals.pth' % dir) 216 | NoiseAmp = torch.load('%s/NoiseAmp.pth' % dir) 217 | Gs2 = torch.load('%s/Gs2.pth' % dir) 218 | Zs2 = torch.load('%s/Zs2.pth' % dir) 219 | reals2 = torch.load('%s/reals2.pth' % dir) 220 | NoiseAmp2 = torch.load('%s/NoiseAmp2.pth' % dir) 221 | else: 222 | print('no appropriate trained model is exist, please train first') 223 | opt.mode = mode 224 | return Gs,Zs,reals,NoiseAmp, Gs2,Zs2,reals2,NoiseAmp2 225 | 226 | 227 | def generate_dir2save(opt): 228 | dir2save = None 229 | if opt.mode == 'train': 230 | dir2save = 'Checkpoints/%s/scale_factor=%.3f, noise_amp=%.4f, lambda_cyc=%.3f, lambda_idt=%.3f' % (opt.input_name,opt.scale_factor_init,opt.noise_amp,opt.lambda_cyc,opt.lambda_idt) 231 | return dir2save 232 | 233 | 234 | def post_config(opt): 235 | # init fixed parameters 236 | opt.device = torch.device("cpu" if opt.not_cuda else "cuda:0") 237 | opt.niter_init = opt.niter 238 | opt.noise_amp_init = opt.noise_amp 239 | opt.nfc_init = opt.nfc 240 | opt.min_nfc_init = opt.min_nfc 241 | opt.scale_factor_init = opt.scale_factor 242 | opt.out_ = 'TrainedModels/%s/scale_factor=%f/' % (opt.input_name[:-4], opt.scale_factor) 243 | 244 | if opt.manualSeed is None: 245 | opt.manualSeed = random.randint(1, 10000) 246 | print("Random Seed: ", opt.manualSeed) 247 | 248 | random.seed(opt.manualSeed) 249 | torch.manual_seed(opt.manualSeed) 250 | if torch.cuda.is_available() and opt.not_cuda: 251 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 252 | return opt 253 | 254 | 255 | def calc_init_scale(opt): 256 | in_scale = math.pow(1/2,1/3) 257 | iter_num = round(math.log(1 / opt.sr_factor, in_scale)) 258 | in_scale = pow(opt.sr_factor, 1 / iter_num) 259 | return in_scale,iter_num 260 | 261 | 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /utils/imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import filters, measurements, interpolation 3 | from skimage import color 4 | from math import pi 5 | import torch 6 | 7 | 8 | def denorm(x): 9 | out = (x + 1) / 2 10 | return out.clamp(0, 1) 11 | 12 | def norm(x): 13 | out = (x - 0.5) * 2 14 | return out.clamp(-1, 1) 15 | 16 | def move_to_gpu(t): 17 | if (torch.cuda.is_available()): 18 | t = t.to(torch.device('cuda')) 19 | return t 20 | 21 | def np2torch(x,opt): 22 | if opt.nc_im == 3: 23 | x = x[:,:,:,None] 24 | x = x.transpose((3, 2, 0, 1))/255 25 | else: 26 | x = color.rgb2gray(x) 27 | x = x[:,:,None,None] 28 | x = x.transpose(3, 2, 0, 1) 29 | x = torch.from_numpy(x) 30 | if not (opt.not_cuda): 31 | x = move_to_gpu(x) 32 | x = x.type(torch.cuda.FloatTensor) if not(opt.not_cuda) else x.type(torch.FloatTensor) 33 | #x = x.type(torch.cuda.FloatTensor) 34 | x = norm(x) 35 | return x 36 | 37 | def torch2uint8(x): 38 | x = x[0,:,:,:] 39 | x = x.permute((1,2,0)) 40 | x = 255*denorm(x) 41 | x = x.cpu().numpy() 42 | x = x.astype(np.uint8) 43 | return x 44 | 45 | 46 | def imresize(ims,scale,opt): 47 | #s = im.shape 48 | for i in range(ims.shape[0]): 49 | im = ims[i:i+1,:,:,:] 50 | im = torch2uint8(im) 51 | im = imresize_in(im, scale_factor=scale) 52 | im = np2torch(im,opt) 53 | if i == 0: 54 | outs = im 55 | else: 56 | outs =torch.cat((outs,im),0) 57 | #im = im[:, :, 0:int(scale * s[2]), 0:int(scale * s[3])] 58 | return outs 59 | 60 | def imresize_to_shape(im,output_shape,opt): 61 | #s = im.shape 62 | im = torch2uint8(im) 63 | im = imresize_in(im, output_shape=output_shape) 64 | im = np2torch(im,opt) 65 | #im = im[:, :, 0:int(scale * s[2]), 0:int(scale * s[3])] 66 | return im 67 | 68 | 69 | def imresize_in(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 70 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 71 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 72 | 73 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 74 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 75 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 76 | 77 | # Choose interpolation method, each method has the matching kernel size 78 | method, kernel_width = { 79 | "cubic": (cubic, 4.0), 80 | "lanczos2": (lanczos2, 4.0), 81 | "lanczos3": (lanczos3, 6.0), 82 | "box": (box, 1.0), 83 | "linear": (linear, 2.0), 84 | None: (cubic, 4.0) # set default interpolation method as cubic 85 | }.get(kernel) 86 | 87 | # Antialiasing is only used when downscaling 88 | antialiasing *= (scale_factor[0] < 1) 89 | 90 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 91 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 92 | 93 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 94 | out_im = np.copy(im) 95 | for dim in sorted_dims: 96 | # No point doing calculations for scale-factor 1. nothing will happen anyway 97 | if scale_factor[dim] == 1.0: 98 | continue 99 | 100 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 101 | # weights that multiply the values there to get its result. 102 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 103 | method, kernel_width, antialiasing) 104 | 105 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 106 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 107 | 108 | return out_im 109 | 110 | 111 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 112 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 113 | # same size as the number of input dimensions) 114 | if scale_factor is not None: 115 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 116 | if np.isscalar(scale_factor): 117 | scale_factor = [scale_factor, scale_factor] 118 | 119 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 120 | scale_factor = list(scale_factor) 121 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 122 | 123 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 124 | # to all the unspecified dimensions 125 | if output_shape is not None: 126 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 127 | 128 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 129 | # sub-optimal, because there can be different scales to the same output-shape. 130 | if scale_factor is None: 131 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 132 | 133 | # Dealing with missing output-shape. calculating according to scale-factor 134 | if output_shape is None: 135 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 136 | 137 | return scale_factor, output_shape 138 | 139 | 140 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 141 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 142 | # such that each position from the field_of_view will be multiplied with a matching filter from the 143 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 144 | # around it. This is only done for one dimension of the image. 145 | 146 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 147 | # 1/sf. this means filtering is more 'low-pass filter'. 148 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 149 | kernel_width *= 1.0 / scale if antialiasing else 1.0 150 | 151 | # These are the coordinates of the output image 152 | out_coordinates = np.arange(1, out_length+1) 153 | 154 | # These are the matching positions of the output-coordinates on the input image coordinates. 155 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 156 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 157 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 158 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 159 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 160 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 161 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 162 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 163 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 164 | 165 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 166 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 167 | 168 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 169 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 170 | expanded_kernel_width = np.ceil(kernel_width) + 2 171 | 172 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 173 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 174 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 175 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 176 | 177 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 178 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 179 | # 'field_of_view') 180 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 181 | 182 | # Normalize weights to sum up to 1. be careful from dividing by 0 183 | sum_weights = np.sum(weights, axis=1) 184 | sum_weights[sum_weights == 0] = 1.0 185 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 186 | 187 | # We use this mirror structure as a trick for reflection padding at the boundaries 188 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 189 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 190 | 191 | # Get rid of weights and pixel positions that are of zero weight 192 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 193 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 194 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 195 | 196 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 197 | return weights, field_of_view 198 | 199 | 200 | def resize_along_dim(im, dim, weights, field_of_view): 201 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 202 | tmp_im = np.swapaxes(im, dim, 0) 203 | 204 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 205 | # tmp_im[field_of_view.T], (bsxfun style) 206 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 207 | 208 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 209 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 210 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 211 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 212 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 213 | # same number 214 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 215 | 216 | # Finally we swap back the axes to the original order 217 | return np.swapaxes(tmp_out_im, dim, 0) 218 | 219 | 220 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 221 | # See kernel_shift function to understand what this is 222 | if kernel_shift_flag: 223 | kernel = kernel_shift(kernel, scale_factor) 224 | 225 | # First run a correlation (convolution with flipped kernel) 226 | out_im = np.zeros_like(im) 227 | for channel in range(np.ndim(im)): 228 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 229 | 230 | # Then subsample and return 231 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 232 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 233 | 234 | 235 | def kernel_shift(kernel, sf): 236 | # There are two reasons for shifting the kernel: 237 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 238 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 239 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 240 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 241 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 242 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 243 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 244 | 245 | # First calculate the current center of mass for the kernel 246 | current_center_of_mass = measurements.center_of_mass(kernel) 247 | 248 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 249 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 250 | 251 | # Define the shift vector for the kernel shifting (x,y) 252 | shift_vec = wanted_center_of_mass - current_center_of_mass 253 | 254 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 255 | # (biggest shift among dims + 1 for safety) 256 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 257 | 258 | # Finally shift the kernel and return 259 | return interpolation.shift(kernel, shift_vec) 260 | 261 | 262 | # These next functions are all interpolation methods. x is the distance from the left pixel center 263 | 264 | 265 | def cubic(x): 266 | absx = np.abs(x) 267 | absx2 = absx ** 2 268 | absx3 = absx ** 3 269 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + 270 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) 271 | 272 | 273 | def lanczos2(x): 274 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / 275 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) 276 | * (abs(x) < 2)) 277 | 278 | 279 | def box(x): 280 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 281 | 282 | 283 | def lanczos3(x): 284 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / 285 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) 286 | * (abs(x) < 3)) 287 | 288 | 289 | def linear(x): 290 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 291 | -------------------------------------------------------------------------------- /utils/manipulate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import utils.functions 3 | import models.model as models 4 | import argparse 5 | import os 6 | from utils.imresize import imresize 7 | import torch.nn as nn 8 | import numpy as np 9 | import math 10 | import matplotlib.pyplot as plt 11 | from models.TuiGAN import * 12 | from options.config import get_arguments 13 | 14 | def TuiGAN_generate(Gs,Zs,reals,NoiseAmp, Gs2,Zs2,reals2,NoiseAmp2, opt,in_s=None,gen_start_scale=0): 15 | if in_s is None: 16 | in_s = torch.full(reals[0].shape, 0, device=opt.device) 17 | x_ab = in_s 18 | x_aba = in_s 19 | count = 0 20 | if opt.mode == 'train': 21 | dir2save = '%s/%s/gen_start_scale=%d' % (opt.out, opt.input_name, gen_start_scale) 22 | else: 23 | dir2save = functions.generate_dir2save(opt) 24 | try: 25 | os.makedirs(dir2save) 26 | except OSError: 27 | pass 28 | for G,G2,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Gs2,Zs,reals,reals[1:],NoiseAmp): 29 | z = functions.generate_noise([3, Z_opt.shape[2] , Z_opt.shape[3] ], device=opt.device) 30 | z = z.expand(real_curr.shape[0], 3, z.shape[2], z.shape[3]) 31 | x_ab = x_ab[:,:,0:real_curr.shape[2],0:real_curr.shape[3]] 32 | z_in = noise_amp*z+real_curr 33 | x_ab = G(z_in.detach(),x_ab) 34 | 35 | x_aba = G2(x_ab,x_aba) 36 | x_ab = imresize(x_ab.detach(),1/opt.scale_factor,opt) 37 | x_ab = x_ab[:,:,0:real_next.shape[2],0:real_next.shape[3]] 38 | x_aba = imresize(x_aba.detach(),1/opt.scale_factor,opt) 39 | x_aba = x_aba[:,:,0:real_next.shape[2],0:real_next.shape[3]] 40 | count += 1 41 | plt.imsave('%s/x_ab_%d.png' % (dir2save,count), functions.convert_image_np(x_ab.detach()), vmin=0,vmax=1) 42 | return x_ab.detach() 43 | 44 | --------------------------------------------------------------------------------