├── .gitignore ├── LICENSE ├── README.md ├── download_dataset.sh ├── examples.jpg ├── label_to_facades.png ├── main.py ├── model.py ├── ops.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | sample/* 2 | logs/* 3 | test/* 4 | datasets/* 5 | checkpoint/* 6 | val/* 7 | *.pyc 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016-2018 Yen-Chen Lin http://yclin.me/ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pix2pix-tensorflow 2 | 3 | TensorFlow implementation of [Image-to-Image Translation Using Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf) that learns a mapping from input images to output images. 4 | 5 | Here are some results generated by the authors of paper: 6 | 7 | 8 | 9 | ## Setup 10 | 11 | ### Prerequisites 12 | - Linux 13 | - Python with numpy 14 | - NVIDIA GPU + CUDA 8.0 + CuDNNv5.1 15 | - TensorFlow 0.11 16 | 17 | ### Getting Started 18 | - Clone this repo: 19 | ```bash 20 | git clone git@github.com:yenchenlin/pix2pix-tensorflow.git 21 | cd pix2pix-tensorflow 22 | ``` 23 | - Download the dataset (script borrowed from [torch code](https://github.com/phillipi/pix2pix/blob/master/datasets/download_dataset.sh)): 24 | ```bash 25 | bash ./download_dataset.sh facades 26 | ``` 27 | - Train the model 28 | ```bash 29 | python main.py --phase train 30 | ``` 31 | - Test the model: 32 | ```bash 33 | python main.py --phase test 34 | ``` 35 | 36 | ## Results 37 | Here is the results generated from this implementation: 38 | 39 | - Facades: 40 | 41 | 42 | 43 | More results on other datasets coming soon! 44 | 45 | **Note**: To avoid the fast convergence of D (discriminator) network, G (generator) network is updated twice for each D network update, which differs from original paper but same as [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow), which this project based on. 46 | 47 | ## Train 48 | Code currently supports [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/) dataset. To reproduce results presented above, it takes 200 epochs of training. Exact computing time depends on own hardware conditions. 49 | 50 | ## Test 51 | Test the model on validation set of [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/) dataset. It will generate synthesized images provided corresponding labels under directory `./test`. 52 | 53 | 54 | ## Acknowledgments 55 | Code borrows heavily from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py). Thanks for their excellent work! 56 | 57 | ## License 58 | MIT 59 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | FILE=$1 3 | URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz 4 | TAR_FILE=./datasets/$FILE.tar.gz 5 | TARGET_DIR=./datasets/$FILE/ 6 | wget -N $URL -O $TAR_FILE 7 | mkdir $TARGET_DIR 8 | tar -zxvf $TAR_FILE -C ./datasets/ 9 | rm $TAR_FILE 10 | -------------------------------------------------------------------------------- /examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yenchenlin/pix2pix-tensorflow/ba40020706ad3a1fbefa1da7bc7a05b7b031fb9e/examples.jpg -------------------------------------------------------------------------------- /label_to_facades.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yenchenlin/pix2pix-tensorflow/ba40020706ad3a1fbefa1da7bc7a05b7b031fb9e/label_to_facades.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import scipy.misc 4 | import numpy as np 5 | 6 | from model import pix2pix 7 | import tensorflow as tf 8 | 9 | parser = argparse.ArgumentParser(description='') 10 | parser.add_argument('--dataset_name', dest='dataset_name', default='facades', help='name of the dataset') 11 | parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') 12 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch') 13 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train') 14 | parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size') 15 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size') 16 | parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer') 17 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer') 18 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels') 19 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels') 20 | parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate') 21 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam') 22 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') 23 | parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation') 24 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA') 25 | parser.add_argument('--phase', dest='phase', default='train', help='train, test') 26 | parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=50, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)') 27 | parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)') 28 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=50, help='print the debug information every print_freq iterations') 29 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false') 30 | parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly') 31 | parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list') 32 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here') 33 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here') 34 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here') 35 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=100.0, help='weight on L1 term in objective') 36 | 37 | args = parser.parse_args() 38 | 39 | def main(_): 40 | if not os.path.exists(args.checkpoint_dir): 41 | os.makedirs(args.checkpoint_dir) 42 | if not os.path.exists(args.sample_dir): 43 | os.makedirs(args.sample_dir) 44 | if not os.path.exists(args.test_dir): 45 | os.makedirs(args.test_dir) 46 | 47 | with tf.Session() as sess: 48 | model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size, 49 | output_size=args.fine_size, dataset_name=args.dataset_name, 50 | checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir) 51 | 52 | if args.phase == 'train': 53 | model.train(args) 54 | else: 55 | model.test(args) 56 | 57 | if __name__ == '__main__': 58 | tf.app.run() 59 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | from glob import glob 5 | import tensorflow as tf 6 | import numpy as np 7 | from six.moves import xrange 8 | 9 | from ops import * 10 | from utils import * 11 | 12 | class pix2pix(object): 13 | def __init__(self, sess, image_size=256, 14 | batch_size=1, sample_size=1, output_size=256, 15 | gf_dim=64, df_dim=64, L1_lambda=100, 16 | input_c_dim=3, output_c_dim=3, dataset_name='facades', 17 | checkpoint_dir=None, sample_dir=None): 18 | """ 19 | 20 | Args: 21 | sess: TensorFlow session 22 | batch_size: The size of batch. Should be specified before training. 23 | output_size: (optional) The resolution in pixels of the images. [256] 24 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 25 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 26 | input_c_dim: (optional) Dimension of input image color. For grayscale input, set to 1. [3] 27 | output_c_dim: (optional) Dimension of output image color. For grayscale input, set to 1. [3] 28 | """ 29 | self.sess = sess 30 | self.is_grayscale = (input_c_dim == 1) 31 | self.batch_size = batch_size 32 | self.image_size = image_size 33 | self.sample_size = sample_size 34 | self.output_size = output_size 35 | 36 | self.gf_dim = gf_dim 37 | self.df_dim = df_dim 38 | 39 | self.input_c_dim = input_c_dim 40 | self.output_c_dim = output_c_dim 41 | 42 | self.L1_lambda = L1_lambda 43 | 44 | # batch normalization : deals with poor initialization helps gradient flow 45 | self.d_bn1 = batch_norm(name='d_bn1') 46 | self.d_bn2 = batch_norm(name='d_bn2') 47 | self.d_bn3 = batch_norm(name='d_bn3') 48 | 49 | self.g_bn_e2 = batch_norm(name='g_bn_e2') 50 | self.g_bn_e3 = batch_norm(name='g_bn_e3') 51 | self.g_bn_e4 = batch_norm(name='g_bn_e4') 52 | self.g_bn_e5 = batch_norm(name='g_bn_e5') 53 | self.g_bn_e6 = batch_norm(name='g_bn_e6') 54 | self.g_bn_e7 = batch_norm(name='g_bn_e7') 55 | self.g_bn_e8 = batch_norm(name='g_bn_e8') 56 | 57 | self.g_bn_d1 = batch_norm(name='g_bn_d1') 58 | self.g_bn_d2 = batch_norm(name='g_bn_d2') 59 | self.g_bn_d3 = batch_norm(name='g_bn_d3') 60 | self.g_bn_d4 = batch_norm(name='g_bn_d4') 61 | self.g_bn_d5 = batch_norm(name='g_bn_d5') 62 | self.g_bn_d6 = batch_norm(name='g_bn_d6') 63 | self.g_bn_d7 = batch_norm(name='g_bn_d7') 64 | 65 | self.dataset_name = dataset_name 66 | self.checkpoint_dir = checkpoint_dir 67 | self.build_model() 68 | 69 | def build_model(self): 70 | self.real_data = tf.placeholder(tf.float32, 71 | [self.batch_size, self.image_size, self.image_size, 72 | self.input_c_dim + self.output_c_dim], 73 | name='real_A_and_B_images') 74 | 75 | self.real_B = self.real_data[:, :, :, :self.input_c_dim] 76 | self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim] 77 | 78 | self.fake_B = self.generator(self.real_A) 79 | 80 | self.real_AB = tf.concat([self.real_A, self.real_B], 3) 81 | self.fake_AB = tf.concat([self.real_A, self.fake_B], 3) 82 | self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False) 83 | self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True) 84 | 85 | self.fake_B_sample = self.sampler(self.real_A) 86 | 87 | self.d_sum = tf.summary.histogram("d", self.D) 88 | self.d__sum = tf.summary.histogram("d_", self.D_) 89 | self.fake_B_sum = tf.summary.image("fake_B", self.fake_B) 90 | 91 | self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D))) 92 | self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_))) 93 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \ 94 | + self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B)) 95 | 96 | self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) 97 | self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) 98 | 99 | self.d_loss = self.d_loss_real + self.d_loss_fake 100 | 101 | self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 102 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 103 | 104 | t_vars = tf.trainable_variables() 105 | 106 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 107 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 108 | 109 | self.saver = tf.train.Saver() 110 | 111 | 112 | def load_random_samples(self): 113 | data = np.random.choice(glob('./datasets/{}/val/*.jpg'.format(self.dataset_name)), self.batch_size) 114 | sample = [load_data(sample_file) for sample_file in data] 115 | 116 | if (self.is_grayscale): 117 | sample_images = np.array(sample).astype(np.float32)[:, :, :, None] 118 | else: 119 | sample_images = np.array(sample).astype(np.float32) 120 | return sample_images 121 | 122 | def sample_model(self, sample_dir, epoch, idx): 123 | sample_images = self.load_random_samples() 124 | samples, d_loss, g_loss = self.sess.run( 125 | [self.fake_B_sample, self.d_loss, self.g_loss], 126 | feed_dict={self.real_data: sample_images} 127 | ) 128 | save_images(samples, [self.batch_size, 1], 129 | './{}/train_{:02d}_{:04d}.png'.format(sample_dir, epoch, idx)) 130 | print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss)) 131 | 132 | def train(self, args): 133 | """Train pix2pix""" 134 | d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \ 135 | .minimize(self.d_loss, var_list=self.d_vars) 136 | g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \ 137 | .minimize(self.g_loss, var_list=self.g_vars) 138 | 139 | init_op = tf.global_variables_initializer() 140 | self.sess.run(init_op) 141 | 142 | self.g_sum = tf.summary.merge([self.d__sum, 143 | self.fake_B_sum, self.d_loss_fake_sum, self.g_loss_sum]) 144 | self.d_sum = tf.summary.merge([self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 145 | self.writer = tf.summary.FileWriter("./logs", self.sess.graph) 146 | 147 | counter = 1 148 | start_time = time.time() 149 | 150 | if self.load(self.checkpoint_dir): 151 | print(" [*] Load SUCCESS") 152 | else: 153 | print(" [!] Load failed...") 154 | 155 | for epoch in xrange(args.epoch): 156 | data = glob('./datasets/{}/train/*.jpg'.format(self.dataset_name)) 157 | #np.random.shuffle(data) 158 | batch_idxs = min(len(data), args.train_size) // self.batch_size 159 | 160 | for idx in xrange(0, batch_idxs): 161 | batch_files = data[idx*self.batch_size:(idx+1)*self.batch_size] 162 | batch = [load_data(batch_file) for batch_file in batch_files] 163 | if (self.is_grayscale): 164 | batch_images = np.array(batch).astype(np.float32)[:, :, :, None] 165 | else: 166 | batch_images = np.array(batch).astype(np.float32) 167 | 168 | # Update D network 169 | _, summary_str = self.sess.run([d_optim, self.d_sum], 170 | feed_dict={ self.real_data: batch_images }) 171 | self.writer.add_summary(summary_str, counter) 172 | 173 | # Update G network 174 | _, summary_str = self.sess.run([g_optim, self.g_sum], 175 | feed_dict={ self.real_data: batch_images }) 176 | self.writer.add_summary(summary_str, counter) 177 | 178 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 179 | _, summary_str = self.sess.run([g_optim, self.g_sum], 180 | feed_dict={ self.real_data: batch_images }) 181 | self.writer.add_summary(summary_str, counter) 182 | 183 | errD_fake = self.d_loss_fake.eval({self.real_data: batch_images}) 184 | errD_real = self.d_loss_real.eval({self.real_data: batch_images}) 185 | errG = self.g_loss.eval({self.real_data: batch_images}) 186 | 187 | counter += 1 188 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 189 | % (epoch, idx, batch_idxs, 190 | time.time() - start_time, errD_fake+errD_real, errG)) 191 | 192 | if np.mod(counter, 100) == 1: 193 | self.sample_model(args.sample_dir, epoch, idx) 194 | 195 | if np.mod(counter, 500) == 2: 196 | self.save(args.checkpoint_dir, counter) 197 | 198 | def discriminator(self, image, y=None, reuse=False): 199 | 200 | with tf.variable_scope("discriminator") as scope: 201 | 202 | # image is 256 x 256 x (input_c_dim + output_c_dim) 203 | if reuse: 204 | tf.get_variable_scope().reuse_variables() 205 | else: 206 | assert tf.get_variable_scope().reuse == False 207 | 208 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 209 | # h0 is (128 x 128 x self.df_dim) 210 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) 211 | # h1 is (64 x 64 x self.df_dim*2) 212 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) 213 | # h2 is (32x 32 x self.df_dim*4) 214 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, d_h=1, d_w=1, name='d_h3_conv'))) 215 | # h3 is (16 x 16 x self.df_dim*8) 216 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin') 217 | 218 | return tf.nn.sigmoid(h4), h4 219 | 220 | def generator(self, image, y=None): 221 | with tf.variable_scope("generator") as scope: 222 | 223 | s = self.output_size 224 | s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128) 225 | 226 | # image is (256 x 256 x input_c_dim) 227 | e1 = conv2d(image, self.gf_dim, name='g_e1_conv') 228 | # e1 is (128 x 128 x self.gf_dim) 229 | e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv')) 230 | # e2 is (64 x 64 x self.gf_dim*2) 231 | e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv')) 232 | # e3 is (32 x 32 x self.gf_dim*4) 233 | e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv')) 234 | # e4 is (16 x 16 x self.gf_dim*8) 235 | e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv')) 236 | # e5 is (8 x 8 x self.gf_dim*8) 237 | e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv')) 238 | # e6 is (4 x 4 x self.gf_dim*8) 239 | e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv')) 240 | # e7 is (2 x 2 x self.gf_dim*8) 241 | e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv')) 242 | # e8 is (1 x 1 x self.gf_dim*8) 243 | 244 | self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8), 245 | [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True) 246 | d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5) 247 | d1 = tf.concat([d1, e7], 3) 248 | # d1 is (2 x 2 x self.gf_dim*8*2) 249 | 250 | self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1), 251 | [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True) 252 | d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5) 253 | d2 = tf.concat([d2, e6], 3) 254 | # d2 is (4 x 4 x self.gf_dim*8*2) 255 | 256 | self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2), 257 | [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True) 258 | d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5) 259 | d3 = tf.concat([d3, e5], 3) 260 | # d3 is (8 x 8 x self.gf_dim*8*2) 261 | 262 | self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3), 263 | [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True) 264 | d4 = self.g_bn_d4(self.d4) 265 | d4 = tf.concat([d4, e4], 3) 266 | # d4 is (16 x 16 x self.gf_dim*8*2) 267 | 268 | self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4), 269 | [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True) 270 | d5 = self.g_bn_d5(self.d5) 271 | d5 = tf.concat([d5, e3], 3) 272 | # d5 is (32 x 32 x self.gf_dim*4*2) 273 | 274 | self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5), 275 | [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True) 276 | d6 = self.g_bn_d6(self.d6) 277 | d6 = tf.concat([d6, e2], 3) 278 | # d6 is (64 x 64 x self.gf_dim*2*2) 279 | 280 | self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6), 281 | [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True) 282 | d7 = self.g_bn_d7(self.d7) 283 | d7 = tf.concat([d7, e1], 3) 284 | # d7 is (128 x 128 x self.gf_dim*1*2) 285 | 286 | self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7), 287 | [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True) 288 | # d8 is (256 x 256 x output_c_dim) 289 | 290 | return tf.nn.tanh(self.d8) 291 | 292 | def sampler(self, image, y=None): 293 | 294 | with tf.variable_scope("generator") as scope: 295 | scope.reuse_variables() 296 | 297 | s = self.output_size 298 | s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128) 299 | 300 | # image is (256 x 256 x input_c_dim) 301 | e1 = conv2d(image, self.gf_dim, name='g_e1_conv') 302 | # e1 is (128 x 128 x self.gf_dim) 303 | e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv')) 304 | # e2 is (64 x 64 x self.gf_dim*2) 305 | e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv')) 306 | # e3 is (32 x 32 x self.gf_dim*4) 307 | e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv')) 308 | # e4 is (16 x 16 x self.gf_dim*8) 309 | e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv')) 310 | # e5 is (8 x 8 x self.gf_dim*8) 311 | e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv')) 312 | # e6 is (4 x 4 x self.gf_dim*8) 313 | e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv')) 314 | # e7 is (2 x 2 x self.gf_dim*8) 315 | e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv')) 316 | # e8 is (1 x 1 x self.gf_dim*8) 317 | 318 | self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8), 319 | [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True) 320 | d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5) 321 | d1 = tf.concat([d1, e7], 3) 322 | # d1 is (2 x 2 x self.gf_dim*8*2) 323 | 324 | self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1), 325 | [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True) 326 | d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5) 327 | d2 = tf.concat([d2, e6], 3) 328 | # d2 is (4 x 4 x self.gf_dim*8*2) 329 | 330 | self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2), 331 | [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True) 332 | d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5) 333 | d3 = tf.concat([d3, e5], 3) 334 | # d3 is (8 x 8 x self.gf_dim*8*2) 335 | 336 | self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3), 337 | [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True) 338 | d4 = self.g_bn_d4(self.d4) 339 | d4 = tf.concat([d4, e4], 3) 340 | # d4 is (16 x 16 x self.gf_dim*8*2) 341 | 342 | self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4), 343 | [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True) 344 | d5 = self.g_bn_d5(self.d5) 345 | d5 = tf.concat([d5, e3], 3) 346 | # d5 is (32 x 32 x self.gf_dim*4*2) 347 | 348 | self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5), 349 | [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True) 350 | d6 = self.g_bn_d6(self.d6) 351 | d6 = tf.concat([d6, e2], 3) 352 | # d6 is (64 x 64 x self.gf_dim*2*2) 353 | 354 | self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6), 355 | [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True) 356 | d7 = self.g_bn_d7(self.d7) 357 | d7 = tf.concat([d7, e1], 3) 358 | # d7 is (128 x 128 x self.gf_dim*1*2) 359 | 360 | self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7), 361 | [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True) 362 | # d8 is (256 x 256 x output_c_dim) 363 | 364 | return tf.nn.tanh(self.d8) 365 | 366 | def save(self, checkpoint_dir, step): 367 | model_name = "pix2pix.model" 368 | model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size) 369 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 370 | 371 | if not os.path.exists(checkpoint_dir): 372 | os.makedirs(checkpoint_dir) 373 | 374 | self.saver.save(self.sess, 375 | os.path.join(checkpoint_dir, model_name), 376 | global_step=step) 377 | 378 | def load(self, checkpoint_dir): 379 | print(" [*] Reading checkpoint...") 380 | 381 | model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size) 382 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 383 | 384 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 385 | if ckpt and ckpt.model_checkpoint_path: 386 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 387 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 388 | return True 389 | else: 390 | return False 391 | 392 | def test(self, args): 393 | """Test pix2pix""" 394 | init_op = tf.global_variables_initializer() 395 | self.sess.run(init_op) 396 | 397 | sample_files = glob('./datasets/{}/val/*.jpg'.format(self.dataset_name)) 398 | 399 | # sort testing input 400 | n = [int(i) for i in map(lambda x: x.split('/')[-1].split('.jpg')[0], sample_files)] 401 | sample_files = [x for (y, x) in sorted(zip(n, sample_files))] 402 | 403 | # load testing input 404 | print("Loading testing images ...") 405 | sample = [load_data(sample_file, is_test=True) for sample_file in sample_files] 406 | 407 | if (self.is_grayscale): 408 | sample_images = np.array(sample).astype(np.float32)[:, :, :, None] 409 | else: 410 | sample_images = np.array(sample).astype(np.float32) 411 | 412 | sample_images = [sample_images[i:i+self.batch_size] 413 | for i in xrange(0, len(sample_images), self.batch_size)] 414 | sample_images = np.array(sample_images) 415 | print(sample_images.shape) 416 | 417 | start_time = time.time() 418 | if self.load(self.checkpoint_dir): 419 | print(" [*] Load SUCCESS") 420 | else: 421 | print(" [!] Load failed...") 422 | 423 | for i, sample_image in enumerate(sample_images): 424 | idx = i+1 425 | print("sampling image ", idx) 426 | samples = self.sess.run( 427 | self.fake_B_sample, 428 | feed_dict={self.real_data: sample_image} 429 | ) 430 | save_images(samples, [self.batch_size, 1], 431 | './{}/test_{:04d}.png'.format(args.test_dir, idx)) 432 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | 7 | from utils import * 8 | 9 | class batch_norm(object): 10 | # h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, self.df_dim*2, name='d_h1_conv'),decay=0.9,updates_collections=None,epsilon=0.00001,scale=True,scope="d_h1_conv")) 11 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 12 | with tf.variable_scope(name): 13 | self.epsilon = epsilon 14 | self.momentum = momentum 15 | self.name = name 16 | 17 | def __call__(self, x, train=True): 18 | return tf.contrib.layers.batch_norm(x, decay=self.momentum, updates_collections=None, epsilon=self.epsilon, scale=True, scope=self.name) 19 | 20 | def binary_cross_entropy(preds, targets, name=None): 21 | """Computes binary cross entropy given `preds`. 22 | 23 | For brevity, let `x = `, `z = targets`. The logistic loss is 24 | 25 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 26 | 27 | Args: 28 | preds: A `Tensor` of type `float32` or `float64`. 29 | targets: A `Tensor` of the same type and shape as `preds`. 30 | """ 31 | eps = 1e-12 32 | with ops.op_scope([preds, targets], name, "bce_loss") as name: 33 | preds = ops.convert_to_tensor(preds, name="preds") 34 | targets = ops.convert_to_tensor(targets, name="targets") 35 | return tf.reduce_mean(-(targets * tf.log(preds + eps) + 36 | (1. - targets) * tf.log(1. - preds + eps))) 37 | 38 | def conv_cond_concat(x, y): 39 | """Concatenate conditioning vector on feature map axis.""" 40 | x_shapes = x.get_shape() 41 | y_shapes = y.get_shape() 42 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 43 | 44 | def conv2d(input_, output_dim, 45 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 46 | name="conv2d"): 47 | with tf.variable_scope(name): 48 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 49 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 50 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 51 | 52 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 53 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 54 | 55 | return conv 56 | 57 | def deconv2d(input_, output_shape, 58 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 59 | name="deconv2d", with_w=False): 60 | with tf.variable_scope(name): 61 | # filter : [height, width, output_channels, in_channels] 62 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 63 | initializer=tf.random_normal_initializer(stddev=stddev)) 64 | 65 | try: 66 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 67 | strides=[1, d_h, d_w, 1]) 68 | 69 | # Support for verisons of TensorFlow before 0.7.0 70 | except AttributeError: 71 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 72 | strides=[1, d_h, d_w, 1]) 73 | 74 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 75 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 76 | 77 | if with_w: 78 | return deconv, w, biases 79 | else: 80 | return deconv 81 | 82 | 83 | def lrelu(x, leak=0.2, name="lrelu"): 84 | return tf.maximum(x, leak*x) 85 | 86 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 87 | shape = input_.get_shape().as_list() 88 | 89 | with tf.variable_scope(scope or "Linear"): 90 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 91 | tf.random_normal_initializer(stddev=stddev)) 92 | bias = tf.get_variable("bias", [output_size], 93 | initializer=tf.constant_initializer(bias_start)) 94 | if with_w: 95 | return tf.matmul(input_, matrix) + bias, matrix, bias 96 | else: 97 | return tf.matmul(input_, matrix) + bias 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu 2 | numpy 3 | scipy 4 | pillow 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import json 7 | import random 8 | import pprint 9 | import scipy.misc 10 | import numpy as np 11 | from time import gmtime, strftime 12 | 13 | pp = pprint.PrettyPrinter() 14 | 15 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 16 | 17 | # ----------------------------- 18 | # new added functions for pix2pix 19 | 20 | def load_data(image_path, flip=True, is_test=False): 21 | img_A, img_B = load_image(image_path) 22 | img_A, img_B = preprocess_A_and_B(img_A, img_B, flip=flip, is_test=is_test) 23 | 24 | img_A = img_A/127.5 - 1. 25 | img_B = img_B/127.5 - 1. 26 | 27 | img_AB = np.concatenate((img_A, img_B), axis=2) 28 | # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim) 29 | return img_AB 30 | 31 | def load_image(image_path): 32 | input_img = imread(image_path) 33 | w = int(input_img.shape[1]) 34 | w2 = int(w/2) 35 | img_A = input_img[:, 0:w2] 36 | img_B = input_img[:, w2:w] 37 | 38 | return img_A, img_B 39 | 40 | def preprocess_A_and_B(img_A, img_B, load_size=286, fine_size=256, flip=True, is_test=False): 41 | if is_test: 42 | img_A = scipy.misc.imresize(img_A, [fine_size, fine_size]) 43 | img_B = scipy.misc.imresize(img_B, [fine_size, fine_size]) 44 | else: 45 | img_A = scipy.misc.imresize(img_A, [load_size, load_size]) 46 | img_B = scipy.misc.imresize(img_B, [load_size, load_size]) 47 | 48 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 49 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 50 | img_A = img_A[h1:h1+fine_size, w1:w1+fine_size] 51 | img_B = img_B[h1:h1+fine_size, w1:w1+fine_size] 52 | 53 | if flip and np.random.random() > 0.5: 54 | img_A = np.fliplr(img_A) 55 | img_B = np.fliplr(img_B) 56 | 57 | return img_A, img_B 58 | 59 | # ----------------------------- 60 | 61 | def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False): 62 | return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w) 63 | 64 | def save_images(images, size, image_path): 65 | return imsave(inverse_transform(images), size, image_path) 66 | 67 | def imread(path, is_grayscale = False): 68 | if (is_grayscale): 69 | return scipy.misc.imread(path, flatten = True).astype(np.float) 70 | else: 71 | return scipy.misc.imread(path).astype(np.float) 72 | 73 | def merge_images(images, size): 74 | return inverse_transform(images) 75 | 76 | def merge(images, size): 77 | h, w = images.shape[1], images.shape[2] 78 | img = np.zeros((h * size[0], w * size[1], 3)) 79 | for idx, image in enumerate(images): 80 | i = idx % size[1] 81 | j = idx // size[1] 82 | img[j*h:j*h+h, i*w:i*w+w, :] = image 83 | 84 | return img 85 | 86 | def imsave(images, size, path): 87 | return scipy.misc.imsave(path, merge(images, size)) 88 | 89 | def transform(image, npx=64, is_crop=True, resize_w=64): 90 | # npx : # of pixels width/height of image 91 | if is_crop: 92 | cropped_image = center_crop(image, npx, resize_w=resize_w) 93 | else: 94 | cropped_image = image 95 | return np.array(cropped_image)/127.5 - 1. 96 | 97 | def inverse_transform(images): 98 | return (images+1.)/2. 99 | 100 | 101 | --------------------------------------------------------------------------------