├── images └── .gitkeep ├── model ├── __init__.py ├── utils.py ├── discriminator.py └── generator.py ├── models └── .gitkeep ├── readme_imgs ├── a2o.jpg ├── h2z_1.JPG ├── h2z_2.JPG ├── h2z_3.JPG ├── h2z_4.JPG ├── h2z_5.JPG ├── h2z_6.JPG ├── l2l.jpg ├── weird.JPG ├── z2h_1.JPG ├── z2h_2.JPG ├── z2h_3.JPG ├── h2z_epoch1.jpg ├── failure_h2z.JPG ├── h2z_epoch200.jpg ├── lion2leopard.jpg ├── failure_h2z_2.JPG ├── failure_h2z_3.JPG ├── failure_h2z_4.JPG └── failure_h2z_5.JPG ├── README.md ├── example.py └── cyclegan.py /images/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /readme_imgs/a2o.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/a2o.jpg -------------------------------------------------------------------------------- /readme_imgs/h2z_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_1.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_2.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_3.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_4.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_4.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_5.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_5.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_6.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_6.JPG -------------------------------------------------------------------------------- /readme_imgs/l2l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/l2l.jpg -------------------------------------------------------------------------------- /readme_imgs/weird.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/weird.JPG -------------------------------------------------------------------------------- /readme_imgs/z2h_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/z2h_1.JPG -------------------------------------------------------------------------------- /readme_imgs/z2h_2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/z2h_2.JPG -------------------------------------------------------------------------------- /readme_imgs/z2h_3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/z2h_3.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_epoch1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_epoch1.jpg -------------------------------------------------------------------------------- /readme_imgs/failure_h2z.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/failure_h2z.JPG -------------------------------------------------------------------------------- /readme_imgs/h2z_epoch200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/h2z_epoch200.jpg -------------------------------------------------------------------------------- /readme_imgs/lion2leopard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/lion2leopard.jpg -------------------------------------------------------------------------------- /readme_imgs/failure_h2z_2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/failure_h2z_2.JPG -------------------------------------------------------------------------------- /readme_imgs/failure_h2z_3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/failure_h2z_3.JPG -------------------------------------------------------------------------------- /readme_imgs/failure_h2z_4.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/failure_h2z_4.JPG -------------------------------------------------------------------------------- /readme_imgs/failure_h2z_5.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eyyub/tensorflow-cyclegan/HEAD/readme_imgs/failure_h2z_5.JPG -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_shape(tensor): 4 | return tensor.get_shape().as_list() 5 | 6 | def inst_norm(tensor): # Instance Normalization https://arxiv.org/abs/1607.08022 7 | epsilon = 1e-5 8 | with tf.variable_scope('in'): 9 | scale = tf.get_variable('scale', initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02), shape=[get_shape(tensor)[-1]]) 10 | center = tf.get_variable('center', initializer=tf.zeros_initializer(dtype=tf.float32), shape=[get_shape(tensor)[-1]]) 11 | instance_mean, instance_var = tf.nn.moments(tensor, axes=[1, 2], keep_dims=True) 12 | 13 | return scale * ((tensor - instance_mean) / tf.sqrt(instance_var + epsilon)) + center 14 | 15 | def lkrelu(x, slope=0.01): 16 | return tf.maximum(slope * x, x) 17 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from .utils import get_shape, inst_norm, lkrelu 4 | 5 | # PatchGAN 6 | class Discriminator(object): 7 | def __init__(self, name, inputs, stddev=0.02, reuse=None): 8 | self._stddev = stddev 9 | 10 | with tf.variable_scope(name, initializer=tf.truncated_normal_initializer(stddev=self._stddev), reuse=reuse): 11 | self._inputs = inputs 12 | self._discriminator = self._build_discriminator(inputs) 13 | 14 | def __getitem__(self, key): 15 | return self._discriminator[key] 16 | 17 | def _build_layer(self, name, inputs, k, use_in=True, use_dropout=False): 18 | layer = dict() 19 | with tf.variable_scope(name): 20 | layer['filters'] = tf.get_variable('filters', [4, 4, get_shape(inputs)[-1], k]) 21 | layer['conv'] = tf.nn.conv2d(inputs, layer['filters'], strides=[1, 2, 2, 1], padding='SAME') 22 | layer['bn'] = inst_norm(layer['conv']) if use_in else layer['conv'] 23 | layer['fmap'] = lkrelu(layer['bn'], slope=0.2) 24 | return layer 25 | 26 | def _build_discriminator(self, inputs, reuse=None): 27 | discriminator = dict() 28 | 29 | # C64-C128-C256-C512 -> PatchGAN 30 | discriminator['l1'] = self._build_layer('l1', inputs, 64, use_in=False) 31 | discriminator['l2'] = self._build_layer('l2', discriminator['l1']['fmap'], 128) 32 | discriminator['l3'] = self._build_layer('l3', discriminator['l2']['fmap'], 256) 33 | discriminator['l4'] = self._build_layer('l4', discriminator['l3']['fmap'], 512) 34 | with tf.variable_scope('l5'): 35 | l5 = dict() 36 | l5['filters'] = tf.get_variable('filters', [4, 4, get_shape(discriminator['l4']['fmap'])[-1], 1]) 37 | l5['conv'] = tf.nn.conv2d(discriminator['l4']['fmap'], l5['filters'], strides=[1, 1, 1, 1], padding='SAME') 38 | l5['fmap'] = l5['conv'] # no sigmoid because we use LSGAN loss 39 | discriminator['l5'] = l5 40 | return discriminator 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-cyclegan 2 | A lightweight [CycleGAN](https://arxiv.org/abs/1703.10593) tensorflow implementation. 3 | 4 | If you plan to use a CycleGAN model for real-world purposes, you should use the [Torch CycleGAN](https://github.com/junyanz/CycleGAN) implementation. 5 | 6 | [@eyyub_s](https://twitter.com/eyyub_s) 7 | 8 | ## Some examples 9 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/lion2leopard.jpg?raw=true) 10 | lion2leopard (cherry-picked) 11 | 12 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/l2l.jpg?raw=true) 13 | More lion2leopard (each classes contain only 100 instances!) 14 | 15 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/h2z_1.JPG?raw=true) 16 | 17 | horse2zebra 18 | 19 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/failure_h2z_5.JPG?raw=true) 20 | 21 | horse2zebra failure 22 | 23 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/z2h_1.JPG?raw=true) 24 | 25 | zebra2horse 26 | 27 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/weird.JPG?raw=true) 28 | 29 | wtf 30 | 31 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/h2z_epoch200.jpg?raw=true) 32 | 33 | More zebra2horse 34 | 35 | ![](https://github.com/Eyyub/tensorflow-cyclegan/blob/master/readme_imgs/a2o.jpg?raw=true) 36 | 37 | apple2orange 38 | 39 | See more in `readme_imgs/` 40 | 41 | ## Build horse2zebra 42 | - Download `horse2zebra.zip` from [CycleGAN datasets](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/) 43 | - Unzip it here `.` 44 | - Run `python build_dataset.py horse2zebra/trainA horse2zebra/trainB trainA trainB` 45 | - (make sure `dataset_trainA.npy` & `dataset_trainB.npy` are created) 46 | - Then, run `python example.py` 47 | - (If you want to stop and restart your training later you can do: `python example.py restore `) 48 | 49 | ## Requiremennts 50 | - Python 3.5 51 | - Tensorflow 52 | - Matplotlib 53 | - Pillow 54 | - (Only tested on Windows so far) 55 | 56 | ## _Very_ useful info 57 | - Training took me ~1day (GTX 1060 3g) 58 | - Each 100 steps the script adds an image in the `images/` folder 59 | - Each 1000 steps the model is saved in `models` 60 | - CycleGAN seems to be init-sensitive, if the generators only inverse colors: kill & re-try training 61 | 62 | ## Todo 63 | - [x] Image Pool 64 | - [ ] Add learning reate linear decay 65 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import matplotlib.pyplot as plt 5 | from cyclegan import CycleGAN 6 | 7 | A = np.load('dataset_trainA.npy') / 255. 8 | B = np.load('dataset_trainB.npy') / 255. 9 | 10 | iters = 200 * min(A.shape[0], B.shape[0]) 11 | batch_size = 1 12 | 13 | with tf.device('/gpu:0'): 14 | model = CycleGAN(256, 256, xchan=3, ychan=3) 15 | 16 | saver = tf.train.Saver() 17 | 18 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: 19 | start = 0 20 | if len(sys.argv) > 1 and sys.argv[1] == 'restore': 21 | saver.restore(sess, 'models/model.ckpt') 22 | start = int(sys.argv[2]) 23 | else: 24 | sess.run(tf.global_variables_initializer()) 25 | 26 | for step in range(start, iters): 27 | # normalize because generator use tanh activation in its output layer 28 | a = 2. * np.expand_dims(A[np.random.randint(0, A.shape[0] - 1)], axis=0) - 1. 29 | b = 2. * np.expand_dims(B[np.random.randint(0, B.shape[0] - 1)], axis=0) - 1. 30 | 31 | d_a = 2. * np.expand_dims(A[np.random.randint(0, A.shape[0] - 1)], axis=0) - 1. 32 | d_b = 2. * np.expand_dims(B[np.random.randint(0, B.shape[0] - 1)], axis=0) - 1. 33 | 34 | (gxloss_curr, gyloss_curr), (dxloss_curr, dyloss_curr) = model.train_step(sess, a, b, d_a, d_b) #a: xs, b: ys 35 | print('Step %d: Gx loss: %f | Gy loss: %f | Dx loss: %f | Dy loss: %f' % (step, gxloss_curr, gyloss_curr, dxloss_curr, dyloss_curr)) 36 | 37 | if step % 100 == 0: 38 | fig = plt.figure() 39 | fig.set_size_inches(15, 15) 40 | fig.subplots_adjust(left=0, bottom=0, 41 | right=1, top=1, wspace=0, hspace=0.1) 42 | 43 | for i in range(0, 12*12, 6): 44 | ra = np.random.randint(0, A.shape[0] - 1) 45 | rb = np.random.randint(0, B.shape[0] - 1) 46 | 47 | # Plot 6 images 48 | 49 | # Plot real A image 50 | fig.add_subplot(12, 12, i + 1) 51 | plt.imshow(A[ra]) 52 | plt.axis('off') 53 | 54 | # Plot fake B image using above A image 55 | fig.add_subplot(12, 12, i + 2) 56 | b_from_a = model.sample_gy(sess, 2. * np.expand_dims(A[ra], axis=0) - 1.) 57 | plt.imshow((b_from_a[0] + 1.) / 2.) 58 | plt.axis('off') 59 | 60 | # Plot real B image 61 | fig.add_subplot(12, 12, i +3) 62 | plt.imshow(B[rb] ) 63 | plt.axis('off') 64 | 65 | # Plot fake A image using above B image 66 | fig.add_subplot(12, 12, i + 4) 67 | a_from_b = model.sample_gx(sess, 2. * np.expand_dims(B[rb], axis=0) - 1.) 68 | plt.imshow((a_from_b[0] + 1.) / 2.) 69 | plt.axis('off') 70 | 71 | # Plot recovered A image from the fake B image generated using the real A image 72 | fig.add_subplot(12, 12, i + 5) 73 | identity_a = model.sample_gx(sess, b_from_a) 74 | plt.imshow((identity_a[0] + 1.) / 2.) 75 | plt.axis('off') 76 | 77 | # Plot recovered B image from the fake A image generated using the real B image 78 | fig.add_subplot(12, 12, i + 6) 79 | identity_b = model.sample_gy(sess, a_from_b) 80 | plt.imshow((identity_b[0] + 1.) / 2.) 81 | plt.axis('off') 82 | plt.savefig('images/iter_%d.jpg' % step) 83 | plt.close() 84 | 85 | if step % 1000 == 0: 86 | # Save the model 87 | save_path = saver.save(sess, "models/model.ckpt") 88 | print("Model saved in file: %s" % save_path) 89 | -------------------------------------------------------------------------------- /cyclegan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model.discriminator import Discriminator 4 | from model.generator import Generator 5 | 6 | class _ImagePool(object): 7 | def __init__(self, max_size): 8 | self._pool = [] 9 | self._max_size = max_size 10 | 11 | def sample_random(self, a): 12 | if len(self._pool) < self._max_size: 13 | self._pool.append(a) 14 | return a 15 | r = np.random.random() 16 | if r > 0.5: 17 | r = np.random.randint(0, len(self._pool) - 1) 18 | inst = self._pool[r] 19 | self._pool[r] = a 20 | return inst 21 | else: 22 | return a 23 | 24 | class CycleGAN(object): 25 | def __init__(self, width, height, xchan, ychan, lambda_=10., pool_size=50, lr=0.0002, beta1=0.5): 26 | """ 27 | width: image width in pixel. 28 | height: image height in pixel. 29 | ichan: number of channels used by input images. 30 | ochan: number of channels used by output images. 31 | lambda_: Cycle-Consistency weighting. 32 | pool_size: Image pool size. 33 | lr: learning rate for ADAM optimizer. 34 | beta1: beta1 parameter for ADAM optimizer. 35 | """ 36 | 37 | self._dx_pool = _ImagePool(pool_size) 38 | self._dy_pool = _ImagePool(pool_size) 39 | 40 | self._xs = tf.placeholder(tf.float32, [None, width, height, xchan]) 41 | self._ys = tf.placeholder(tf.float32, [None, width, height, ychan]) 42 | 43 | self._d_xs = tf.placeholder(tf.float32, [None, width, height, xchan]) 44 | self._d_ys = tf.placeholder(tf.float32, [None, width, height, ychan]) 45 | self._fake_d_xs = tf.placeholder(tf.float32, [None, width, height, xchan]) 46 | self._fake_d_ys = tf.placeholder(tf.float32, [None, width, height, ychan]) 47 | 48 | self._gx = Generator('Gx', self._ys, xchan) 49 | self._gy = Generator('Gy', self._xs, ychan) 50 | 51 | self._gx_from_gy = Generator('Gx', self._gy['l15']['fmap'], xchan, reuse=True) 52 | self._gy_from_gx = Generator('Gy', self._gx['l15']['fmap'], ychan, reuse=True) 53 | 54 | self._real_dx = Discriminator('Dx', self._d_xs) 55 | self._fake_dx = Discriminator('Dx', self._xs, reuse=True) 56 | self._fake_dx_g = Discriminator('Dx', self._gx['l15']['fmap'], reuse=True) 57 | 58 | self._real_dy = Discriminator('Dy', self._d_ys) 59 | self._fake_dy = Discriminator('Dy', self._ys, reuse=True) 60 | self._fake_dy_g = Discriminator('Dy', self._gy['l15']['fmap'], reuse=True) 61 | 62 | # Forward and backward Cycle-Consistency with LSGAN-kind losses 63 | cycle_loss = lambda_ * (tf.reduce_mean(tf.abs((self._gx_from_gy['l15']['fmap'] - self._xs))) + tf.reduce_mean(tf.abs((self._gy_from_gx['l15']['fmap'] - self._ys)))) 64 | self._gx_loss = 0.5 * tf.reduce_mean(tf.square(self._fake_dx_g['l5']['fmap'] - 1.)) + cycle_loss 65 | self._gy_loss = 0.5 * tf.reduce_mean(tf.square(self._fake_dy_g['l5']['fmap'] - 1.)) + cycle_loss 66 | 67 | self._dx_loss = 0.5 * tf.reduce_mean(tf.square(self._real_dx['l5']['fmap'] - 1.)) + 0.5 * tf.reduce_mean(tf.square(self._fake_dx['l5']['fmap'])) 68 | self._dy_loss = 0.5 * tf.reduce_mean(tf.square(self._real_dy['l5']['fmap'] - 1.)) + 0.5 * tf.reduce_mean(tf.square(self._fake_dy['l5']['fmap'])) 69 | 70 | self._gx_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._gx_loss, 71 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gx')) 72 | 73 | self._gy_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._gy_loss, 74 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Gy')) 75 | 76 | self._dx_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._dx_loss, 77 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Dx')) 78 | 79 | self._dy_train_step = tf.train.AdamOptimizer(lr, beta1=beta1).minimize(self._dy_loss, 80 | var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Dy')) 81 | 82 | def train_step(self, sess, xs, ys, d_xs, d_ys): 83 | 84 | ops = [self._gx_train_step, self._gy_train_step, self._gx_loss, self._gy_loss, self._gx['l15']['fmap'], self._gy['l15']['fmap']] 85 | _, _, gxloss_curr, gyloss_curr, gxs, gys = sess.run(ops, feed_dict={self._xs: xs, self._ys: ys}) 86 | 87 | _, _, dxloss_curr, dyloss_curr = sess.run([self._dx_train_step, self._dy_train_step, self._dx_loss, self._dy_loss], 88 | feed_dict={self._xs: self._dx_pool.sample_random(gxs), 89 | self._ys: self._dy_pool.sample_random(gys), 90 | self._d_xs: d_xs, self._d_ys: d_ys}) 91 | 92 | return ((gxloss_curr, gyloss_curr), (dxloss_curr, dyloss_curr)) 93 | 94 | def sample_gx(self, sess, ys): 95 | return sess.run(self._gx['l15']['fmap'], feed_dict={self._ys: ys}) 96 | 97 | def sample_gy(self, sess, xs): 98 | return sess.run(self._gy['l15']['fmap'], feed_dict={self._xs: xs}) 99 | -------------------------------------------------------------------------------- /model/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from .utils import get_shape, inst_norm, lkrelu 4 | 5 | # 9-ResBlock Generator 6 | class Generator(object): 7 | def __init__(self, name, inputs, ochan, stddev=0.02, center=True, scale=True, reuse=None): 8 | self._stddev = stddev 9 | self._ochan = ochan 10 | with tf.variable_scope(name, initializer=tf.truncated_normal_initializer(stddev=self._stddev), reuse=reuse): 11 | self._inputs = inputs 12 | self._resnet = self._build_resnet(self._inputs) 13 | 14 | def __getitem__(self, key): 15 | return self._resnet[key] 16 | 17 | def _build_conv_layer(self, name, inputs, k, rfsize, stride, use_in=True, f=tf.nn.relu, reflect=False): 18 | layer = dict() 19 | with tf.variable_scope(name): 20 | layer['filters'] = tf.get_variable('filters', [rfsize, rfsize, get_shape(inputs)[-1], k]) 21 | 22 | if reflect: 23 | # pad with 3, not indicated by the paper but torch CycleGAN does it this way 24 | layer['conv'] = tf.nn.conv2d(tf.pad(inputs, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT'), layer['filters'], strides=[1, stride, stride, 1], padding='VALID') 25 | else: 26 | layer['conv'] = tf.nn.conv2d(inputs, layer['filters'], strides=[1, stride, stride, 1], padding='SAME') 27 | layer['bn'] = inst_norm(layer['conv']) if use_in else layer['conv'] 28 | layer['fmap'] = f(layer['bn']) 29 | return layer 30 | 31 | def _build_residual_layer(self, name, inputs, k, rfsize, blocksize=2, stride=1): # rfsize: receptive field size 32 | layer = dict() 33 | with tf.variable_scope(name): 34 | with tf.variable_scope('layer1'): 35 | layer['filters1'] = tf.get_variable('filters1', [rfsize, rfsize, get_shape(inputs)[-1], k]) 36 | layer['conv1'] = tf.nn.conv2d(tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT'), layer['filters1'], strides=[1, stride, stride, 1], padding='VALID') 37 | layer['bn1'] = inst_norm(layer['conv1']) 38 | layer['fmap1'] = tf.nn.relu(layer['bn1']) 39 | 40 | with tf.variable_scope('layer2'): 41 | layer['filters2'] = tf.get_variable('filters2', [rfsize, rfsize, get_shape(inputs)[-1], k]) 42 | layer['conv2'] = tf.nn.conv2d(tf.pad(layer['fmap1'], [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT'), layer['filters2'], strides=[1, stride, stride, 1], padding='VALID') 43 | layer['bn2'] = inst_norm(layer['conv2']) 44 | 45 | # No ReLu here (following http://torch.ch/blog/2016/02/04/resnets.html, as indicated by the authors) 46 | layer['fmap2'] = layer['bn2'] + inputs 47 | return layer 48 | 49 | def _build_deconv_layer(self, name, inputs, k, output_shape, rfsize): # fractional-strided conv layer 50 | layer = dict() 51 | 52 | with tf.variable_scope(name): 53 | output_shape = [tf.shape(inputs)[0]] + output_shape 54 | layer['filters'] = tf.get_variable('filters', [rfsize, rfsize, output_shape[-1], get_shape(inputs)[-1]]) 55 | layer['conv'] = tf.nn.conv2d_transpose(inputs, layer['filters'], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') 56 | layer['bn'] = inst_norm(tf.reshape(layer['conv'], output_shape)) 57 | layer['fmap'] = tf.nn.relu(layer['bn']) 58 | return layer 59 | 60 | def _build_resnet(self, inputs): 61 | resnet = dict() 62 | 63 | inputs_shape = get_shape(inputs) 64 | width = inputs_shape[1] 65 | height = inputs_shape[2] 66 | 67 | # c7s1-32,d64,d128,R128,R128,R128,R128,R128,R128,R128,R128,R128,u64,u32,c7s1-3 See paper §7.5 68 | with tf.variable_scope('resnet'): 69 | resnet['l1'] = self._build_conv_layer('c7s1-32_1', inputs, k=32, rfsize=7, stride=1, reflect=True) 70 | resnet['l2'] = self._build_conv_layer('d64_1', resnet['l1']['fmap'], k=64, rfsize=3, stride=2) 71 | resnet['l3'] = self._build_conv_layer('d128_1', resnet['l2']['fmap'], k=128, rfsize=3, stride=2) 72 | resnet['l4'] = self._build_residual_layer('r128_1', resnet['l3']['fmap'], k=128, rfsize=3, stride=1) 73 | resnet['l5'] = self._build_residual_layer('r128_2', resnet['l4']['fmap2'], k=128, rfsize=3, stride=1) 74 | resnet['l6'] = self._build_residual_layer('r128_3', resnet['l5']['fmap2'], k=128, rfsize=3, stride=1) 75 | resnet['l7'] = self._build_residual_layer('r128_4', resnet['l6']['fmap2'], k=128, rfsize=3, stride=1) 76 | resnet['l8'] = self._build_residual_layer('r128_5', resnet['l7']['fmap2'], k=128, rfsize=3, stride=1) 77 | resnet['l9'] = self._build_residual_layer('r128_6', resnet['l8']['fmap2'], k=128, rfsize=3, stride=1) 78 | resnet['l10'] = self._build_residual_layer('r128_7', resnet['l9']['fmap2'], k=128, rfsize=3, stride=1) 79 | resnet['l11'] = self._build_residual_layer('r128_8', resnet['l10']['fmap2'], k=128, rfsize=3, stride=1) 80 | resnet['l12'] = self._build_residual_layer('r128_9', resnet['l11']['fmap2'], k=128, rfsize=3, stride=1) 81 | resnet['l13'] = self._build_deconv_layer('u64_1', resnet['l12']['fmap2'], k=64, output_shape=[width//2, height//2, 64], rfsize=3) 82 | resnet['l14'] = self._build_deconv_layer('u32_1', resnet['l13']['fmap'], k=32, output_shape=[width, height, 32], rfsize=3) 83 | resnet['l15'] = self._build_conv_layer('c7s1-3_1', resnet['l14']['fmap'], f=tf.nn.tanh, k=get_shape(inputs)[-1], rfsize=7, stride=1, use_in=False, reflect=True) 84 | return resnet 85 | --------------------------------------------------------------------------------