├── .gitignore ├── .ipynb_checkpoints ├── Blog for CycleGan-checkpoint.ipynb └── CycleGAN_notebook-checkpoint.ipynb ├── README.md ├── __pycache__ └── layers.cpython-36.pyc ├── download_datasets.sh ├── images ├── Generator.jpg ├── Results.jpg ├── discriminator.jpg ├── distortion.jpg └── model.jpg ├── layers.py ├── main.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | output/ 3 | input/ 4 | MNIST_data 5 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/Blog for CycleGan-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Cross domain image transfer has been a hot topic and various models have been used to acheive this task. Recenetly, realising the power of GAN networks people have started making model based on GANs and the results are quite impressive. In this blog, we will try to understand the recent development in this direction by Kim et. al https://arxiv.org/pdf/1703.05192v1.pdf." 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "collapsed": true 14 | }, 15 | "source": [ 16 | "There are various components of transferring an image from one cateogry to another category. So let us first look at these things in detail\n", 17 | "\n", 18 | "1.) Extracting the features of an image (Decoding): So, first of all what do we mean by features of an image. Suppose we want to convert an image of man to a look-alike woman, first, we need to understand various general things about the human face like this thing is nose, or ear, or eyebrows or hair etc. These are some high level features but there are some low level features like edges and curves. Using these low level features only we recongise these high level features which help use break down the image and understand the details of it. Model that is used to extract these features almost in all neutral network is call Convolution Network which you can read in detail about in this paper: [Link to nice blog post on Convolution]\n", 19 | "\n", 20 | "\n", 21 | "\n", 22 | "2.) Learning transformation from the feature of one model to that of another model (Transfer): After first step, we have high and low level features of an image and we would like to convert these details to that of another category. Let us look at one of these transformatin that we need to learn. Suppose we detected that a man's face has some beard and we like to convert this face to that of a woman then we need to remove the beard and smothen out that part of face to match with color of face. Another example might include thick eyebrows to thin eyebrows and so on.\n", 23 | "\n", 24 | "\n", 25 | "3.) Making the final image (Encoding): This step can bee send as the reversal of step 1. Here, we will look at different features of an image and construct a face out of it.\n", 26 | "\n", 27 | "\n", 28 | "Things need to keep in mind. We will not learn high level features like this is an ear or this part is nose because these are very high level features but maybe like this the upper part of nose or like this is left part of the left eyebrow and so on. Basically some mid level features.\n", 29 | "\n", 30 | "\n", 31 | "So, now that we have understood the key elements of an image transformation, let use see how they model this using neural network. Once, you are clear with these steps and have basic knowledge of Deep Learning you are good to go to make the model." 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "Extracting the features:\n", 39 | "\n", 40 | "For the basic about convolution network you can go through this very intuitive blog post by []. So the first step is extracting 64 very low level features (using a window of [7,7]).This can easily be done via a conv layer as follow:" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "collapsed": false 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "o_c1 = general_conv2d(input, 64, 7, 7, 1, 1)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "Now, we can keep on extracting further features and since we dont want to blow up the size of image, we will use stride of 2. It will prevent overlapping as well as reduce the size of output. So, we further extract features as follow" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": { 65 | "collapsed": true 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "o_c2 = general_conv2d(o_c1, 64*2, 3, 3, 2, 2)\n", 70 | "o_c3 = general_conv2d(o_c2, 64*4, 3, 3, 2, 2)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "So, as you can see we are extracting more and more high-level features from the previous layer of low-level features moving towards extracting very high level features." 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "So we can assume that after these step we will have fairly decent amount of features that we would like to transform to a woman's face. So we will build the transformation layers as follow. For these transformation we would like to maintain the same number of features, so the number of features of intermediate output of intermediate layers will be same." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "collapsed": true 92 | }, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 2", 100 | "language": "python", 101 | "name": "python2" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 2 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython2", 113 | "version": "2.7.11" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/CycleGAN_notebook-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Library imports " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "collapsed": false 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "# Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow\n", 19 | "\n", 20 | "import tensorflow as tf\n", 21 | "from tensorflow.examples.tutorials.mnist import input_data\n", 22 | "import numpy as np\n", 23 | "from skimage.io import imsave\n", 24 | "import os\n", 25 | "import shutil" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Constants and flags " 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": { 39 | "collapsed": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "img_height = 28\n", 44 | "img_width = 28\n", 45 | "img_size = img_height * img_width\n", 46 | "\n", 47 | "to_train = True\n", 48 | "to_restore = False\n", 49 | "output_path = \"output\"\n", 50 | "\n", 51 | "max_epoch = 500\n", 52 | "\n", 53 | "h1_size = 150\n", 54 | "h2_size = 300\n", 55 | "z_size = 100\n", 56 | "batch_size = 256\n", 57 | "ngf = 128" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Function definitions " 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "### Convolution layer \n", 72 | "Defines a general 2D convolution layer with batch normalization" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "collapsed": true 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "def general_conv2d(inputconv, name=\"conv2d\",\n", 84 | " o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, \n", 85 | " stddev=0.02, padding=None, \n", 86 | " do_norm=True, do_relu=True):\n", 87 | " '''Defines a general 2D convolution layer with batch normalization'''\n", 88 | " \n", 89 | " with tf.variable_scope(name):\n", 90 | " initializer = tf.truncated_normal_intializer(stddev=stddev)\n", 91 | " w = tf.get_variable('w',\n", 92 | " [f_h, f_w, inputconv.get_shape(-1), o_d], \n", 93 | " initializer=initializer)\n", 94 | " conv = tf.nn.conv2d(inputconv,\n", 95 | " filter=w,strides=[1, s_w, s_h, 1],\n", 96 | " padding=padding)\n", 97 | " biases = tf.get_variable('b',\n", 98 | " [o_d],\n", 99 | " initializer=tf.constant_initializer(0.0))\n", 100 | " conv = tf.nn.bias_add(conv,biases)\n", 101 | " \n", 102 | " # Add batch_norm layer\n", 103 | " if do_norm:\n", 104 | " dims = conv.get_shape()\n", 105 | " scale = tf.get_variable('scale',\n", 106 | " [dims[1],dims[2],dims[3]],\n", 107 | " tf.constant_initializer(1))\n", 108 | " beta = tf.get_variable('beta',\n", 109 | " [dims[1],dims[2],dims[3]],\n", 110 | " tf.constant_initializer(0))\n", 111 | " conv_mean, conv_var = tf.nn.moments(conv,[0])\n", 112 | " conv = tf.nn.batch_normalization(conv, conv_mean, conv_var, beta, scale, 0.001)\n", 113 | " # Add ReLU activation\n", 114 | " if do_relu:\n", 115 | " conv = tf.nn.relu(conv,0)\n", 116 | " return conv" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "### Resnet block " 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "collapsed": false 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "def build_resnet_block(inputres, dim, name=\"resnet\"):\n", 135 | " out_res = inputres\n", 136 | " with tf.variable_scope(name):\n", 137 | " out_res = general_conv2d(inputres, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c1\")\n", 138 | " out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c2\", do_relu=False)\n", 139 | " return tf.nn.relu(out_res + inputres)\n", 140 | "\n", 141 | "def build_generator_resnet_6blocks(inputgen, name=\"generator\"):\n", 142 | " with tf.variable_scope(name):\n", 143 | " f = 7\n", 144 | " ks = 3\n", 145 | " o_c1 = general_conv2d(inputgen, ngf, f, f, 1, 1, 0.02, \"SAME\", \"c1\")\n", 146 | " o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, None, \"c2\")\n", 147 | " o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, None, \"c3\")" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Show results " 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):\n", 166 | " batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5\n", 167 | " img_h, img_w = batch_res.shape[1], batch_res.shape[2]\n", 168 | " grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)\n", 169 | " grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)\n", 170 | " img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)\n", 171 | " for i, res in enumerate(batch_res):\n", 172 | " if i >= grid_size[0] * grid_size[1]:\n", 173 | " break\n", 174 | " img = (res) * 255\n", 175 | " img = img.astype(np.uint8)\n", 176 | " row = (i // grid_size[0]) * (img_h + grid_pad)\n", 177 | " col = (i % grid_size[1]) * (img_w + grid_pad)\n", 178 | " img_grid[row:row + img_h, col:col + img_w] = img\n", 179 | " imsave(fname, img_grid)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "### Training and testing routines " 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "collapsed": false 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "def train():\n", 198 | " mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n", 199 | "\n", 200 | " x_data = tf.placeholder(tf.float32, [batch_size, img_size], name=\"x_data\")\n", 201 | " z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n", 202 | " keep_prob = tf.placeholder(tf.float32, name=\"keep_prob\")\n", 203 | " global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n", 204 | "\n", 205 | " x_generated, g_params = build_generator(z_prior)\n", 206 | " y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)\n", 207 | "\n", 208 | " d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))\n", 209 | " g_loss = - tf.log(y_generated)\n", 210 | "\n", 211 | " optimizer = tf.train.AdamOptimizer(0.0001)\n", 212 | "\n", 213 | " d_trainer = optimizer.minimize(d_loss, var_list=d_params)\n", 214 | " g_trainer = optimizer.minimize(g_loss, var_list=g_params)\n", 215 | "\n", 216 | " init = tf.initialize_all_variables()\n", 217 | "\n", 218 | " saver = tf.train.Saver()\n", 219 | "\n", 220 | " sess = tf.Session()\n", 221 | "\n", 222 | " sess.run(init)\n", 223 | "\n", 224 | " if to_restore:\n", 225 | " chkpt_fname = tf.train.latest_checkpoint(output_path)\n", 226 | " saver.restore(sess, chkpt_fname)\n", 227 | " else:\n", 228 | " if os.path.exists(output_path):\n", 229 | " shutil.rmtree(output_path)\n", 230 | " os.mkdir(output_path)\n", 231 | "\n", 232 | "\n", 233 | " z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n", 234 | "\n", 235 | " for i in range(sess.run(global_step), max_epoch):\n", 236 | " for j in range(60000 / batch_size):\n", 237 | " print(\"epoch:%s, iter:%s\" % (i, j))\n", 238 | " x_value, _ = mnist.train.next_batch(batch_size)\n", 239 | " x_value = 2 * x_value.astype(np.float32) - 1\n", 240 | " z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n", 241 | " sess.run(d_trainer,\n", 242 | " feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n", 243 | " if j % 1 == 0:\n", 244 | " sess.run(g_trainer,\n", 245 | " feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n", 246 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})\n", 247 | " show_result(x_gen_val, \"output/sample{0}.jpg\".format(i))\n", 248 | " z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n", 249 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})\n", 250 | " show_result(x_gen_val, \"output/random_sample{0}.jpg\".format(i))\n", 251 | " sess.run(tf.assign(global_step, i + 1))\n", 252 | " saver.save(sess, os.path.join(output_path, \"model\"), global_step=global_step)\n", 253 | "\n", 254 | "\n", 255 | "def test():\n", 256 | " z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n", 257 | " x_generated, _ = build_generator(z_prior)\n", 258 | " chkpt_fname = tf.train.latest_checkpoint(output_path)\n", 259 | "\n", 260 | " init = tf.initialize_all_variables()\n", 261 | " sess = tf.Session()\n", 262 | " saver = tf.train.Saver()\n", 263 | " sess.run(init)\n", 264 | " saver.restore(sess, chkpt_fname)\n", 265 | " z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n", 266 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})\n", 267 | " show_result(x_gen_val, \"output/test_result.jpg\")" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "## Main/driver" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "collapsed": false 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "if __name__ == '__main__':\n", 286 | " if to_train:\n", 287 | " train()\n", 288 | " else:\n", 289 | " test()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "collapsed": true 297 | }, 298 | "outputs": [], 299 | "source": [] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python 2", 305 | "language": "python", 306 | "name": "python2" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 2 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython2", 318 | "version": "2.7.11" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 2 323 | } 324 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CycleGAN 2 | Tensorflow implementation of CycleGAN. 3 | 4 | 1. [Original implementation](https://github.com/junyanz/CycleGAN/) 5 | 2. [Paper](https://arxiv.org/abs/1703.10593) 6 | 7 | 8 | 9 | ### CycleGAN model 10 | 11 | CycleGAN model can be summarized in the following image. For full details about implementation and understanding CycleGAN you can read the tutorial at this [link](https://hardikbansal.github.io/CycleGANBlog/) 12 | 13 | 14 |

