├── images ├── .gitkeep ├── iter_0.jpg ├── iter_500.jpg └── iter_1000.jpg ├── model ├── __init__.py ├── utils.py ├── discriminator.py └── generator.py ├── models └── .gitkeep ├── README.md ├── example.py └── pix2pix.py /images/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/iter_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-pix2pix/HEAD/images/iter_0.jpg -------------------------------------------------------------------------------- /images/iter_500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-pix2pix/HEAD/images/iter_500.jpg -------------------------------------------------------------------------------- /images/iter_1000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-pix2pix/HEAD/images/iter_1000.jpg -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_shape(tensor): 4 | return tensor.get_shape().as_list() 5 | 6 | def batch_norm(*args, **kwargs): 7 | with tf.name_scope('bn'): 8 | bn = tf.layers.batch_normalization(*args, **kwargs) 9 | return bn 10 | 11 | def lkrelu(x, slope=0.01): 12 | return tf.maximum(slope * x, x) 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-pix2pix 2 | A lightweight [pix2pix](https://arxiv.org/abs/1611.07004) Tensorflow implementation. 3 | 4 | [@eyyub_s](https://twitter.com/eyyub_s) 5 | 6 | ![](https://raw.githubusercontent.com/phillipi/pix2pix/master/imgs/examples.jpg) 7 | 8 | ## Build the example 9 | First you need to download the [CMP Facade](http://cmp.felk.cvut.cz/~tylecr1/facade/) dataset. 10 | 11 | Then, run `python build_dataset.py ' PatchGAN 32 | discriminator['l1'] = self._build_layer('l1', inputs, 64, bn=False) 33 | discriminator['l2'] = self._build_layer('l2', discriminator['l1']['fmap'], 128) 34 | discriminator['l3'] = self._build_layer('l3', discriminator['l2']['fmap'], 256) 35 | discriminator['l4'] = self._build_layer('l4', discriminator['l3']['fmap'], 512) 36 | with tf.variable_scope('l5'): 37 | l5 = dict() 38 | l5['filters'] = tf.get_variable('filters', [4, 4, get_shape(discriminator['l4']['fmap'])[-1], 1]) 39 | l5['conv'] = tf.nn.conv2d(discriminator['l4']['fmap'], l5['filters'], strides=[1, 1, 1, 1], padding='SAME') 40 | l5['bn'] = batch_norm(l5['conv'], center=self._center, scale=self._scale, training=self._is_training) 41 | l5['fmap'] = tf.nn.sigmoid(l5['bn']) 42 | discriminator['l5'] = l5 43 | return discriminator 44 | -------------------------------------------------------------------------------- /pix2pix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model.discriminator import Discriminator 4 | from model.generator import Generator 5 | 6 | class Pix2pix(object): 7 | def __init__(self, width, height, ichan, ochan, l1_weight=100., lr=0.0002, beta1=0.5): 8 | """ 9 | width: image width in pixel. 10 | height: image height in pixel. 11 | ichan: number of channels used by input images. 12 | ochan: number of channels used by output images. 13 | l1_weight: L1 loss weight. 14 | lr: learning rate for ADAM optimizer. 15 | beta1: beta1 parameter for ADAM optimizer. 16 | """ 17 | self._is_training = tf.placeholder(tf.bool) 18 | 19 | self._g_inputs = tf.placeholder(tf.float32, [None, width, height, ichan]) 20 | self._d_inputs_a = tf.placeholder(tf.float32, [None, width, height, ichan]) 21 | self._d_inputs_b = tf.placeholder(tf.float32, [None, width, height, ochan]) 22 | self._g = Generator(self._g_inputs, self._is_training, ochan) 23 | self._real_d = Discriminator(tf.concat([self._d_inputs_a, self._d_inputs_b], axis=3), self._is_training) 24 | self._fake_d = Discriminator(tf.concat([self._d_inputs_a, self._g._decoder['cl9']['fmap']], axis=3), self._is_training, reuse=True) 25 | 26 | self._g_loss = -tf.reduce_mean(tf.log(self._fake_d._discriminator['l5']['fmap'])) + l1_weight * tf.reduce_mean(tf.abs(self._d_inputs_b - self._g._decoder['cl9']['fmap'])) 27 | self._d_loss = -tf.reduce_mean(tf.log(self._real_d._discriminator['l5']['fmap']) + tf.log(1.0 - self._fake_d._discriminator['l5']['fmap'])) 28 | 29 | g_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='G') 30 | with tf.control_dependencies(g_update_ops): 31 | self._g_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._g_loss, 32 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')) 33 | 34 | d_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='D') 35 | with tf.control_dependencies(d_update_ops): 36 | self._d_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._d_loss, 37 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')) 38 | 39 | def train_step(self, sess, g_inputs, d_inputs_a, d_inputs_b, is_training=True): 40 | _, dloss_curr = sess.run([self._d_train_step, self._d_loss], 41 | feed_dict={self._d_inputs_a : d_inputs_a, self._d_inputs_b : d_inputs_b, self._g_inputs : g_inputs, self._is_training : is_training}) 42 | _, gloss_curr = sess.run([self._g_train_step, self._g_loss], 43 | feed_dict={self._g_inputs : g_inputs, self._d_inputs_a : d_inputs_a, self._d_inputs_b : d_inputs_b,self._is_training : is_training}) 44 | return (gloss_curr, dloss_curr) 45 | 46 | def sample_generator(self, sess, g_inputs, is_training=False): 47 | return sess.run(self._g._decoder['cl9']['fmap'], feed_dict={self._g_inputs : g_inputs, self._is_training : is_training}) 48 | -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from .utils import get_shape, batch_norm, lkrelu 4 | 5 | # U-Net Generator 6 | class Generator(object): 7 | def __init__(self, inputs, is_training, ochan, stddev=0.02, center=True, scale=True, reuse=None): 8 | self._is_training = is_training 9 | self._stddev = stddev 10 | self._ochan = ochan 11 | with tf.variable_scope('G', initializer=tf.truncated_normal_initializer(stddev=self._stddev), reuse=reuse): 12 | self._center = center 13 | self._scale = scale 14 | self._prob = 0.5 # constant from pix2pix paper 15 | self._inputs = inputs 16 | self._encoder = self._build_encoder(inputs) 17 | self._decoder = self._build_decoder(self._encoder) 18 | 19 | def _build_encoder_layer(self, name, inputs, k, bn=True, use_dropout=False): 20 | layer = dict() 21 | with tf.variable_scope(name): 22 | layer['filters'] = tf.get_variable('filters', [4, 4, get_shape(inputs)[-1], k]) 23 | layer['conv'] = tf.nn.conv2d(inputs, layer['filters'], strides=[1, 2, 2, 1], padding='SAME') 24 | layer['bn'] = batch_norm(layer['conv'], center=self._center, scale=self._scale, training=self._is_training) if bn else layer['conv'] 25 | layer['dropout'] = tf.nn.dropout(layer['bn'], self._prob) if use_dropout else layer['bn'] 26 | layer['fmap'] = lkrelu(layer['dropout'], slope=0.2) 27 | return layer 28 | 29 | def _build_encoder(self, inputs): 30 | encoder = dict() 31 | 32 | # C64-C128-C256-C512-C512-C512-C512-C512 33 | with tf.variable_scope('encoder'): 34 | encoder['l1'] = self._build_encoder_layer('l1', inputs, 64, bn=False) 35 | encoder['l2'] = self._build_encoder_layer('l2', encoder['l1']['fmap'], 128) 36 | encoder['l3'] = self._build_encoder_layer('l3', encoder['l2']['fmap'], 256) 37 | encoder['l4'] = self._build_encoder_layer('l4', encoder['l3']['fmap'], 512) 38 | encoder['l5'] = self._build_encoder_layer('l5', encoder['l4']['fmap'], 512) 39 | encoder['l6'] = self._build_encoder_layer('l6', encoder['l5']['fmap'], 512) 40 | encoder['l7'] = self._build_encoder_layer('l7', encoder['l6']['fmap'], 512) 41 | encoder['l8'] = self._build_encoder_layer('l8', encoder['l7']['fmap'], 512) 42 | return encoder 43 | 44 | def _build_decoder_layer(self, name, inputs, output_shape_from,use_dropout=False): 45 | layer = dict() 46 | 47 | with tf.variable_scope(name): 48 | output_shape = tf.shape(output_shape_from) 49 | layer['filters'] = tf.get_variable('filters', [4, 4, get_shape(output_shape_from)[-1], get_shape(inputs)[-1]]) 50 | layer['conv'] = tf.nn.conv2d_transpose(inputs, layer['filters'], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') 51 | layer['bn'] = batch_norm(tf.reshape(layer['conv'], output_shape), center=self._center, scale=self._scale, training=self._is_training) 52 | layer['dropout'] = tf.nn.dropout(layer['bn'], self._prob) if use_dropout else layer['bn'] 53 | layer['fmap'] = tf.nn.relu(layer['dropout']) 54 | return layer 55 | 56 | def _build_decoder(self, encoder): 57 | decoder = dict() 58 | 59 | # CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128 60 | with tf.variable_scope('decoder'): # U-Net 61 | decoder['dl1'] = self._build_decoder_layer('dl1', encoder['l8']['fmap'], output_shape_from=encoder['l7']['fmap'], use_dropout=True) 62 | 63 | # fmap_concat represent skip connections 64 | fmap_concat = tf.concat([decoder['dl1']['fmap'], encoder['l7']['fmap']], axis=3) 65 | decoder['dl2'] = self._build_decoder_layer('dl2', fmap_concat, output_shape_from=encoder['l6']['fmap'], use_dropout=True) 66 | 67 | fmap_concat = tf.concat([decoder['dl2']['fmap'], encoder['l6']['fmap']], axis=3) 68 | decoder['dl3'] = self._build_decoder_layer('dl3', fmap_concat, output_shape_from=encoder['l5']['fmap'], use_dropout=True) 69 | 70 | fmap_concat = tf.concat([decoder['dl3']['fmap'], encoder['l5']['fmap']], axis=3) 71 | decoder['dl4'] = self._build_decoder_layer('dl4', fmap_concat, output_shape_from=encoder['l4']['fmap']) 72 | 73 | fmap_concat = tf.concat([decoder['dl4']['fmap'], encoder['l4']['fmap']], axis=3) 74 | decoder['dl5'] = self._build_decoder_layer('dl5', fmap_concat, output_shape_from=encoder['l3']['fmap']) 75 | 76 | fmap_concat = tf.concat([decoder['dl5']['fmap'], encoder['l3']['fmap']], axis=3) 77 | decoder['dl6'] = self._build_decoder_layer('dl6', fmap_concat, output_shape_from=encoder['l2']['fmap']) 78 | 79 | fmap_concat = tf.concat([decoder['dl6']['fmap'], encoder['l2']['fmap']], axis=3) 80 | decoder['dl7'] = self._build_decoder_layer('dl7', fmap_concat, output_shape_from=encoder['l1']['fmap']) 81 | 82 | fmap_concat = tf.concat([decoder['dl7']['fmap'], encoder['l1']['fmap']], axis=3) 83 | decoder['dl8'] = self._build_decoder_layer('dl8', fmap_concat, output_shape_from=self._inputs) 84 | 85 | with tf.variable_scope('cl9'): 86 | cl9 = dict() 87 | cl9['filters'] = tf.get_variable('filters', [4, 4, get_shape(decoder['dl8']['fmap'])[-1], self._ochan]) 88 | cl9['conv'] = tf.nn.conv2d(decoder['dl8']['fmap'], cl9['filters'], strides=[1, 1, 1, 1], padding='SAME') 89 | cl9['fmap'] = tf.nn.tanh(cl9['conv']) 90 | decoder['cl9'] = cl9 91 | return decoder 92 | --------------------------------------------------------------------------------