├── .gitignore
├── LICENSE
├── README.md
├── download_dataset.sh
├── examples.jpg
├── label_to_facades.png
├── main.py
├── model.py
├── ops.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | sample/*
2 | logs/*
3 | test/*
4 | datasets/*
5 | checkpoint/*
6 | val/*
7 | *.pyc
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2016-2018 Yen-Chen Lin http://yclin.me/
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pix2pix-tensorflow
2 |
3 | TensorFlow implementation of [Image-to-Image Translation Using Conditional Adversarial Networks](https://arxiv.org/pdf/1611.07004v1.pdf) that learns a mapping from input images to output images.
4 |
5 | Here are some results generated by the authors of paper:
6 |
7 |
8 |
9 | ## Setup
10 |
11 | ### Prerequisites
12 | - Linux
13 | - Python with numpy
14 | - NVIDIA GPU + CUDA 8.0 + CuDNNv5.1
15 | - TensorFlow 0.11
16 |
17 | ### Getting Started
18 | - Clone this repo:
19 | ```bash
20 | git clone git@github.com:yenchenlin/pix2pix-tensorflow.git
21 | cd pix2pix-tensorflow
22 | ```
23 | - Download the dataset (script borrowed from [torch code](https://github.com/phillipi/pix2pix/blob/master/datasets/download_dataset.sh)):
24 | ```bash
25 | bash ./download_dataset.sh facades
26 | ```
27 | - Train the model
28 | ```bash
29 | python main.py --phase train
30 | ```
31 | - Test the model:
32 | ```bash
33 | python main.py --phase test
34 | ```
35 |
36 | ## Results
37 | Here is the results generated from this implementation:
38 |
39 | - Facades:
40 |
41 |
42 |
43 | More results on other datasets coming soon!
44 |
45 | **Note**: To avoid the fast convergence of D (discriminator) network, G (generator) network is updated twice for each D network update, which differs from original paper but same as [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow), which this project based on.
46 |
47 | ## Train
48 | Code currently supports [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/) dataset. To reproduce results presented above, it takes 200 epochs of training. Exact computing time depends on own hardware conditions.
49 |
50 | ## Test
51 | Test the model on validation set of [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/) dataset. It will generate synthesized images provided corresponding labels under directory `./test`.
52 |
53 |
54 | ## Acknowledgments
55 | Code borrows heavily from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow/blob/master/model.py). Thanks for their excellent work!
56 |
57 | ## License
58 | MIT
59 |
--------------------------------------------------------------------------------
/download_dataset.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | FILE=$1
3 | URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
4 | TAR_FILE=./datasets/$FILE.tar.gz
5 | TARGET_DIR=./datasets/$FILE/
6 | wget -N $URL -O $TAR_FILE
7 | mkdir $TARGET_DIR
8 | tar -zxvf $TAR_FILE -C ./datasets/
9 | rm $TAR_FILE
10 |
--------------------------------------------------------------------------------
/examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yenchenlin/pix2pix-tensorflow/ba40020706ad3a1fbefa1da7bc7a05b7b031fb9e/examples.jpg
--------------------------------------------------------------------------------
/label_to_facades.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yenchenlin/pix2pix-tensorflow/ba40020706ad3a1fbefa1da7bc7a05b7b031fb9e/label_to_facades.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import scipy.misc
4 | import numpy as np
5 |
6 | from model import pix2pix
7 | import tensorflow as tf
8 |
9 | parser = argparse.ArgumentParser(description='')
10 | parser.add_argument('--dataset_name', dest='dataset_name', default='facades', help='name of the dataset')
11 | parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
12 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
13 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
14 | parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
15 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
16 | parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
17 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
18 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
19 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
20 | parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate')
21 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
22 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
23 | parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation')
24 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
25 | parser.add_argument('--phase', dest='phase', default='train', help='train, test')
26 | parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=50, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)')
27 | parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)')
28 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=50, help='print the debug information every print_freq iterations')
29 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
30 | parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly')
31 | parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list')
32 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
33 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
34 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
35 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=100.0, help='weight on L1 term in objective')
36 |
37 | args = parser.parse_args()
38 |
39 | def main(_):
40 | if not os.path.exists(args.checkpoint_dir):
41 | os.makedirs(args.checkpoint_dir)
42 | if not os.path.exists(args.sample_dir):
43 | os.makedirs(args.sample_dir)
44 | if not os.path.exists(args.test_dir):
45 | os.makedirs(args.test_dir)
46 |
47 | with tf.Session() as sess:
48 | model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size,
49 | output_size=args.fine_size, dataset_name=args.dataset_name,
50 | checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir)
51 |
52 | if args.phase == 'train':
53 | model.train(args)
54 | else:
55 | model.test(args)
56 |
57 | if __name__ == '__main__':
58 | tf.app.run()
59 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import time
4 | from glob import glob
5 | import tensorflow as tf
6 | import numpy as np
7 | from six.moves import xrange
8 |
9 | from ops import *
10 | from utils import *
11 |
12 | class pix2pix(object):
13 | def __init__(self, sess, image_size=256,
14 | batch_size=1, sample_size=1, output_size=256,
15 | gf_dim=64, df_dim=64, L1_lambda=100,
16 | input_c_dim=3, output_c_dim=3, dataset_name='facades',
17 | checkpoint_dir=None, sample_dir=None):
18 | """
19 |
20 | Args:
21 | sess: TensorFlow session
22 | batch_size: The size of batch. Should be specified before training.
23 | output_size: (optional) The resolution in pixels of the images. [256]
24 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
25 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
26 | input_c_dim: (optional) Dimension of input image color. For grayscale input, set to 1. [3]
27 | output_c_dim: (optional) Dimension of output image color. For grayscale input, set to 1. [3]
28 | """
29 | self.sess = sess
30 | self.is_grayscale = (input_c_dim == 1)
31 | self.batch_size = batch_size
32 | self.image_size = image_size
33 | self.sample_size = sample_size
34 | self.output_size = output_size
35 |
36 | self.gf_dim = gf_dim
37 | self.df_dim = df_dim
38 |
39 | self.input_c_dim = input_c_dim
40 | self.output_c_dim = output_c_dim
41 |
42 | self.L1_lambda = L1_lambda
43 |
44 | # batch normalization : deals with poor initialization helps gradient flow
45 | self.d_bn1 = batch_norm(name='d_bn1')
46 | self.d_bn2 = batch_norm(name='d_bn2')
47 | self.d_bn3 = batch_norm(name='d_bn3')
48 |
49 | self.g_bn_e2 = batch_norm(name='g_bn_e2')
50 | self.g_bn_e3 = batch_norm(name='g_bn_e3')
51 | self.g_bn_e4 = batch_norm(name='g_bn_e4')
52 | self.g_bn_e5 = batch_norm(name='g_bn_e5')
53 | self.g_bn_e6 = batch_norm(name='g_bn_e6')
54 | self.g_bn_e7 = batch_norm(name='g_bn_e7')
55 | self.g_bn_e8 = batch_norm(name='g_bn_e8')
56 |
57 | self.g_bn_d1 = batch_norm(name='g_bn_d1')
58 | self.g_bn_d2 = batch_norm(name='g_bn_d2')
59 | self.g_bn_d3 = batch_norm(name='g_bn_d3')
60 | self.g_bn_d4 = batch_norm(name='g_bn_d4')
61 | self.g_bn_d5 = batch_norm(name='g_bn_d5')
62 | self.g_bn_d6 = batch_norm(name='g_bn_d6')
63 | self.g_bn_d7 = batch_norm(name='g_bn_d7')
64 |
65 | self.dataset_name = dataset_name
66 | self.checkpoint_dir = checkpoint_dir
67 | self.build_model()
68 |
69 | def build_model(self):
70 | self.real_data = tf.placeholder(tf.float32,
71 | [self.batch_size, self.image_size, self.image_size,
72 | self.input_c_dim + self.output_c_dim],
73 | name='real_A_and_B_images')
74 |
75 | self.real_B = self.real_data[:, :, :, :self.input_c_dim]
76 | self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
77 |
78 | self.fake_B = self.generator(self.real_A)
79 |
80 | self.real_AB = tf.concat([self.real_A, self.real_B], 3)
81 | self.fake_AB = tf.concat([self.real_A, self.fake_B], 3)
82 | self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False)
83 | self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)
84 |
85 | self.fake_B_sample = self.sampler(self.real_A)
86 |
87 | self.d_sum = tf.summary.histogram("d", self.D)
88 | self.d__sum = tf.summary.histogram("d_", self.D_)
89 | self.fake_B_sum = tf.summary.image("fake_B", self.fake_B)
90 |
91 | self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
92 | self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
93 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
94 | + self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))
95 |
96 | self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
97 | self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
98 |
99 | self.d_loss = self.d_loss_real + self.d_loss_fake
100 |
101 | self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
102 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
103 |
104 | t_vars = tf.trainable_variables()
105 |
106 | self.d_vars = [var for var in t_vars if 'd_' in var.name]
107 | self.g_vars = [var for var in t_vars if 'g_' in var.name]
108 |
109 | self.saver = tf.train.Saver()
110 |
111 |
112 | def load_random_samples(self):
113 | data = np.random.choice(glob('./datasets/{}/val/*.jpg'.format(self.dataset_name)), self.batch_size)
114 | sample = [load_data(sample_file) for sample_file in data]
115 |
116 | if (self.is_grayscale):
117 | sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
118 | else:
119 | sample_images = np.array(sample).astype(np.float32)
120 | return sample_images
121 |
122 | def sample_model(self, sample_dir, epoch, idx):
123 | sample_images = self.load_random_samples()
124 | samples, d_loss, g_loss = self.sess.run(
125 | [self.fake_B_sample, self.d_loss, self.g_loss],
126 | feed_dict={self.real_data: sample_images}
127 | )
128 | save_images(samples, [self.batch_size, 1],
129 | './{}/train_{:02d}_{:04d}.png'.format(sample_dir, epoch, idx))
130 | print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))
131 |
132 | def train(self, args):
133 | """Train pix2pix"""
134 | d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
135 | .minimize(self.d_loss, var_list=self.d_vars)
136 | g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
137 | .minimize(self.g_loss, var_list=self.g_vars)
138 |
139 | init_op = tf.global_variables_initializer()
140 | self.sess.run(init_op)
141 |
142 | self.g_sum = tf.summary.merge([self.d__sum,
143 | self.fake_B_sum, self.d_loss_fake_sum, self.g_loss_sum])
144 | self.d_sum = tf.summary.merge([self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
145 | self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
146 |
147 | counter = 1
148 | start_time = time.time()
149 |
150 | if self.load(self.checkpoint_dir):
151 | print(" [*] Load SUCCESS")
152 | else:
153 | print(" [!] Load failed...")
154 |
155 | for epoch in xrange(args.epoch):
156 | data = glob('./datasets/{}/train/*.jpg'.format(self.dataset_name))
157 | #np.random.shuffle(data)
158 | batch_idxs = min(len(data), args.train_size) // self.batch_size
159 |
160 | for idx in xrange(0, batch_idxs):
161 | batch_files = data[idx*self.batch_size:(idx+1)*self.batch_size]
162 | batch = [load_data(batch_file) for batch_file in batch_files]
163 | if (self.is_grayscale):
164 | batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
165 | else:
166 | batch_images = np.array(batch).astype(np.float32)
167 |
168 | # Update D network
169 | _, summary_str = self.sess.run([d_optim, self.d_sum],
170 | feed_dict={ self.real_data: batch_images })
171 | self.writer.add_summary(summary_str, counter)
172 |
173 | # Update G network
174 | _, summary_str = self.sess.run([g_optim, self.g_sum],
175 | feed_dict={ self.real_data: batch_images })
176 | self.writer.add_summary(summary_str, counter)
177 |
178 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
179 | _, summary_str = self.sess.run([g_optim, self.g_sum],
180 | feed_dict={ self.real_data: batch_images })
181 | self.writer.add_summary(summary_str, counter)
182 |
183 | errD_fake = self.d_loss_fake.eval({self.real_data: batch_images})
184 | errD_real = self.d_loss_real.eval({self.real_data: batch_images})
185 | errG = self.g_loss.eval({self.real_data: batch_images})
186 |
187 | counter += 1
188 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
189 | % (epoch, idx, batch_idxs,
190 | time.time() - start_time, errD_fake+errD_real, errG))
191 |
192 | if np.mod(counter, 100) == 1:
193 | self.sample_model(args.sample_dir, epoch, idx)
194 |
195 | if np.mod(counter, 500) == 2:
196 | self.save(args.checkpoint_dir, counter)
197 |
198 | def discriminator(self, image, y=None, reuse=False):
199 |
200 | with tf.variable_scope("discriminator") as scope:
201 |
202 | # image is 256 x 256 x (input_c_dim + output_c_dim)
203 | if reuse:
204 | tf.get_variable_scope().reuse_variables()
205 | else:
206 | assert tf.get_variable_scope().reuse == False
207 |
208 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))
209 | # h0 is (128 x 128 x self.df_dim)
210 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))
211 | # h1 is (64 x 64 x self.df_dim*2)
212 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))
213 | # h2 is (32x 32 x self.df_dim*4)
214 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, d_h=1, d_w=1, name='d_h3_conv')))
215 | # h3 is (16 x 16 x self.df_dim*8)
216 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')
217 |
218 | return tf.nn.sigmoid(h4), h4
219 |
220 | def generator(self, image, y=None):
221 | with tf.variable_scope("generator") as scope:
222 |
223 | s = self.output_size
224 | s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)
225 |
226 | # image is (256 x 256 x input_c_dim)
227 | e1 = conv2d(image, self.gf_dim, name='g_e1_conv')
228 | # e1 is (128 x 128 x self.gf_dim)
229 | e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))
230 | # e2 is (64 x 64 x self.gf_dim*2)
231 | e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))
232 | # e3 is (32 x 32 x self.gf_dim*4)
233 | e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))
234 | # e4 is (16 x 16 x self.gf_dim*8)
235 | e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))
236 | # e5 is (8 x 8 x self.gf_dim*8)
237 | e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))
238 | # e6 is (4 x 4 x self.gf_dim*8)
239 | e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))
240 | # e7 is (2 x 2 x self.gf_dim*8)
241 | e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))
242 | # e8 is (1 x 1 x self.gf_dim*8)
243 |
244 | self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
245 | [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)
246 | d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)
247 | d1 = tf.concat([d1, e7], 3)
248 | # d1 is (2 x 2 x self.gf_dim*8*2)
249 |
250 | self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
251 | [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)
252 | d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)
253 | d2 = tf.concat([d2, e6], 3)
254 | # d2 is (4 x 4 x self.gf_dim*8*2)
255 |
256 | self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
257 | [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)
258 | d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)
259 | d3 = tf.concat([d3, e5], 3)
260 | # d3 is (8 x 8 x self.gf_dim*8*2)
261 |
262 | self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
263 | [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)
264 | d4 = self.g_bn_d4(self.d4)
265 | d4 = tf.concat([d4, e4], 3)
266 | # d4 is (16 x 16 x self.gf_dim*8*2)
267 |
268 | self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
269 | [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)
270 | d5 = self.g_bn_d5(self.d5)
271 | d5 = tf.concat([d5, e3], 3)
272 | # d5 is (32 x 32 x self.gf_dim*4*2)
273 |
274 | self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
275 | [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)
276 | d6 = self.g_bn_d6(self.d6)
277 | d6 = tf.concat([d6, e2], 3)
278 | # d6 is (64 x 64 x self.gf_dim*2*2)
279 |
280 | self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
281 | [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)
282 | d7 = self.g_bn_d7(self.d7)
283 | d7 = tf.concat([d7, e1], 3)
284 | # d7 is (128 x 128 x self.gf_dim*1*2)
285 |
286 | self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),
287 | [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)
288 | # d8 is (256 x 256 x output_c_dim)
289 |
290 | return tf.nn.tanh(self.d8)
291 |
292 | def sampler(self, image, y=None):
293 |
294 | with tf.variable_scope("generator") as scope:
295 | scope.reuse_variables()
296 |
297 | s = self.output_size
298 | s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)
299 |
300 | # image is (256 x 256 x input_c_dim)
301 | e1 = conv2d(image, self.gf_dim, name='g_e1_conv')
302 | # e1 is (128 x 128 x self.gf_dim)
303 | e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))
304 | # e2 is (64 x 64 x self.gf_dim*2)
305 | e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))
306 | # e3 is (32 x 32 x self.gf_dim*4)
307 | e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))
308 | # e4 is (16 x 16 x self.gf_dim*8)
309 | e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))
310 | # e5 is (8 x 8 x self.gf_dim*8)
311 | e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))
312 | # e6 is (4 x 4 x self.gf_dim*8)
313 | e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))
314 | # e7 is (2 x 2 x self.gf_dim*8)
315 | e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))
316 | # e8 is (1 x 1 x self.gf_dim*8)
317 |
318 | self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
319 | [self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)
320 | d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)
321 | d1 = tf.concat([d1, e7], 3)
322 | # d1 is (2 x 2 x self.gf_dim*8*2)
323 |
324 | self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
325 | [self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)
326 | d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)
327 | d2 = tf.concat([d2, e6], 3)
328 | # d2 is (4 x 4 x self.gf_dim*8*2)
329 |
330 | self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
331 | [self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)
332 | d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)
333 | d3 = tf.concat([d3, e5], 3)
334 | # d3 is (8 x 8 x self.gf_dim*8*2)
335 |
336 | self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
337 | [self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)
338 | d4 = self.g_bn_d4(self.d4)
339 | d4 = tf.concat([d4, e4], 3)
340 | # d4 is (16 x 16 x self.gf_dim*8*2)
341 |
342 | self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
343 | [self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)
344 | d5 = self.g_bn_d5(self.d5)
345 | d5 = tf.concat([d5, e3], 3)
346 | # d5 is (32 x 32 x self.gf_dim*4*2)
347 |
348 | self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
349 | [self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)
350 | d6 = self.g_bn_d6(self.d6)
351 | d6 = tf.concat([d6, e2], 3)
352 | # d6 is (64 x 64 x self.gf_dim*2*2)
353 |
354 | self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
355 | [self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)
356 | d7 = self.g_bn_d7(self.d7)
357 | d7 = tf.concat([d7, e1], 3)
358 | # d7 is (128 x 128 x self.gf_dim*1*2)
359 |
360 | self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),
361 | [self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)
362 | # d8 is (256 x 256 x output_c_dim)
363 |
364 | return tf.nn.tanh(self.d8)
365 |
366 | def save(self, checkpoint_dir, step):
367 | model_name = "pix2pix.model"
368 | model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
369 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
370 |
371 | if not os.path.exists(checkpoint_dir):
372 | os.makedirs(checkpoint_dir)
373 |
374 | self.saver.save(self.sess,
375 | os.path.join(checkpoint_dir, model_name),
376 | global_step=step)
377 |
378 | def load(self, checkpoint_dir):
379 | print(" [*] Reading checkpoint...")
380 |
381 | model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
382 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
383 |
384 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
385 | if ckpt and ckpt.model_checkpoint_path:
386 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
387 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
388 | return True
389 | else:
390 | return False
391 |
392 | def test(self, args):
393 | """Test pix2pix"""
394 | init_op = tf.global_variables_initializer()
395 | self.sess.run(init_op)
396 |
397 | sample_files = glob('./datasets/{}/val/*.jpg'.format(self.dataset_name))
398 |
399 | # sort testing input
400 | n = [int(i) for i in map(lambda x: x.split('/')[-1].split('.jpg')[0], sample_files)]
401 | sample_files = [x for (y, x) in sorted(zip(n, sample_files))]
402 |
403 | # load testing input
404 | print("Loading testing images ...")
405 | sample = [load_data(sample_file, is_test=True) for sample_file in sample_files]
406 |
407 | if (self.is_grayscale):
408 | sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
409 | else:
410 | sample_images = np.array(sample).astype(np.float32)
411 |
412 | sample_images = [sample_images[i:i+self.batch_size]
413 | for i in xrange(0, len(sample_images), self.batch_size)]
414 | sample_images = np.array(sample_images)
415 | print(sample_images.shape)
416 |
417 | start_time = time.time()
418 | if self.load(self.checkpoint_dir):
419 | print(" [*] Load SUCCESS")
420 | else:
421 | print(" [!] Load failed...")
422 |
423 | for i, sample_image in enumerate(sample_images):
424 | idx = i+1
425 | print("sampling image ", idx)
426 | samples = self.sess.run(
427 | self.fake_B_sample,
428 | feed_dict={self.real_data: sample_image}
429 | )
430 | save_images(samples, [self.batch_size, 1],
431 | './{}/test_{:04d}.png'.format(args.test_dir, idx))
432 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from tensorflow.python.framework import ops
6 |
7 | from utils import *
8 |
9 | class batch_norm(object):
10 | # h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, self.df_dim*2, name='d_h1_conv'),decay=0.9,updates_collections=None,epsilon=0.00001,scale=True,scope="d_h1_conv"))
11 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
12 | with tf.variable_scope(name):
13 | self.epsilon = epsilon
14 | self.momentum = momentum
15 | self.name = name
16 |
17 | def __call__(self, x, train=True):
18 | return tf.contrib.layers.batch_norm(x, decay=self.momentum, updates_collections=None, epsilon=self.epsilon, scale=True, scope=self.name)
19 |
20 | def binary_cross_entropy(preds, targets, name=None):
21 | """Computes binary cross entropy given `preds`.
22 |
23 | For brevity, let `x = `, `z = targets`. The logistic loss is
24 |
25 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
26 |
27 | Args:
28 | preds: A `Tensor` of type `float32` or `float64`.
29 | targets: A `Tensor` of the same type and shape as `preds`.
30 | """
31 | eps = 1e-12
32 | with ops.op_scope([preds, targets], name, "bce_loss") as name:
33 | preds = ops.convert_to_tensor(preds, name="preds")
34 | targets = ops.convert_to_tensor(targets, name="targets")
35 | return tf.reduce_mean(-(targets * tf.log(preds + eps) +
36 | (1. - targets) * tf.log(1. - preds + eps)))
37 |
38 | def conv_cond_concat(x, y):
39 | """Concatenate conditioning vector on feature map axis."""
40 | x_shapes = x.get_shape()
41 | y_shapes = y.get_shape()
42 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
43 |
44 | def conv2d(input_, output_dim,
45 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
46 | name="conv2d"):
47 | with tf.variable_scope(name):
48 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
49 | initializer=tf.truncated_normal_initializer(stddev=stddev))
50 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
51 |
52 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
53 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
54 |
55 | return conv
56 |
57 | def deconv2d(input_, output_shape,
58 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
59 | name="deconv2d", with_w=False):
60 | with tf.variable_scope(name):
61 | # filter : [height, width, output_channels, in_channels]
62 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
63 | initializer=tf.random_normal_initializer(stddev=stddev))
64 |
65 | try:
66 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
67 | strides=[1, d_h, d_w, 1])
68 |
69 | # Support for verisons of TensorFlow before 0.7.0
70 | except AttributeError:
71 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
72 | strides=[1, d_h, d_w, 1])
73 |
74 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
75 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
76 |
77 | if with_w:
78 | return deconv, w, biases
79 | else:
80 | return deconv
81 |
82 |
83 | def lrelu(x, leak=0.2, name="lrelu"):
84 | return tf.maximum(x, leak*x)
85 |
86 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
87 | shape = input_.get_shape().as_list()
88 |
89 | with tf.variable_scope(scope or "Linear"):
90 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
91 | tf.random_normal_initializer(stddev=stddev))
92 | bias = tf.get_variable("bias", [output_size],
93 | initializer=tf.constant_initializer(bias_start))
94 | if with_w:
95 | return tf.matmul(input_, matrix) + bias, matrix, bias
96 | else:
97 | return tf.matmul(input_, matrix) + bias
98 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu
2 | numpy
3 | scipy
4 | pillow
5 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some codes from https://github.com/Newmu/dcgan_code
3 | """
4 | from __future__ import division
5 | import math
6 | import json
7 | import random
8 | import pprint
9 | import scipy.misc
10 | import numpy as np
11 | from time import gmtime, strftime
12 |
13 | pp = pprint.PrettyPrinter()
14 |
15 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
16 |
17 | # -----------------------------
18 | # new added functions for pix2pix
19 |
20 | def load_data(image_path, flip=True, is_test=False):
21 | img_A, img_B = load_image(image_path)
22 | img_A, img_B = preprocess_A_and_B(img_A, img_B, flip=flip, is_test=is_test)
23 |
24 | img_A = img_A/127.5 - 1.
25 | img_B = img_B/127.5 - 1.
26 |
27 | img_AB = np.concatenate((img_A, img_B), axis=2)
28 | # img_AB shape: (fine_size, fine_size, input_c_dim + output_c_dim)
29 | return img_AB
30 |
31 | def load_image(image_path):
32 | input_img = imread(image_path)
33 | w = int(input_img.shape[1])
34 | w2 = int(w/2)
35 | img_A = input_img[:, 0:w2]
36 | img_B = input_img[:, w2:w]
37 |
38 | return img_A, img_B
39 |
40 | def preprocess_A_and_B(img_A, img_B, load_size=286, fine_size=256, flip=True, is_test=False):
41 | if is_test:
42 | img_A = scipy.misc.imresize(img_A, [fine_size, fine_size])
43 | img_B = scipy.misc.imresize(img_B, [fine_size, fine_size])
44 | else:
45 | img_A = scipy.misc.imresize(img_A, [load_size, load_size])
46 | img_B = scipy.misc.imresize(img_B, [load_size, load_size])
47 |
48 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
49 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size)))
50 | img_A = img_A[h1:h1+fine_size, w1:w1+fine_size]
51 | img_B = img_B[h1:h1+fine_size, w1:w1+fine_size]
52 |
53 | if flip and np.random.random() > 0.5:
54 | img_A = np.fliplr(img_A)
55 | img_B = np.fliplr(img_B)
56 |
57 | return img_A, img_B
58 |
59 | # -----------------------------
60 |
61 | def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
62 | return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)
63 |
64 | def save_images(images, size, image_path):
65 | return imsave(inverse_transform(images), size, image_path)
66 |
67 | def imread(path, is_grayscale = False):
68 | if (is_grayscale):
69 | return scipy.misc.imread(path, flatten = True).astype(np.float)
70 | else:
71 | return scipy.misc.imread(path).astype(np.float)
72 |
73 | def merge_images(images, size):
74 | return inverse_transform(images)
75 |
76 | def merge(images, size):
77 | h, w = images.shape[1], images.shape[2]
78 | img = np.zeros((h * size[0], w * size[1], 3))
79 | for idx, image in enumerate(images):
80 | i = idx % size[1]
81 | j = idx // size[1]
82 | img[j*h:j*h+h, i*w:i*w+w, :] = image
83 |
84 | return img
85 |
86 | def imsave(images, size, path):
87 | return scipy.misc.imsave(path, merge(images, size))
88 |
89 | def transform(image, npx=64, is_crop=True, resize_w=64):
90 | # npx : # of pixels width/height of image
91 | if is_crop:
92 | cropped_image = center_crop(image, npx, resize_w=resize_w)
93 | else:
94 | cropped_image = image
95 | return np.array(cropped_image)/127.5 - 1.
96 |
97 | def inverse_transform(images):
98 | return (images+1.)/2.
99 |
100 |
101 |
--------------------------------------------------------------------------------