15 | Model 16 |

17 | 18 | ##### Generator 19 | 20 | 21 |

22 | Generator 23 |

24 | 25 | ##### Discriminator 26 | 27 |

28 | Discriminator 29 |

30 | 31 | ### Our Results 32 | 33 | We ran the model for [horse2zebra dataset](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip) but because of the lack of resources, we just ran the model for 100 epochs and got following results. 34 | 35 |

36 | Results 37 |

38 | 39 | ### Final Comments 40 | 41 | 1. During training we noticed that the ouput results were sensitive to initialization. Thanks to [vanhuyz](https://github.com/vanhuyz) for pointing this out and suggesting training multiple times to get best results. You might notice background color being reversed as in following image. This effect can be observed only after 10-20 epochs and you can try to run the code again. 42 | 43 |

44 | Fail 45 |

46 | 47 | 2. We also think that this model is not good fit to change the shape of object. We tried to run the model for converting a men's face to a look alike women's face. For that we used celebA dataset but the results are not good and images produced are quite distorted. 48 | 49 | 50 | ### Blog 51 | 52 | If you would like to understand the paper and see how to implement it by your own, you can have look at the blog by [me](https://hardikbansal.github.io/CycleGANBlog/) 53 | -------------------------------------------------------------------------------- /__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /download_datasets.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 5 | exit 1 6 | fi 7 | 8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 9 | ZIP_FILE=./input/$FILE.zip 10 | TARGET_DIR=./input/$FILE/ 11 | wget -N $URL -O $ZIP_FILE 12 | mkdir $TARGET_DIR 13 | unzip $ZIP_FILE -d ./input/ 14 | rm $ZIP_FILE -------------------------------------------------------------------------------- /images/Generator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/Generator.jpg -------------------------------------------------------------------------------- /images/Results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/Results.jpg -------------------------------------------------------------------------------- /images/discriminator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/discriminator.jpg -------------------------------------------------------------------------------- /images/distortion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/distortion.jpg -------------------------------------------------------------------------------- /images/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/model.jpg -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False): 4 | 5 | with tf.variable_scope(name): 6 | if alt_relu_impl: 7 | f1 = 0.5 * (1 + leak) 8 | f2 = 0.5 * (1 - leak) 9 | # lrelu = 1/2 * (1 + leak) * x + 1/2 * (1 - leak) * |x| 10 | return f1 * x + f2 * abs(x) 11 | else: 12 | return tf.maximum(x, leak*x) 13 | 14 | def instance_norm(x): 15 | 16 | with tf.variable_scope("instance_norm"): 17 | epsilon = 1e-5 18 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 19 | scale = tf.get_variable('scale',[x.get_shape()[-1]], 20 | initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02)) 21 | offset = tf.get_variable('offset',[x.get_shape()[-1]],initializer=tf.constant_initializer(0.0)) 22 | out = scale*tf.div(x-mean, tf.sqrt(var+epsilon)) + offset 23 | 24 | return out 25 | 26 | 27 | def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_norm=True, do_relu=True, relufactor=0): 28 | with tf.variable_scope(name): 29 | 30 | conv = tf.contrib.layers.conv2d(inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0)) 31 | if do_norm: 32 | conv = instance_norm(conv) 33 | # conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm") 34 | 35 | if do_relu: 36 | if(relufactor == 0): 37 | conv = tf.nn.relu(conv,"relu") 38 | else: 39 | conv = lrelu(conv, relufactor, "lrelu") 40 | 41 | return conv 42 | 43 | 44 | 45 | def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0): 46 | with tf.variable_scope(name): 47 | 48 | conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0)) 49 | 50 | if do_norm: 51 | conv = instance_norm(conv) 52 | # conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm") 53 | 54 | if do_relu: 55 | if(relufactor == 0): 56 | conv = tf.nn.relu(conv,"relu") 57 | else: 58 | conv = lrelu(conv, relufactor, "lrelu") 59 | 60 | return conv 61 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | import numpy as np 4 | from scipy.misc import imsave 5 | import os 6 | import shutil 7 | from PIL import Image 8 | import time 9 | import random 10 | import sys 11 | 12 | 13 | from layers import * 14 | from model import * 15 | 16 | img_height = 256 17 | img_width = 256 18 | img_layer = 3 19 | img_size = img_height * img_width 20 | 21 | to_train = True 22 | to_test = False 23 | to_restore = False 24 | output_path = "./output" 25 | check_dir = "./output/checkpoints/" 26 | 27 | 28 | temp_check = 0 29 | 30 | 31 | 32 | max_epoch = 1 33 | max_images = 100 34 | 35 | h1_size = 150 36 | h2_size = 300 37 | z_size = 100 38 | batch_size = 1 39 | pool_size = 50 40 | sample_size = 10 41 | save_training_images = True 42 | ngf = 32 43 | ndf = 64 44 | 45 | class CycleGAN(): 46 | 47 | def input_setup(self): 48 | 49 | ''' 50 | This function basically setup variables for taking image input. 51 | 52 | filenames_A/filenames_B -> takes the list of all training images 53 | self.image_A/self.image_B -> Input image with each values ranging from [-1,1] 54 | ''' 55 | 56 | filenames_A = tf.train.match_filenames_once("./input/horse2zebra/trainA/*.jpg") 57 | self.queue_length_A = tf.size(filenames_A) 58 | filenames_B = tf.train.match_filenames_once("./input/horse2zebra/trainB/*.jpg") 59 | self.queue_length_B = tf.size(filenames_B) 60 | 61 | filename_queue_A = tf.train.string_input_producer(filenames_A) 62 | filename_queue_B = tf.train.string_input_producer(filenames_B) 63 | 64 | image_reader = tf.WholeFileReader() 65 | _, image_file_A = image_reader.read(filename_queue_A) 66 | _, image_file_B = image_reader.read(filename_queue_B) 67 | 68 | self.image_A = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_A),[256,256]),127.5),1) 69 | self.image_B = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_B),[256,256]),127.5),1) 70 | 71 | 72 | 73 | def input_read(self, sess): 74 | 75 | 76 | ''' 77 | It reads the input into from the image folder. 78 | 79 | self.fake_images_A/self.fake_images_B -> List of generated images used for calculation of loss function of Discriminator 80 | self.A_input/self.B_input -> Stores all the training images in python list 81 | ''' 82 | 83 | # Loading images into the tensors 84 | coord = tf.train.Coordinator() 85 | threads = tf.train.start_queue_runners(coord=coord) 86 | 87 | num_files_A = sess.run(self.queue_length_A) 88 | num_files_B = sess.run(self.queue_length_B) 89 | 90 | self.fake_images_A = np.zeros((pool_size,1,img_height, img_width, img_layer)) 91 | self.fake_images_B = np.zeros((pool_size,1,img_height, img_width, img_layer)) 92 | 93 | 94 | self.A_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer)) 95 | self.B_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer)) 96 | 97 | for i in range(max_images): 98 | image_tensor = sess.run(self.image_A) 99 | if(image_tensor.size() == img_size*batch_size*img_layer): 100 | self.A_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer)) 101 | 102 | for i in range(max_images): 103 | image_tensor = sess.run(self.image_B) 104 | if(image_tensor.size() == img_size*batch_size*img_layer): 105 | self.B_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer)) 106 | 107 | 108 | coord.request_stop() 109 | coord.join(threads) 110 | 111 | 112 | 113 | 114 | def model_setup(self): 115 | 116 | ''' This function sets up the model to train 117 | 118 | self.input_A/self.input_B -> Set of training images. 119 | self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B 120 | self.lr -> Learning rate variable 121 | self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calcualte cyclic loss 122 | ''' 123 | 124 | self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A") 125 | self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B") 126 | 127 | self.fake_pool_A = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_A") 128 | self.fake_pool_B = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_B") 129 | 130 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 131 | 132 | self.num_fake_inputs = 0 133 | 134 | self.lr = tf.placeholder(tf.float32, shape=[], name="lr") 135 | 136 | with tf.variable_scope("Model") as scope: 137 | self.fake_B = build_generator_resnet_9blocks(self.input_A, name="g_A") 138 | self.fake_A = build_generator_resnet_9blocks(self.input_B, name="g_B") 139 | self.rec_A = build_gen_discriminator(self.input_A, "d_A") 140 | self.rec_B = build_gen_discriminator(self.input_B, "d_B") 141 | 142 | scope.reuse_variables() 143 | 144 | self.fake_rec_A = build_gen_discriminator(self.fake_A, "d_A") 145 | self.fake_rec_B = build_gen_discriminator(self.fake_B, "d_B") 146 | self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B") 147 | self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A") 148 | 149 | scope.reuse_variables() 150 | 151 | self.fake_pool_rec_A = build_gen_discriminator(self.fake_pool_A, "d_A") 152 | self.fake_pool_rec_B = build_gen_discriminator(self.fake_pool_B, "d_B") 153 | 154 | def loss_calc(self): 155 | 156 | ''' In this function we are defining the variables for loss calcultions and traning model 157 | 158 | d_loss_A/d_loss_B -> loss for discriminator A/B 159 | g_loss_A/g_loss_B -> loss for generator A/B 160 | *_trainer -> Variaous trainer for above loss functions 161 | *_summ -> Summary variables for above loss functions''' 162 | 163 | cyc_loss = tf.reduce_mean(tf.abs(self.input_A-self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B-self.cyc_B)) 164 | 165 | disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A,1)) 166 | disc_loss_B = tf.reduce_mean(tf.squared_difference(self.fake_rec_B,1)) 167 | 168 | g_loss_A = cyc_loss*10 + disc_loss_B 169 | g_loss_B = cyc_loss*10 + disc_loss_A 170 | 171 | d_loss_A = (tf.reduce_mean(tf.square(self.fake_pool_rec_A)) + tf.reduce_mean(tf.squared_difference(self.rec_A,1)))/2.0 172 | d_loss_B = (tf.reduce_mean(tf.square(self.fake_pool_rec_B)) + tf.reduce_mean(tf.squared_difference(self.rec_B,1)))/2.0 173 | 174 | 175 | optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5) 176 | 177 | self.model_vars = tf.trainable_variables() 178 | 179 | d_A_vars = [var for var in self.model_vars if 'd_A' in var.name] 180 | g_A_vars = [var for var in self.model_vars if 'g_A' in var.name] 181 | d_B_vars = [var for var in self.model_vars if 'd_B' in var.name] 182 | g_B_vars = [var for var in self.model_vars if 'g_B' in var.name] 183 | 184 | self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars) 185 | self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars) 186 | self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars) 187 | self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars) 188 | 189 | for var in self.model_vars: print(var.name) 190 | 191 | #Summary variables for tensorboard 192 | 193 | self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A) 194 | self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B) 195 | self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A) 196 | self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B) 197 | 198 | def save_training_images(self, sess, epoch): 199 | 200 | if not os.path.exists("./output/imgs"): 201 | os.makedirs("./output/imgs") 202 | 203 | for i in range(0,10): 204 | fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([self.fake_A, self.fake_B, self.cyc_A, self.cyc_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]}) 205 | imsave("./output/imgs/fakeB_"+ str(epoch) + "_" + str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8)) 206 | imsave("./output/imgs/fakeA_"+ str(epoch) + "_" + str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8)) 207 | imsave("./output/imgs/cycA_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_A_temp[0]+1)*127.5).astype(np.uint8)) 208 | imsave("./output/imgs/cycB_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_B_temp[0]+1)*127.5).astype(np.uint8)) 209 | imsave("./output/imgs/inputA_"+ str(epoch) + "_" + str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8)) 210 | imsave("./output/imgs/inputB_"+ str(epoch) + "_" + str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8)) 211 | 212 | def fake_image_pool(self, num_fakes, fake, fake_pool): 213 | ''' This function saves the generated image to corresponding pool of images. 214 | In starting. It keeps on feeling the pool till it is full and then randomly selects an 215 | already stored image and replace it with new one.''' 216 | 217 | if(num_fakes < pool_size): 218 | fake_pool[num_fakes] = fake 219 | return fake 220 | else : 221 | p = random.random() 222 | if p > 0.5: 223 | random_id = random.randint(0,pool_size-1) 224 | temp = fake_pool[random_id] 225 | fake_pool[random_id] = fake 226 | return temp 227 | else : 228 | return fake 229 | 230 | 231 | def train(self): 232 | 233 | 234 | ''' Training Function ''' 235 | 236 | 237 | # Load Dataset from the dataset folder 238 | self.input_setup() 239 | 240 | #Build the network 241 | self.model_setup() 242 | 243 | #Loss function calculations 244 | self.loss_calc() 245 | 246 | # Initializing the global variables 247 | init = tf.global_variables_initializer() 248 | saver = tf.train.Saver() 249 | 250 | with tf.Session() as sess: 251 | sess.run(init) 252 | 253 | #Read input to nd array 254 | self.input_read(sess) 255 | 256 | #Restore the model to run the model from last checkpoint 257 | if to_restore: 258 | chkpt_fname = tf.train.latest_checkpoint(check_dir) 259 | saver.restore(sess, chkpt_fname) 260 | 261 | writer = tf.summary.FileWriter("./output/2") 262 | 263 | if not os.path.exists(check_dir): 264 | os.makedirs(check_dir) 265 | 266 | # Training Loop 267 | for epoch in range(sess.run(self.global_step),100): 268 | print ("In the epoch ", epoch) 269 | saver.save(sess,os.path.join(check_dir,"cyclegan"),global_step=epoch) 270 | 271 | # Dealing with the learning rate as per the epoch number 272 | if(epoch < 100) : 273 | curr_lr = 0.0002 274 | else: 275 | curr_lr = 0.0002 - 0.0002*(epoch-100)/100 276 | 277 | if(save_training_images): 278 | self.save_training_images(sess, epoch) 279 | 280 | # sys.exit() 281 | 282 | for ptr in range(0,max_images): 283 | print("In the iteration ",ptr) 284 | print("Starting",time.time()*1000.0) 285 | 286 | # Optimizing the G_A network 287 | 288 | _, fake_B_temp, summary_str = sess.run([self.g_A_trainer, self.fake_B, self.g_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr}) 289 | 290 | writer.add_summary(summary_str, epoch*max_images + ptr) 291 | fake_B_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_B_temp, self.fake_images_B) 292 | 293 | # Optimizing the D_B network 294 | _, summary_str = sess.run([self.d_B_trainer, self.d_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_B:fake_B_temp1}) 295 | writer.add_summary(summary_str, epoch*max_images + ptr) 296 | 297 | 298 | # Optimizing the G_B network 299 | _, fake_A_temp, summary_str = sess.run([self.g_B_trainer, self.fake_A, self.g_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr}) 300 | 301 | writer.add_summary(summary_str, epoch*max_images + ptr) 302 | 303 | 304 | fake_A_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_A_temp, self.fake_images_A) 305 | 306 | # Optimizing the D_A network 307 | _, summary_str = sess.run([self.d_A_trainer, self.d_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_A:fake_A_temp1}) 308 | 309 | writer.add_summary(summary_str, epoch*max_images + ptr) 310 | 311 | self.num_fake_inputs+=1 312 | 313 | 314 | 315 | sess.run(tf.assign(self.global_step, epoch + 1)) 316 | 317 | writer.add_graph(sess.graph) 318 | 319 | def test(self): 320 | 321 | 322 | ''' Testing Function''' 323 | 324 | print("Testing the results") 325 | 326 | self.input_setup() 327 | 328 | self.model_setup() 329 | saver = tf.train.Saver() 330 | init = tf.global_variables_initializer() 331 | 332 | with tf.Session() as sess: 333 | 334 | sess.run(init) 335 | 336 | self.input_read(sess) 337 | 338 | chkpt_fname = tf.train.latest_checkpoint(check_dir) 339 | saver.restore(sess, chkpt_fname) 340 | 341 | if not os.path.exists("./output/imgs/test/"): 342 | os.makedirs("./output/imgs/test/") 343 | 344 | for i in range(0,100): 345 | fake_A_temp, fake_B_temp = sess.run([self.fake_A, self.fake_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]}) 346 | imsave("./output/imgs/test/fakeB_"+str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8)) 347 | imsave("./output/imgs/test/fakeA_"+str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8)) 348 | imsave("./output/imgs/test/inputA_"+str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8)) 349 | imsave("./output/imgs/test/inputB_"+str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8)) 350 | 351 | 352 | def main(): 353 | 354 | model = CycleGAN() 355 | if to_train: 356 | model.train() 357 | elif to_test: 358 | model.test() 359 | 360 | if __name__ == '__main__': 361 | 362 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow 2 | 3 | import tensorflow as tf 4 | from tensorflow.examples.tutorials.mnist import input_data 5 | import numpy as np 6 | from scipy.misc import imsave 7 | import os 8 | import shutil 9 | from PIL import Image 10 | import time 11 | import random 12 | 13 | 14 | from layers import * 15 | 16 | img_height = 256 17 | img_width = 256 18 | img_layer = 3 19 | img_size = img_height * img_width 20 | 21 | 22 | batch_size = 1 23 | pool_size = 50 24 | ngf = 32 25 | ndf = 64 26 | 27 | 28 | 29 | 30 | 31 | def build_resnet_block(inputres, dim, name="resnet"): 32 | 33 | with tf.variable_scope(name): 34 | 35 | out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT") 36 | out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID","c1") 37 | out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT") 38 | out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID","c2",do_relu=False) 39 | 40 | return tf.nn.relu(out_res + inputres) 41 | 42 | 43 | def build_generator_resnet_6blocks(inputgen, name="generator"): 44 | with tf.variable_scope(name): 45 | f = 7 46 | ks = 3 47 | 48 | pad_input = tf.pad(inputgen,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT") 49 | o_c1 = general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02,name="c1") 50 | o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02,"SAME","c2") 51 | o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02,"SAME","c3") 52 | 53 | o_r1 = build_resnet_block(o_c3, ngf*4, "r1") 54 | o_r2 = build_resnet_block(o_r1, ngf*4, "r2") 55 | o_r3 = build_resnet_block(o_r2, ngf*4, "r3") 56 | o_r4 = build_resnet_block(o_r3, ngf*4, "r4") 57 | o_r5 = build_resnet_block(o_r4, ngf*4, "r5") 58 | o_r6 = build_resnet_block(o_r5, ngf*4, "r6") 59 | 60 | o_c4 = general_deconv2d(o_r6, [batch_size,64,64,ngf*2], ngf*2, ks, ks, 2, 2, 0.02,"SAME","c4") 61 | o_c5 = general_deconv2d(o_c4, [batch_size,128,128,ngf], ngf, ks, ks, 2, 2, 0.02,"SAME","c5") 62 | o_c5_pad = tf.pad(o_c5,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT") 63 | o_c6 = general_conv2d(o_c5_pad, img_layer, f, f, 1, 1, 0.02,"VALID","c6",do_relu=False) 64 | 65 | # Adding the tanh layer 66 | 67 | out_gen = tf.nn.tanh(o_c6,"t1") 68 | 69 | 70 | return out_gen 71 | 72 | def build_generator_resnet_9blocks(inputgen, name="generator"): 73 | with tf.variable_scope(name): 74 | f = 7 75 | ks = 3 76 | 77 | pad_input = tf.pad(inputgen,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT") 78 | o_c1 = general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02,name="c1") 79 | o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02,"SAME","c2") 80 | o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02,"SAME","c3") 81 | 82 | o_r1 = build_resnet_block(o_c3, ngf*4, "r1") 83 | o_r2 = build_resnet_block(o_r1, ngf*4, "r2") 84 | o_r3 = build_resnet_block(o_r2, ngf*4, "r3") 85 | o_r4 = build_resnet_block(o_r3, ngf*4, "r4") 86 | o_r5 = build_resnet_block(o_r4, ngf*4, "r5") 87 | o_r6 = build_resnet_block(o_r5, ngf*4, "r6") 88 | o_r7 = build_resnet_block(o_r6, ngf*4, "r7") 89 | o_r8 = build_resnet_block(o_r7, ngf*4, "r8") 90 | o_r9 = build_resnet_block(o_r8, ngf*4, "r9") 91 | 92 | o_c4 = general_deconv2d(o_r9, [batch_size,128,128,ngf*2], ngf*2, ks, ks, 2, 2, 0.02,"SAME","c4") 93 | o_c5 = general_deconv2d(o_c4, [batch_size,256,256,ngf], ngf, ks, ks, 2, 2, 0.02,"SAME","c5") 94 | o_c6 = general_conv2d(o_c5, img_layer, f, f, 1, 1, 0.02,"SAME","c6",do_relu=False) 95 | 96 | # Adding the tanh layer 97 | 98 | out_gen = tf.nn.tanh(o_c6,"t1") 99 | 100 | 101 | return out_gen 102 | 103 | 104 | def build_gen_discriminator(inputdisc, name="discriminator"): 105 | 106 | with tf.variable_scope(name): 107 | f = 4 108 | 109 | o_c1 = general_conv2d(inputdisc, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2) 110 | o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) 111 | o_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2) 112 | o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4",relufactor=0.2) 113 | o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5",do_norm=False,do_relu=False) 114 | 115 | return o_c5 116 | 117 | 118 | def patch_discriminator(inputdisc, name="discriminator"): 119 | 120 | with tf.variable_scope(name): 121 | f= 4 122 | 123 | patch_input = tf.random_crop(inputdisc,[1,70,70,3]) 124 | o_c1 = general_conv2d(patch_input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm="False", relufactor=0.2) 125 | o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) 126 | o_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2) 127 | o_c4 = general_conv2d(o_c3, ndf*8, f, f, 2, 2, 0.02, "SAME", "c4", relufactor=0.2) 128 | o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5",do_norm=False,do_relu=False) 129 | 130 | return o_c5 --------------------------------------------------------------------------------