├── README.md ├── datasets ├── __init__.py └── pix2pix.py ├── imgs ├── generated_epoch_00000212_iter00085000.png ├── real_input.png └── real_target.png ├── main_pix2pixgan.py ├── misc.py ├── models ├── UNet.py └── __init__.py └── transforms ├── __init__.py └── pix2pix.py /README.md: -------------------------------------------------------------------------------- 1 | # [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004) 2 | # Install 3 | - install **fantastic** [pytorch](https://github.com/pytorch/pytorch) and [pytorch.vision](https://github.com/pytorch/vision) 4 | 5 | # Datasets 6 | - [Download images from author's implementation](https://github.com/phillipi/pix2pix) 7 | - Suppose you downloaded the "facades" dataset in /path/to/facades 8 | 9 | # Train with facades dataset (mode: B2A) 10 | - ```CUDA_VISIBLE_DEVICES=x python main_pix2pixgan.py --dataset pix2pix --dataroot /path/to/facades/train --valDataroot /path/to/facades/val --mode B2A --exp ./facades --display 5 --evalIter 500``` 11 | - Resulting model is saved in ./facades directory named like net[D|G]_epoch_xx.pth 12 | # Train with edges2shoes dataset (mode: A2B) 13 | - ```CUDA_VISIBLE_DEVICES=x python main_pix2pixgan.py --dataset pix2pix --dataroot /path/to/edges2shoes/train --valDataroot /path/to/edges2shoes/val --mode A2B --exp ./edges2shoes --batchSize 4 --display 5``` 14 | 15 | # Results 16 | - Randomly selected input samples 17 | ![input](https://github.com/taey16/pix2pix.pytorch/blob/master/imgs/real_input.png) 18 | - Corresponding real target samples 19 | ![target](https://github.com/taey16/pix2pix.pytorch/blob/master/imgs/real_target.png) 20 | - **Corresponding generated samples** 21 | ![generated](https://github.com/taey16/pix2pix.pytorch/blob/master/imgs/generated_epoch_00000212_iter00085000.png) 22 | 23 | # Note 24 | - We modified [pytorch.vision.folder](https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) and [transform.py](https://github.com/pytorch/vision/blob/master/torchvision/transforms.py) as to follow the format of train images in the datasets 25 | - Most of the parameters are the same as the paper. 26 | - You can easily reproduce results of the paper with other dataets 27 | - Try B2A or A2B translation as your need 28 | 29 | # Reference 30 | - [pix2pix.torch](https://github.com/phillipi/pix2pix) 31 | - [pix2pix-pytorch](https://github.com/mrzhu-cool/pix2pix-pytorch) (Another pytorch implemention of the pix2pix) 32 | - [dcgan.pytorch](https://github.com/pytorch/examples/tree/master/dcgan) 33 | - **FANTASTIC pytorch** [pytorch doc](http://pytorch.org/docs/notes/autograd.html) 34 | - [genhacks from soumith](https://github.com/soumith/ganhacks) 35 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | def make_dataset(dir): 17 | images = [] 18 | if not os.path.isdir(dir): 19 | raise Exception('Check dataroot') 20 | for root, _, fnames in sorted(os.walk(dir)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | path = os.path.join(dir, fname) 24 | item = path 25 | images.append(item) 26 | return images 27 | 28 | def default_loader(path): 29 | return Image.open(path).convert('RGB') 30 | 31 | class pix2pix(data.Dataset): 32 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 33 | imgs = make_dataset(root) 34 | if len(imgs) == 0: 35 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 36 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 37 | self.root = root 38 | self.imgs = imgs 39 | self.transform = transform 40 | self.loader = loader 41 | 42 | if seed is not None: 43 | np.random.seed(seed) 44 | 45 | def __getitem__(self, index): 46 | path = self.imgs[index] 47 | img = self.loader(path) 48 | # NOTE: img -> PIL Image 49 | w, h = img.size 50 | #w, h = 512, 256 51 | #img = img.resize((w, h), Image.BILINEAR) 52 | # NOTE: split a sample into imgA and imgB 53 | imgA = img.crop((0, 0, w/2, h)) 54 | imgB = img.crop((w/2, 0, w, h)) 55 | if self.transform is not None: 56 | # NOTE preprocessing for each pair of images 57 | imgA, imgB = self.transform(imgA, imgB) 58 | return imgA, imgB 59 | 60 | def __len__(self): 61 | return len(self.imgs) 62 | -------------------------------------------------------------------------------- /imgs/generated_epoch_00000212_iter00085000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/imgs/generated_epoch_00000212_iter00085000.png -------------------------------------------------------------------------------- /imgs/real_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/imgs/real_input.png -------------------------------------------------------------------------------- /imgs/real_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/imgs/real_target.png -------------------------------------------------------------------------------- /main_pix2pixgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import sys 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | cudnn.fastest = True 12 | import torch.optim as optim 13 | import torchvision.utils as vutils 14 | from torch.autograd import Variable 15 | 16 | import models.UNet as net 17 | from misc import * 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--dataset', required=False, 21 | default='pix2pix', help='') 22 | parser.add_argument('--dataroot', required=False, 23 | default='', help='path to trn dataset') 24 | parser.add_argument('--valDataroot', required=False, 25 | default='', help='path to val dataset') 26 | parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes') 27 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 28 | parser.add_argument('--valBatchSize', type=int, default=64, help='input batch size') 29 | parser.add_argument('--originalSize', type=int, 30 | default=286, help='the height / width of the original input image') 31 | parser.add_argument('--imageSize', type=int, 32 | default=256, help='the height / width of the cropped input image to network') 33 | parser.add_argument('--inputChannelSize', type=int, 34 | default=3, help='size of the input channels') 35 | parser.add_argument('--outputChannelSize', type=int, 36 | default=3, help='size of the output channels') 37 | parser.add_argument('--ngf', type=int, default=64) 38 | parser.add_argument('--ndf', type=int, default=64) 39 | parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') 40 | parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002') 41 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') 42 | parser.add_argument('--lambdaGAN', type=float, default=1, help='lambdaGAN, default=1') 43 | parser.add_argument('--lambdaIMG', type=float, default=100, help='lambdaIMG, default=100') 44 | parser.add_argument('--wd', type=float, default=0.0000, help='weight decay, default=0.0004') 45 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 46 | parser.add_argument('--rmsprop', action='store_true', help='Whether to use adam (default is rmsprop)') 47 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 48 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 49 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 50 | parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints') 51 | parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs') 52 | parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot') 53 | 54 | opt = parser.parse_args() 55 | print(opt) 56 | 57 | create_exp_dir(opt.exp) 58 | #opt.manualSeed = random.randint(1, 10000) 59 | opt.manualSeed = 101 60 | random.seed(opt.manualSeed) 61 | torch.manual_seed(opt.manualSeed) 62 | print("Random Seed: ", opt.manualSeed) 63 | 64 | # get dataloader 65 | dataloader = getLoader(opt.dataset, 66 | opt.dataroot, 67 | opt.originalSize, 68 | opt.imageSize, 69 | opt.batchSize, 70 | opt.workers, 71 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 72 | split='train', 73 | shuffle=True, 74 | seed=opt.manualSeed) 75 | valDataloader = getLoader(opt.dataset, 76 | opt.valDataroot, 77 | opt.imageSize, #opt.originalSize, 78 | opt.imageSize, 79 | opt.valBatchSize, 80 | opt.workers, 81 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 82 | split='val', 83 | shuffle=False, 84 | seed=opt.manualSeed) 85 | 86 | # get logger 87 | trainLogger = open('%s/train.log' % opt.exp, 'w') 88 | 89 | ngf = opt.ngf 90 | ndf = opt.ndf 91 | inputChannelSize = opt.inputChannelSize 92 | outputChannelSize= opt.outputChannelSize 93 | 94 | # get models 95 | netG = net.G(inputChannelSize, outputChannelSize, ngf) 96 | netG.apply(weights_init) 97 | if opt.netG != '': 98 | netG.load_state_dict(torch.load(opt.netG)) 99 | print(netG) 100 | netD = net.D(inputChannelSize + outputChannelSize, ndf) 101 | netD.apply(weights_init) 102 | if opt.netD != '': 103 | netD.load_state_dict(torch.load(opt.netD)) 104 | print(netD) 105 | 106 | netG.train() 107 | netD.train() 108 | criterionBCE = nn.BCELoss() 109 | criterionCAE = nn.L1Loss() 110 | 111 | target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize) 112 | input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) 113 | val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize) 114 | val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) 115 | label_d = torch.FloatTensor(opt.batchSize) 116 | # NOTE: size of 2D output maps in the discriminator 117 | sizePatchGAN = 30 118 | real_label = 1 119 | fake_label = 0 120 | # NOTE weight for L_cGAN and L_L1 (i.e. Eq.(4) in the paper) 121 | lambdaGAN = opt.lambdaGAN 122 | lambdaIMG = opt.lambdaIMG 123 | 124 | netD.cuda() 125 | netG.cuda() 126 | criterionBCE.cuda() 127 | criterionCAE.cuda() 128 | target, input, label_d = target.cuda(), input.cuda(), label_d.cuda() 129 | val_target, val_input = val_target.cuda(), val_input.cuda() 130 | 131 | target = Variable(target) 132 | input = Variable(input) 133 | label_d = Variable(label_d) 134 | 135 | # get randomly sampled validation images and save it 136 | val_iter = iter(valDataloader) 137 | data_val = val_iter.next() 138 | if opt.mode == 'B2A': 139 | val_target_cpu, val_input_cpu = data_val 140 | elif opt.mode == 'A2B': 141 | val_input_cpu, val_target_cpu = data_val 142 | val_target_cpu, val_input_cpu = val_target_cpu.cuda(), val_input_cpu.cuda() 143 | val_target.resize_as_(val_target_cpu).copy_(val_target_cpu) 144 | val_input.resize_as_(val_input_cpu).copy_(val_input_cpu) 145 | vutils.save_image(val_target, '%s/real_target.png' % opt.exp, normalize=True) 146 | vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True) 147 | 148 | if opt.rmsprop: 149 | optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD) 150 | optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG) 151 | else: 152 | optimizerD = optim.Adam(netD.parameters(), lr = opt.lrD, betas = (opt.beta1, 0.999), weight_decay=opt.wd) 153 | optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.0) 154 | 155 | # NOTE training loop 156 | ganIterations = 0 157 | for epoch in range(opt.niter): 158 | for i, data in enumerate(dataloader, 0): 159 | if opt.mode == 'B2A': 160 | target_cpu, input_cpu = data 161 | elif opt.mode == 'A2B' : 162 | input_cpu, target_cpu = data 163 | batch_size = target_cpu.size(0) 164 | 165 | target_cpu, input_cpu = target_cpu.cuda(), input_cpu.cuda() 166 | # NOTE 167 | # default: imgA: target, imgB: input 168 | target.data.resize_as_(target_cpu).copy_(target_cpu) 169 | input.data.resize_as_(input_cpu).copy_(input_cpu) 170 | 171 | # max_D first 172 | for p in netD.parameters(): 173 | p.requires_grad = True 174 | netD.zero_grad() 175 | 176 | # NOTE: compute L_cGAN in eq.(2) 177 | label_d.data.resize_((batch_size, 1, sizePatchGAN, sizePatchGAN)).fill_(real_label) 178 | output = netD(torch.cat([target, input], 1)) # conditional 179 | errD_real = criterionBCE(output, label_d) 180 | errD_real.backward() 181 | D_x = output.data.mean() 182 | x_hat = netG(input) 183 | fake = x_hat.detach() 184 | label_d.data.fill_(fake_label) 185 | output = netD(torch.cat([fake, input], 1)) # conditional 186 | errD_fake = criterionBCE(output, label_d) 187 | errD_fake.backward() 188 | D_G_z1 = output.data.mean() 189 | errD = errD_real + errD_fake 190 | optimizerD.step() # update parameters 191 | 192 | # prevent computing gradients of weights in Discriminator 193 | for p in netD.parameters(): 194 | p.requires_grad = False 195 | netG.zero_grad() # start to update G 196 | 197 | # compute L_L1 (eq.(4) in the paper 198 | L_img_ = criterionCAE(x_hat, target) 199 | L_img = lambdaIMG * L_img_ 200 | if lambdaIMG <> 0: 201 | L_img.backward(retain_variables=True) 202 | 203 | # compute L_cGAN (eq.(2) in the paper 204 | label_d.data.fill_(real_label) 205 | output = netD(torch.cat([x_hat, input], 1)) 206 | errG_ = criterionBCE(output, label_d) 207 | errG = lambdaGAN * errG_ 208 | if lambdaGAN <> 0: 209 | errG.backward() 210 | D_G_z2 = output.data.mean() 211 | 212 | optimizerG.step() 213 | ganIterations += 1 214 | 215 | if ganIterations % opt.display == 0: 216 | print('[%d/%d][%d/%d] L_D: %f L_img: %f L_G: %f D(x): %f D(G(z)): %f / %f' 217 | % (epoch, opt.niter, i, len(dataloader), 218 | errD.data[0], L_img.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 219 | sys.stdout.flush() 220 | trainLogger.write('%d\t%f\t%f\t%f\t%f\t%f\t%f\n' % \ 221 | (i, errD.data[0], errG.data[0], L_img.data[0], D_x, D_G_z1, D_G_z2)) 222 | trainLogger.flush() 223 | if ganIterations % opt.evalIter == 0: 224 | val_batch_output = torch.FloatTensor(val_input.size()).fill_(0) 225 | for idx in range(val_input.size(0)): 226 | single_img = val_input[idx,:,:,:].unsqueeze(0) 227 | val_inputv = Variable(single_img, volatile=True) 228 | x_hat_val = netG(val_inputv) 229 | val_batch_output[idx,:,:,:].copy_(x_hat_val.data) 230 | vutils.save_image(val_batch_output, '%s/generated_epoch_%08d_iter%08d.png' % \ 231 | (opt.exp, epoch, ganIterations), normalize=True) 232 | 233 | # do checkpointing 234 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.exp, epoch)) 235 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.exp, epoch)) 236 | trainLogger.close() 237 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | def create_exp_dir(exp): 6 | try: 7 | os.makedirs(exp) 8 | print('Creating exp dir: %s' % exp) 9 | except OSError: 10 | pass 11 | return True 12 | 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find('BatchNorm') != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | 22 | 23 | def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4, 24 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None): 25 | 26 | #import pdb; pdb.set_trace() 27 | if datasetName == 'trans': 28 | from datasets.trans import trans as commonDataset 29 | import transforms.pix2pix as transforms 30 | elif datasetName == 'folder': 31 | from torchvision.datasets.folder import ImageFolder as commonDataset 32 | import torchvision.transforms as transforms 33 | elif datasetName == 'pix2pix': 34 | from datasets.pix2pix import pix2pix as commonDataset 35 | import transforms.pix2pix as transforms 36 | 37 | if datasetName != 'folder': 38 | if split == 'train': 39 | dataset = commonDataset(root=dataroot, 40 | transform=transforms.Compose([ 41 | transforms.Scale(originalSize), 42 | transforms.RandomCrop(imageSize), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean, std), 46 | ]), 47 | seed=seed) 48 | else: 49 | dataset = commonDataset(root=dataroot, 50 | transform=transforms.Compose([ 51 | transforms.Scale(originalSize), 52 | transforms.CenterCrop(imageSize), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean, std), 55 | ]), 56 | seed=seed) 57 | 58 | else: 59 | if split == 'train': 60 | dataset = commonDataset(root=dataroot, 61 | transform=transforms.Compose([ 62 | transforms.Scale(originalSize), 63 | transforms.RandomCrop(imageSize), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | transforms.Normalize(mean, std), 67 | ])) 68 | else: 69 | dataset = commonDataset(root=dataroot, 70 | transform=transforms.Compose([ 71 | transforms.Scale(originalSize), 72 | transforms.CenterCrop(imageSize), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean, std), 75 | ])) 76 | assert dataset 77 | dataloader = torch.utils.data.DataLoader(dataset, 78 | batch_size=batchSize, 79 | shuffle=shuffle, 80 | num_workers=int(workers)) 81 | return dataloader 82 | 83 | def check_cuda(opt): 84 | if torch.cuda.is_available() and not opt.cuda: 85 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 86 | opt.cuda = True 87 | return opt 88 | 89 | 90 | ################ 91 | def accuracy(output, target, topk=(1,)): 92 | """Computes the precision@k for the specified values of k""" 93 | maxk = max(topk) 94 | batch_size = target.size(0) 95 | 96 | _, pred = output.topk(maxk, 1, True, True) 97 | pred = pred.t() 98 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 99 | 100 | res = [] 101 | for k in topk: 102 | correct_k = correct[:k].view(-1).float().sum(0) 103 | res.append(correct_k.mul_(100.0 / batch_size)) 104 | return res 105 | 106 | 107 | class AverageMeter(object): 108 | """Computes and stores the average and current value""" 109 | def __init__(self): 110 | self.reset() 111 | 112 | def reset(self): 113 | self.val = 0 114 | self.avg = 0 115 | self.sum = 0 116 | self.count = 0 117 | 118 | def update(self, val, n=1): 119 | self.val = val 120 | self.sum += val * n 121 | self.count += n 122 | self.avg = self.sum / self.count 123 | 124 | 125 | import numpy as np 126 | class ImagePool: 127 | def __init__(self, pool_size=50): 128 | self.pool_size = pool_size 129 | if pool_size > 0: 130 | self.num_imgs = 0 131 | self.images = [] 132 | 133 | def query(self, image): 134 | if self.pool_size == 0: 135 | return image 136 | if self.num_imgs < self.pool_size: 137 | self.images.append(image.clone()) 138 | self.num_imgs += 1 139 | return image 140 | else: 141 | #import pdb; pdb.set_trace() 142 | if np.random.uniform(0,1) > 0.5: 143 | random_id = np.random.randint(self.pool_size, size=1)[0] 144 | tmp = self.images[random_id].clone() 145 | self.images[random_id] = image.clone() 146 | return tmp 147 | else: 148 | return image 149 | 150 | def adjust_learning_rate(optimizer, init_lr, epoch, factor, every): 151 | #import pdb; pdb.set_trace() 152 | lrd = init_lr / every 153 | old_lr = optimizer.param_groups[0]['lr'] 154 | lr = old_lr - lrd 155 | if lr < 0: lr = 0 156 | # optimizer.param_groups[0].keys() 157 | # ['betas', 'weight_decay', 'params', 'eps', 'lr'] 158 | # optimizer.param_groups[1].keys() 159 | # *** IndexError: list index out of range 160 | for param_group in optimizer.param_groups: 161 | param_group['lr'] = lr 162 | 163 | """ 164 | import cv2 165 | def visualize_attention(masks, resize, inputs=None): 166 | masks = masks.numpy() 167 | masks = masks.transpose(1,2,0) 168 | masks = cv2.resize(masks, (resize, resize)) 169 | if masks.ndim == 2: masks = masks[:,:,np.newaxis] 170 | masks = masks.transpose(2,0,1) 171 | masks = torch.from_numpy(masks).unsqueeze(1) 172 | if inputs is not None: 173 | return inputs * masks.expand_as(inputs) 174 | else: 175 | return masks 176 | """ 177 | -------------------------------------------------------------------------------- /models/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.nn as nn 4 | 5 | 6 | def blockUNet(in_c, out_c, name, transposed=False, bn=True, relu=True, dropout=False): 7 | block = nn.Sequential() 8 | if relu: 9 | block.add_module('%s.relu' % name, nn.ReLU(inplace=True)) 10 | else: 11 | block.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 12 | if not transposed: 13 | block.add_module('%s.conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)) 14 | else: 15 | block.add_module('%s.tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)) 16 | if bn: 17 | block.add_module('%s.bn' % name, nn.BatchNorm2d(out_c)) 18 | if dropout: 19 | block.add_module('%s.dropout' % name, nn.Dropout2d(0.5, inplace=True)) 20 | return block 21 | 22 | 23 | class D(nn.Module): 24 | def __init__(self, nc, nf): 25 | super(D, self).__init__() 26 | 27 | main = nn.Sequential() 28 | # 256 29 | layer_idx = 1 30 | name = 'layer%d' % layer_idx 31 | main.add_module('%s.conv' % name, nn.Conv2d(nc, nf, 4, 2, 1, bias=False)) 32 | 33 | # 128 34 | layer_idx += 1 35 | name = 'layer%d' % layer_idx 36 | main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False)) 37 | 38 | # 64 39 | layer_idx += 1 40 | name = 'layer%d' % layer_idx 41 | nf = nf * 2 42 | main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False)) 43 | 44 | # 32 45 | layer_idx += 1 46 | name = 'layer%d' % layer_idx 47 | nf = nf * 2 48 | main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 49 | main.add_module('%s.conv' % name, nn.Conv2d(nf, nf*2, 4, 1, 1, bias=False)) 50 | main.add_module('%s.bn' % name, nn.BatchNorm2d(nf*2)) 51 | 52 | # 31 53 | layer_idx += 1 54 | name = 'layer%d' % layer_idx 55 | nf = nf * 2 56 | main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 57 | main.add_module('%s.conv' % name, nn.Conv2d(nf, 1, 4, 1, 1, bias=False)) 58 | main.add_module('%s.sigmoid' % name , nn.Sigmoid()) 59 | # 30 (sizePatchGAN=30) 60 | 61 | self.main = main 62 | 63 | def forward(self, x): 64 | output = self.main(x) 65 | return output 66 | 67 | 68 | class G(nn.Module): 69 | def __init__(self, input_nc, output_nc, nf): 70 | super(G, self).__init__() 71 | 72 | # input is 256 x 256 73 | layer_idx = 1 74 | name = 'layer%d' % layer_idx 75 | layer1 = nn.Sequential() 76 | layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False)) 77 | # input is 128 x 128 78 | layer_idx += 1 79 | name = 'layer%d' % layer_idx 80 | layer2 = blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False) 81 | # input is 64 x 64 82 | layer_idx += 1 83 | name = 'layer%d' % layer_idx 84 | layer3 = blockUNet(nf*2, nf*4, name, transposed=False, bn=True, relu=False, dropout=False) 85 | # input is 32 86 | layer_idx += 1 87 | name = 'layer%d' % layer_idx 88 | layer4 = blockUNet(nf*4, nf*8, name, transposed=False, bn=True, relu=False, dropout=False) 89 | # input is 16 90 | layer_idx += 1 91 | name = 'layer%d' % layer_idx 92 | layer5 = blockUNet(nf*8, nf*8, name, transposed=False, bn=True, relu=False, dropout=False) 93 | # input is 8 94 | layer_idx += 1 95 | name = 'layer%d' % layer_idx 96 | layer6 = blockUNet(nf*8, nf*8, name, transposed=False, bn=True, relu=False, dropout=False) 97 | # input is 4 98 | layer_idx += 1 99 | name = 'layer%d' % layer_idx 100 | layer7 = blockUNet(nf*8, nf*8, name, transposed=False, bn=True, relu=False, dropout=False) 101 | # input is 2 x 2 102 | layer_idx += 1 103 | name = 'layer%d' % layer_idx 104 | layer8 = blockUNet(nf*8, nf*8, name, transposed=False, bn=False, relu=False, dropout=False) 105 | 106 | ## NOTE: decoder 107 | # input is 1 108 | name = 'dlayer%d' % layer_idx 109 | d_inc = nf*8 110 | dlayer8 = blockUNet(d_inc, nf*8, name, transposed=True, bn=True, relu=True, dropout=True) 111 | 112 | #import pdb; pdb.set_trace() 113 | # input is 2 114 | layer_idx -= 1 115 | name = 'dlayer%d' % layer_idx 116 | d_inc = nf*8*2 117 | dlayer7 = blockUNet(d_inc, nf*8, name, transposed=True, bn=True, relu=True, dropout=True) 118 | # input is 4 119 | layer_idx -= 1 120 | name = 'dlayer%d' % layer_idx 121 | d_inc = nf*8*2 122 | dlayer6 = blockUNet(d_inc, nf*8, name, transposed=True, bn=True, relu=True, dropout=True) 123 | # input is 8 124 | layer_idx -= 1 125 | name = 'dlayer%d' % layer_idx 126 | d_inc = nf*8*2 127 | dlayer5 = blockUNet(d_inc, nf*8, name, transposed=True, bn=True, relu=True, dropout=False) 128 | # input is 16 129 | layer_idx -= 1 130 | name = 'dlayer%d' % layer_idx 131 | d_inc = nf*8*2 132 | dlayer4 = blockUNet(d_inc, nf*4, name, transposed=True, bn=True, relu=True, dropout=False) 133 | # input is 32 134 | layer_idx -= 1 135 | name = 'dlayer%d' % layer_idx 136 | d_inc = nf*4*2 137 | dlayer3 = blockUNet(d_inc, nf*2, name, transposed=True, bn=True, relu=True, dropout=False) 138 | # input is 64 139 | layer_idx -= 1 140 | name = 'dlayer%d' % layer_idx 141 | d_inc = nf*2*2 142 | dlayer2 = blockUNet(d_inc, nf, name, transposed=True, bn=True, relu=True, dropout=False) 143 | # input is 128 144 | layer_idx -= 1 145 | name = 'dlayer%d' % layer_idx 146 | dlayer1 = nn.Sequential() 147 | d_inc = nf*2 148 | dlayer1.add_module('%s.relu' % name, nn.ReLU(inplace=True)) 149 | dlayer1.add_module('%s.tconv' % name, nn.ConvTranspose2d(d_inc, output_nc, 4, 2, 1, bias=False)) 150 | dlayer1.add_module('%s.tanh' % name, nn.Tanh()) 151 | 152 | self.layer1 = layer1 153 | self.layer2 = layer2 154 | self.layer3 = layer3 155 | self.layer4 = layer4 156 | self.layer5 = layer5 157 | self.layer6 = layer6 158 | self.layer7 = layer7 159 | self.layer8 = layer8 160 | self.dlayer8 = dlayer8 161 | self.dlayer7 = dlayer7 162 | self.dlayer6 = dlayer6 163 | self.dlayer5 = dlayer5 164 | self.dlayer4 = dlayer4 165 | self.dlayer3 = dlayer3 166 | self.dlayer2 = dlayer2 167 | self.dlayer1 = dlayer1 168 | 169 | def forward(self, x): 170 | out1 = self.layer1(x) 171 | out2 = self.layer2(out1) 172 | out3 = self.layer3(out2) 173 | out4 = self.layer4(out3) 174 | out5 = self.layer5(out4) 175 | out6 = self.layer6(out5) 176 | out7 = self.layer7(out6) 177 | out8 = self.layer8(out7) 178 | dout8 = self.dlayer8(out8) 179 | dout8_out7 = torch.cat([dout8, out7], 1) 180 | dout7 = self.dlayer7(dout8_out7) 181 | dout7_out6 = torch.cat([dout7, out6], 1) 182 | dout6 = self.dlayer6(dout7_out6) 183 | dout6_out5 = torch.cat([dout6, out5], 1) 184 | dout5 = self.dlayer5(dout6_out5) 185 | dout5_out4 = torch.cat([dout5, out4], 1) 186 | dout4 = self.dlayer4(dout5_out4) 187 | dout4_out3 = torch.cat([dout4, out3], 1) 188 | dout3 = self.dlayer3(dout4_out3) 189 | dout3_out2 = torch.cat([dout3, out2], 1) 190 | dout2 = self.dlayer2(dout3_out2) 191 | dout2_out1 = torch.cat([dout2, out1], 1) 192 | dout1 = self.dlayer1(dout2_out1) 193 | return dout1 194 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/models/__init__.py -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taey16/pix2pix.pytorch/3cbd98cc9ec08d496b7c8fc170c5ea4dc15a9e00/transforms/__init__.py -------------------------------------------------------------------------------- /transforms/pix2pix.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB): 24 | for t in self.transforms: 25 | imgA, imgB = t(imgA, imgB) 26 | return imgA, imgB 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB): 33 | pics = [picA, picB] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB): 61 | pics = [picA, picB] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB): 87 | tensors = [tensorA, tensorB] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB): 109 | imgs = [imgA, imgB] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB): 139 | imgs = [imgA, imgB] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB): 158 | imgs = [imgA, imgB] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB): 171 | imgs = [imgA, imgB] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB): 190 | imgs = [imgA, imgB] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB): 214 | imgs = [imgA, imgB] 215 | output = [] 216 | flag = random.random() < 0.5 217 | for img in imgs: 218 | if flag: 219 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 220 | else: 221 | output.append(img) 222 | return output[0], output[1] 223 | --------------------------------------------------------------------------------