├── .gitignore
├── .ipynb_checkpoints
├── Blog for CycleGan-checkpoint.ipynb
└── CycleGAN_notebook-checkpoint.ipynb
├── README.md
├── __pycache__
└── layers.cpython-36.pyc
├── download_datasets.sh
├── images
├── Generator.jpg
├── Results.jpg
├── discriminator.jpg
├── distortion.jpg
└── model.jpg
├── layers.py
├── main.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | output/
3 | input/
4 | MNIST_data
5 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/Blog for CycleGan-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Cross domain image transfer has been a hot topic and various models have been used to acheive this task. Recenetly, realising the power of GAN networks people have started making model based on GANs and the results are quite impressive. In this blog, we will try to understand the recent development in this direction by Kim et. al https://arxiv.org/pdf/1703.05192v1.pdf."
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {
13 | "collapsed": true
14 | },
15 | "source": [
16 | "There are various components of transferring an image from one cateogry to another category. So let us first look at these things in detail\n",
17 | "\n",
18 | "1.) Extracting the features of an image (Decoding): So, first of all what do we mean by features of an image. Suppose we want to convert an image of man to a look-alike woman, first, we need to understand various general things about the human face like this thing is nose, or ear, or eyebrows or hair etc. These are some high level features but there are some low level features like edges and curves. Using these low level features only we recongise these high level features which help use break down the image and understand the details of it. Model that is used to extract these features almost in all neutral network is call Convolution Network which you can read in detail about in this paper: [Link to nice blog post on Convolution]\n",
19 | "\n",
20 | "\n",
21 | "\n",
22 | "2.) Learning transformation from the feature of one model to that of another model (Transfer): After first step, we have high and low level features of an image and we would like to convert these details to that of another category. Let us look at one of these transformatin that we need to learn. Suppose we detected that a man's face has some beard and we like to convert this face to that of a woman then we need to remove the beard and smothen out that part of face to match with color of face. Another example might include thick eyebrows to thin eyebrows and so on.\n",
23 | "\n",
24 | "\n",
25 | "3.) Making the final image (Encoding): This step can bee send as the reversal of step 1. Here, we will look at different features of an image and construct a face out of it.\n",
26 | "\n",
27 | "\n",
28 | "Things need to keep in mind. We will not learn high level features like this is an ear or this part is nose because these are very high level features but maybe like this the upper part of nose or like this is left part of the left eyebrow and so on. Basically some mid level features.\n",
29 | "\n",
30 | "\n",
31 | "So, now that we have understood the key elements of an image transformation, let use see how they model this using neural network. Once, you are clear with these steps and have basic knowledge of Deep Learning you are good to go to make the model."
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "Extracting the features:\n",
39 | "\n",
40 | "For the basic about convolution network you can go through this very intuitive blog post by []. So the first step is extracting 64 very low level features (using a window of [7,7]).This can easily be done via a conv layer as follow:"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "collapsed": false
48 | },
49 | "outputs": [],
50 | "source": [
51 | "o_c1 = general_conv2d(input, 64, 7, 7, 1, 1)"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "Now, we can keep on extracting further features and since we dont want to blow up the size of image, we will use stride of 2. It will prevent overlapping as well as reduce the size of output. So, we further extract features as follow"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "o_c2 = general_conv2d(o_c1, 64*2, 3, 3, 2, 2)\n",
70 | "o_c3 = general_conv2d(o_c2, 64*4, 3, 3, 2, 2)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "So, as you can see we are extracting more and more high-level features from the previous layer of low-level features moving towards extracting very high level features."
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "So we can assume that after these step we will have fairly decent amount of features that we would like to transform to a woman's face. So we will build the transformation layers as follow. For these transformation we would like to maintain the same number of features, so the number of features of intermediate output of intermediate layers will be same."
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {
91 | "collapsed": true
92 | },
93 | "outputs": [],
94 | "source": []
95 | }
96 | ],
97 | "metadata": {
98 | "kernelspec": {
99 | "display_name": "Python 2",
100 | "language": "python",
101 | "name": "python2"
102 | },
103 | "language_info": {
104 | "codemirror_mode": {
105 | "name": "ipython",
106 | "version": 2
107 | },
108 | "file_extension": ".py",
109 | "mimetype": "text/x-python",
110 | "name": "python",
111 | "nbconvert_exporter": "python",
112 | "pygments_lexer": "ipython2",
113 | "version": "2.7.11"
114 | }
115 | },
116 | "nbformat": 4,
117 | "nbformat_minor": 2
118 | }
119 |
--------------------------------------------------------------------------------
/.ipynb_checkpoints/CycleGAN_notebook-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Library imports "
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "collapsed": false
15 | },
16 | "outputs": [],
17 | "source": [
18 | "# Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow\n",
19 | "\n",
20 | "import tensorflow as tf\n",
21 | "from tensorflow.examples.tutorials.mnist import input_data\n",
22 | "import numpy as np\n",
23 | "from skimage.io import imsave\n",
24 | "import os\n",
25 | "import shutil"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Constants and flags "
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {
39 | "collapsed": true
40 | },
41 | "outputs": [],
42 | "source": [
43 | "img_height = 28\n",
44 | "img_width = 28\n",
45 | "img_size = img_height * img_width\n",
46 | "\n",
47 | "to_train = True\n",
48 | "to_restore = False\n",
49 | "output_path = \"output\"\n",
50 | "\n",
51 | "max_epoch = 500\n",
52 | "\n",
53 | "h1_size = 150\n",
54 | "h2_size = 300\n",
55 | "z_size = 100\n",
56 | "batch_size = 256\n",
57 | "ngf = 128"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "## Function definitions "
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "### Convolution layer \n",
72 | "Defines a general 2D convolution layer with batch normalization"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {
79 | "collapsed": true
80 | },
81 | "outputs": [],
82 | "source": [
83 | "def general_conv2d(inputconv, name=\"conv2d\",\n",
84 | " o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, \n",
85 | " stddev=0.02, padding=None, \n",
86 | " do_norm=True, do_relu=True):\n",
87 | " '''Defines a general 2D convolution layer with batch normalization'''\n",
88 | " \n",
89 | " with tf.variable_scope(name):\n",
90 | " initializer = tf.truncated_normal_intializer(stddev=stddev)\n",
91 | " w = tf.get_variable('w',\n",
92 | " [f_h, f_w, inputconv.get_shape(-1), o_d], \n",
93 | " initializer=initializer)\n",
94 | " conv = tf.nn.conv2d(inputconv,\n",
95 | " filter=w,strides=[1, s_w, s_h, 1],\n",
96 | " padding=padding)\n",
97 | " biases = tf.get_variable('b',\n",
98 | " [o_d],\n",
99 | " initializer=tf.constant_initializer(0.0))\n",
100 | " conv = tf.nn.bias_add(conv,biases)\n",
101 | " \n",
102 | " # Add batch_norm layer\n",
103 | " if do_norm:\n",
104 | " dims = conv.get_shape()\n",
105 | " scale = tf.get_variable('scale',\n",
106 | " [dims[1],dims[2],dims[3]],\n",
107 | " tf.constant_initializer(1))\n",
108 | " beta = tf.get_variable('beta',\n",
109 | " [dims[1],dims[2],dims[3]],\n",
110 | " tf.constant_initializer(0))\n",
111 | " conv_mean, conv_var = tf.nn.moments(conv,[0])\n",
112 | " conv = tf.nn.batch_normalization(conv, conv_mean, conv_var, beta, scale, 0.001)\n",
113 | " # Add ReLU activation\n",
114 | " if do_relu:\n",
115 | " conv = tf.nn.relu(conv,0)\n",
116 | " return conv"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "### Resnet block "
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {
130 | "collapsed": false
131 | },
132 | "outputs": [],
133 | "source": [
134 | "def build_resnet_block(inputres, dim, name=\"resnet\"):\n",
135 | " out_res = inputres\n",
136 | " with tf.variable_scope(name):\n",
137 | " out_res = general_conv2d(inputres, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c1\")\n",
138 | " out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, \"SAME\", \"c2\", do_relu=False)\n",
139 | " return tf.nn.relu(out_res + inputres)\n",
140 | "\n",
141 | "def build_generator_resnet_6blocks(inputgen, name=\"generator\"):\n",
142 | " with tf.variable_scope(name):\n",
143 | " f = 7\n",
144 | " ks = 3\n",
145 | " o_c1 = general_conv2d(inputgen, ngf, f, f, 1, 1, 0.02, \"SAME\", \"c1\")\n",
146 | " o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, None, \"c2\")\n",
147 | " o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, None, \"c3\")"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "### Show results "
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "collapsed": true
162 | },
163 | "outputs": [],
164 | "source": [
165 | "def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):\n",
166 | " batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5\n",
167 | " img_h, img_w = batch_res.shape[1], batch_res.shape[2]\n",
168 | " grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)\n",
169 | " grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)\n",
170 | " img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)\n",
171 | " for i, res in enumerate(batch_res):\n",
172 | " if i >= grid_size[0] * grid_size[1]:\n",
173 | " break\n",
174 | " img = (res) * 255\n",
175 | " img = img.astype(np.uint8)\n",
176 | " row = (i // grid_size[0]) * (img_h + grid_pad)\n",
177 | " col = (i % grid_size[1]) * (img_w + grid_pad)\n",
178 | " img_grid[row:row + img_h, col:col + img_w] = img\n",
179 | " imsave(fname, img_grid)"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "metadata": {},
185 | "source": [
186 | "### Training and testing routines "
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {
193 | "collapsed": false
194 | },
195 | "outputs": [],
196 | "source": [
197 | "def train():\n",
198 | " mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
199 | "\n",
200 | " x_data = tf.placeholder(tf.float32, [batch_size, img_size], name=\"x_data\")\n",
201 | " z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n",
202 | " keep_prob = tf.placeholder(tf.float32, name=\"keep_prob\")\n",
203 | " global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n",
204 | "\n",
205 | " x_generated, g_params = build_generator(z_prior)\n",
206 | " y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)\n",
207 | "\n",
208 | " d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))\n",
209 | " g_loss = - tf.log(y_generated)\n",
210 | "\n",
211 | " optimizer = tf.train.AdamOptimizer(0.0001)\n",
212 | "\n",
213 | " d_trainer = optimizer.minimize(d_loss, var_list=d_params)\n",
214 | " g_trainer = optimizer.minimize(g_loss, var_list=g_params)\n",
215 | "\n",
216 | " init = tf.initialize_all_variables()\n",
217 | "\n",
218 | " saver = tf.train.Saver()\n",
219 | "\n",
220 | " sess = tf.Session()\n",
221 | "\n",
222 | " sess.run(init)\n",
223 | "\n",
224 | " if to_restore:\n",
225 | " chkpt_fname = tf.train.latest_checkpoint(output_path)\n",
226 | " saver.restore(sess, chkpt_fname)\n",
227 | " else:\n",
228 | " if os.path.exists(output_path):\n",
229 | " shutil.rmtree(output_path)\n",
230 | " os.mkdir(output_path)\n",
231 | "\n",
232 | "\n",
233 | " z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
234 | "\n",
235 | " for i in range(sess.run(global_step), max_epoch):\n",
236 | " for j in range(60000 / batch_size):\n",
237 | " print(\"epoch:%s, iter:%s\" % (i, j))\n",
238 | " x_value, _ = mnist.train.next_batch(batch_size)\n",
239 | " x_value = 2 * x_value.astype(np.float32) - 1\n",
240 | " z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
241 | " sess.run(d_trainer,\n",
242 | " feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n",
243 | " if j % 1 == 0:\n",
244 | " sess.run(g_trainer,\n",
245 | " feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})\n",
246 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})\n",
247 | " show_result(x_gen_val, \"output/sample{0}.jpg\".format(i))\n",
248 | " z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
249 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})\n",
250 | " show_result(x_gen_val, \"output/random_sample{0}.jpg\".format(i))\n",
251 | " sess.run(tf.assign(global_step, i + 1))\n",
252 | " saver.save(sess, os.path.join(output_path, \"model\"), global_step=global_step)\n",
253 | "\n",
254 | "\n",
255 | "def test():\n",
256 | " z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name=\"z_prior\")\n",
257 | " x_generated, _ = build_generator(z_prior)\n",
258 | " chkpt_fname = tf.train.latest_checkpoint(output_path)\n",
259 | "\n",
260 | " init = tf.initialize_all_variables()\n",
261 | " sess = tf.Session()\n",
262 | " saver = tf.train.Saver()\n",
263 | " sess.run(init)\n",
264 | " saver.restore(sess, chkpt_fname)\n",
265 | " z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)\n",
266 | " x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})\n",
267 | " show_result(x_gen_val, \"output/test_result.jpg\")"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "metadata": {},
273 | "source": [
274 | "## Main/driver"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {
281 | "collapsed": false
282 | },
283 | "outputs": [],
284 | "source": [
285 | "if __name__ == '__main__':\n",
286 | " if to_train:\n",
287 | " train()\n",
288 | " else:\n",
289 | " test()"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {
296 | "collapsed": true
297 | },
298 | "outputs": [],
299 | "source": []
300 | }
301 | ],
302 | "metadata": {
303 | "kernelspec": {
304 | "display_name": "Python 2",
305 | "language": "python",
306 | "name": "python2"
307 | },
308 | "language_info": {
309 | "codemirror_mode": {
310 | "name": "ipython",
311 | "version": 2
312 | },
313 | "file_extension": ".py",
314 | "mimetype": "text/x-python",
315 | "name": "python",
316 | "nbconvert_exporter": "python",
317 | "pygments_lexer": "ipython2",
318 | "version": "2.7.11"
319 | }
320 | },
321 | "nbformat": 4,
322 | "nbformat_minor": 2
323 | }
324 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CycleGAN
2 | Tensorflow implementation of CycleGAN.
3 |
4 | 1. [Original implementation](https://github.com/junyanz/CycleGAN/)
5 | 2. [Paper](https://arxiv.org/abs/1703.10593)
6 |
7 |
8 |
9 | ### CycleGAN model
10 |
11 | CycleGAN model can be summarized in the following image. For full details about implementation and understanding CycleGAN you can read the tutorial at this [link](https://hardikbansal.github.io/CycleGANBlog/)
12 |
13 |
14 |
15 |
16 |
17 |
18 | ##### Generator
19 |
20 |
21 |
22 |
23 |
24 |
25 | ##### Discriminator
26 |
27 |
28 |
29 |
30 |
31 | ### Our Results
32 |
33 | We ran the model for [horse2zebra dataset](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip) but because of the lack of resources, we just ran the model for 100 epochs and got following results.
34 |
35 |
36 |
37 |
38 |
39 | ### Final Comments
40 |
41 | 1. During training we noticed that the ouput results were sensitive to initialization. Thanks to [vanhuyz](https://github.com/vanhuyz) for pointing this out and suggesting training multiple times to get best results. You might notice background color being reversed as in following image. This effect can be observed only after 10-20 epochs and you can try to run the code again.
42 |
43 |
44 |
45 |
46 |
47 | 2. We also think that this model is not good fit to change the shape of object. We tried to run the model for converting a men's face to a look alike women's face. For that we used celebA dataset but the results are not good and images produced are quite distorted.
48 |
49 |
50 | ### Blog
51 |
52 | If you would like to understand the paper and see how to implement it by your own, you can have look at the blog by [me](https://hardikbansal.github.io/CycleGANBlog/)
53 |
--------------------------------------------------------------------------------
/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/download_datasets.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
5 | exit 1
6 | fi
7 |
8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
9 | ZIP_FILE=./input/$FILE.zip
10 | TARGET_DIR=./input/$FILE/
11 | wget -N $URL -O $ZIP_FILE
12 | mkdir $TARGET_DIR
13 | unzip $ZIP_FILE -d ./input/
14 | rm $ZIP_FILE
--------------------------------------------------------------------------------
/images/Generator.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/Generator.jpg
--------------------------------------------------------------------------------
/images/Results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/Results.jpg
--------------------------------------------------------------------------------
/images/discriminator.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/discriminator.jpg
--------------------------------------------------------------------------------
/images/distortion.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/distortion.jpg
--------------------------------------------------------------------------------
/images/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/architrathore/CycleGAN/67cca4e710874ad61a63651bd69d624011ae8295/images/model.jpg
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
4 |
5 | with tf.variable_scope(name):
6 | if alt_relu_impl:
7 | f1 = 0.5 * (1 + leak)
8 | f2 = 0.5 * (1 - leak)
9 | # lrelu = 1/2 * (1 + leak) * x + 1/2 * (1 - leak) * |x|
10 | return f1 * x + f2 * abs(x)
11 | else:
12 | return tf.maximum(x, leak*x)
13 |
14 | def instance_norm(x):
15 |
16 | with tf.variable_scope("instance_norm"):
17 | epsilon = 1e-5
18 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
19 | scale = tf.get_variable('scale',[x.get_shape()[-1]],
20 | initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))
21 | offset = tf.get_variable('offset',[x.get_shape()[-1]],initializer=tf.constant_initializer(0.0))
22 | out = scale*tf.div(x-mean, tf.sqrt(var+epsilon)) + offset
23 |
24 | return out
25 |
26 |
27 | def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_norm=True, do_relu=True, relufactor=0):
28 | with tf.variable_scope(name):
29 |
30 | conv = tf.contrib.layers.conv2d(inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))
31 | if do_norm:
32 | conv = instance_norm(conv)
33 | # conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm")
34 |
35 | if do_relu:
36 | if(relufactor == 0):
37 | conv = tf.nn.relu(conv,"relu")
38 | else:
39 | conv = lrelu(conv, relufactor, "lrelu")
40 |
41 | return conv
42 |
43 |
44 |
45 | def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):
46 | with tf.variable_scope(name):
47 |
48 | conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))
49 |
50 | if do_norm:
51 | conv = instance_norm(conv)
52 | # conv = tf.contrib.layers.batch_norm(conv, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope="batch_norm")
53 |
54 | if do_relu:
55 | if(relufactor == 0):
56 | conv = tf.nn.relu(conv,"relu")
57 | else:
58 | conv = lrelu(conv, relufactor, "lrelu")
59 |
60 | return conv
61 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.examples.tutorials.mnist import input_data
3 | import numpy as np
4 | from scipy.misc import imsave
5 | import os
6 | import shutil
7 | from PIL import Image
8 | import time
9 | import random
10 | import sys
11 |
12 |
13 | from layers import *
14 | from model import *
15 |
16 | img_height = 256
17 | img_width = 256
18 | img_layer = 3
19 | img_size = img_height * img_width
20 |
21 | to_train = True
22 | to_test = False
23 | to_restore = False
24 | output_path = "./output"
25 | check_dir = "./output/checkpoints/"
26 |
27 |
28 | temp_check = 0
29 |
30 |
31 |
32 | max_epoch = 1
33 | max_images = 100
34 |
35 | h1_size = 150
36 | h2_size = 300
37 | z_size = 100
38 | batch_size = 1
39 | pool_size = 50
40 | sample_size = 10
41 | save_training_images = True
42 | ngf = 32
43 | ndf = 64
44 |
45 | class CycleGAN():
46 |
47 | def input_setup(self):
48 |
49 | '''
50 | This function basically setup variables for taking image input.
51 |
52 | filenames_A/filenames_B -> takes the list of all training images
53 | self.image_A/self.image_B -> Input image with each values ranging from [-1,1]
54 | '''
55 |
56 | filenames_A = tf.train.match_filenames_once("./input/horse2zebra/trainA/*.jpg")
57 | self.queue_length_A = tf.size(filenames_A)
58 | filenames_B = tf.train.match_filenames_once("./input/horse2zebra/trainB/*.jpg")
59 | self.queue_length_B = tf.size(filenames_B)
60 |
61 | filename_queue_A = tf.train.string_input_producer(filenames_A)
62 | filename_queue_B = tf.train.string_input_producer(filenames_B)
63 |
64 | image_reader = tf.WholeFileReader()
65 | _, image_file_A = image_reader.read(filename_queue_A)
66 | _, image_file_B = image_reader.read(filename_queue_B)
67 |
68 | self.image_A = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_A),[256,256]),127.5),1)
69 | self.image_B = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_B),[256,256]),127.5),1)
70 |
71 |
72 |
73 | def input_read(self, sess):
74 |
75 |
76 | '''
77 | It reads the input into from the image folder.
78 |
79 | self.fake_images_A/self.fake_images_B -> List of generated images used for calculation of loss function of Discriminator
80 | self.A_input/self.B_input -> Stores all the training images in python list
81 | '''
82 |
83 | # Loading images into the tensors
84 | coord = tf.train.Coordinator()
85 | threads = tf.train.start_queue_runners(coord=coord)
86 |
87 | num_files_A = sess.run(self.queue_length_A)
88 | num_files_B = sess.run(self.queue_length_B)
89 |
90 | self.fake_images_A = np.zeros((pool_size,1,img_height, img_width, img_layer))
91 | self.fake_images_B = np.zeros((pool_size,1,img_height, img_width, img_layer))
92 |
93 |
94 | self.A_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer))
95 | self.B_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer))
96 |
97 | for i in range(max_images):
98 | image_tensor = sess.run(self.image_A)
99 | if(image_tensor.size() == img_size*batch_size*img_layer):
100 | self.A_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer))
101 |
102 | for i in range(max_images):
103 | image_tensor = sess.run(self.image_B)
104 | if(image_tensor.size() == img_size*batch_size*img_layer):
105 | self.B_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer))
106 |
107 |
108 | coord.request_stop()
109 | coord.join(threads)
110 |
111 |
112 |
113 |
114 | def model_setup(self):
115 |
116 | ''' This function sets up the model to train
117 |
118 | self.input_A/self.input_B -> Set of training images.
119 | self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B
120 | self.lr -> Learning rate variable
121 | self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calcualte cyclic loss
122 | '''
123 |
124 | self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
125 | self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")
126 |
127 | self.fake_pool_A = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_A")
128 | self.fake_pool_B = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_B")
129 |
130 | self.global_step = tf.Variable(0, name="global_step", trainable=False)
131 |
132 | self.num_fake_inputs = 0
133 |
134 | self.lr = tf.placeholder(tf.float32, shape=[], name="lr")
135 |
136 | with tf.variable_scope("Model") as scope:
137 | self.fake_B = build_generator_resnet_9blocks(self.input_A, name="g_A")
138 | self.fake_A = build_generator_resnet_9blocks(self.input_B, name="g_B")
139 | self.rec_A = build_gen_discriminator(self.input_A, "d_A")
140 | self.rec_B = build_gen_discriminator(self.input_B, "d_B")
141 |
142 | scope.reuse_variables()
143 |
144 | self.fake_rec_A = build_gen_discriminator(self.fake_A, "d_A")
145 | self.fake_rec_B = build_gen_discriminator(self.fake_B, "d_B")
146 | self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
147 | self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")
148 |
149 | scope.reuse_variables()
150 |
151 | self.fake_pool_rec_A = build_gen_discriminator(self.fake_pool_A, "d_A")
152 | self.fake_pool_rec_B = build_gen_discriminator(self.fake_pool_B, "d_B")
153 |
154 | def loss_calc(self):
155 |
156 | ''' In this function we are defining the variables for loss calcultions and traning model
157 |
158 | d_loss_A/d_loss_B -> loss for discriminator A/B
159 | g_loss_A/g_loss_B -> loss for generator A/B
160 | *_trainer -> Variaous trainer for above loss functions
161 | *_summ -> Summary variables for above loss functions'''
162 |
163 | cyc_loss = tf.reduce_mean(tf.abs(self.input_A-self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B-self.cyc_B))
164 |
165 | disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A,1))
166 | disc_loss_B = tf.reduce_mean(tf.squared_difference(self.fake_rec_B,1))
167 |
168 | g_loss_A = cyc_loss*10 + disc_loss_B
169 | g_loss_B = cyc_loss*10 + disc_loss_A
170 |
171 | d_loss_A = (tf.reduce_mean(tf.square(self.fake_pool_rec_A)) + tf.reduce_mean(tf.squared_difference(self.rec_A,1)))/2.0
172 | d_loss_B = (tf.reduce_mean(tf.square(self.fake_pool_rec_B)) + tf.reduce_mean(tf.squared_difference(self.rec_B,1)))/2.0
173 |
174 |
175 | optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5)
176 |
177 | self.model_vars = tf.trainable_variables()
178 |
179 | d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
180 | g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
181 | d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
182 | g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]
183 |
184 | self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
185 | self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
186 | self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
187 | self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
188 |
189 | for var in self.model_vars: print(var.name)
190 |
191 | #Summary variables for tensorboard
192 |
193 | self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
194 | self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
195 | self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
196 | self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
197 |
198 | def save_training_images(self, sess, epoch):
199 |
200 | if not os.path.exists("./output/imgs"):
201 | os.makedirs("./output/imgs")
202 |
203 | for i in range(0,10):
204 | fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([self.fake_A, self.fake_B, self.cyc_A, self.cyc_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})
205 | imsave("./output/imgs/fakeB_"+ str(epoch) + "_" + str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8))
206 | imsave("./output/imgs/fakeA_"+ str(epoch) + "_" + str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8))
207 | imsave("./output/imgs/cycA_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_A_temp[0]+1)*127.5).astype(np.uint8))
208 | imsave("./output/imgs/cycB_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_B_temp[0]+1)*127.5).astype(np.uint8))
209 | imsave("./output/imgs/inputA_"+ str(epoch) + "_" + str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8))
210 | imsave("./output/imgs/inputB_"+ str(epoch) + "_" + str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8))
211 |
212 | def fake_image_pool(self, num_fakes, fake, fake_pool):
213 | ''' This function saves the generated image to corresponding pool of images.
214 | In starting. It keeps on feeling the pool till it is full and then randomly selects an
215 | already stored image and replace it with new one.'''
216 |
217 | if(num_fakes < pool_size):
218 | fake_pool[num_fakes] = fake
219 | return fake
220 | else :
221 | p = random.random()
222 | if p > 0.5:
223 | random_id = random.randint(0,pool_size-1)
224 | temp = fake_pool[random_id]
225 | fake_pool[random_id] = fake
226 | return temp
227 | else :
228 | return fake
229 |
230 |
231 | def train(self):
232 |
233 |
234 | ''' Training Function '''
235 |
236 |
237 | # Load Dataset from the dataset folder
238 | self.input_setup()
239 |
240 | #Build the network
241 | self.model_setup()
242 |
243 | #Loss function calculations
244 | self.loss_calc()
245 |
246 | # Initializing the global variables
247 | init = tf.global_variables_initializer()
248 | saver = tf.train.Saver()
249 |
250 | with tf.Session() as sess:
251 | sess.run(init)
252 |
253 | #Read input to nd array
254 | self.input_read(sess)
255 |
256 | #Restore the model to run the model from last checkpoint
257 | if to_restore:
258 | chkpt_fname = tf.train.latest_checkpoint(check_dir)
259 | saver.restore(sess, chkpt_fname)
260 |
261 | writer = tf.summary.FileWriter("./output/2")
262 |
263 | if not os.path.exists(check_dir):
264 | os.makedirs(check_dir)
265 |
266 | # Training Loop
267 | for epoch in range(sess.run(self.global_step),100):
268 | print ("In the epoch ", epoch)
269 | saver.save(sess,os.path.join(check_dir,"cyclegan"),global_step=epoch)
270 |
271 | # Dealing with the learning rate as per the epoch number
272 | if(epoch < 100) :
273 | curr_lr = 0.0002
274 | else:
275 | curr_lr = 0.0002 - 0.0002*(epoch-100)/100
276 |
277 | if(save_training_images):
278 | self.save_training_images(sess, epoch)
279 |
280 | # sys.exit()
281 |
282 | for ptr in range(0,max_images):
283 | print("In the iteration ",ptr)
284 | print("Starting",time.time()*1000.0)
285 |
286 | # Optimizing the G_A network
287 |
288 | _, fake_B_temp, summary_str = sess.run([self.g_A_trainer, self.fake_B, self.g_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})
289 |
290 | writer.add_summary(summary_str, epoch*max_images + ptr)
291 | fake_B_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_B_temp, self.fake_images_B)
292 |
293 | # Optimizing the D_B network
294 | _, summary_str = sess.run([self.d_B_trainer, self.d_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_B:fake_B_temp1})
295 | writer.add_summary(summary_str, epoch*max_images + ptr)
296 |
297 |
298 | # Optimizing the G_B network
299 | _, fake_A_temp, summary_str = sess.run([self.g_B_trainer, self.fake_A, self.g_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})
300 |
301 | writer.add_summary(summary_str, epoch*max_images + ptr)
302 |
303 |
304 | fake_A_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_A_temp, self.fake_images_A)
305 |
306 | # Optimizing the D_A network
307 | _, summary_str = sess.run([self.d_A_trainer, self.d_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_A:fake_A_temp1})
308 |
309 | writer.add_summary(summary_str, epoch*max_images + ptr)
310 |
311 | self.num_fake_inputs+=1
312 |
313 |
314 |
315 | sess.run(tf.assign(self.global_step, epoch + 1))
316 |
317 | writer.add_graph(sess.graph)
318 |
319 | def test(self):
320 |
321 |
322 | ''' Testing Function'''
323 |
324 | print("Testing the results")
325 |
326 | self.input_setup()
327 |
328 | self.model_setup()
329 | saver = tf.train.Saver()
330 | init = tf.global_variables_initializer()
331 |
332 | with tf.Session() as sess:
333 |
334 | sess.run(init)
335 |
336 | self.input_read(sess)
337 |
338 | chkpt_fname = tf.train.latest_checkpoint(check_dir)
339 | saver.restore(sess, chkpt_fname)
340 |
341 | if not os.path.exists("./output/imgs/test/"):
342 | os.makedirs("./output/imgs/test/")
343 |
344 | for i in range(0,100):
345 | fake_A_temp, fake_B_temp = sess.run([self.fake_A, self.fake_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})
346 | imsave("./output/imgs/test/fakeB_"+str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8))
347 | imsave("./output/imgs/test/fakeA_"+str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8))
348 | imsave("./output/imgs/test/inputA_"+str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8))
349 | imsave("./output/imgs/test/inputB_"+str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8))
350 |
351 |
352 | def main():
353 |
354 | model = CycleGAN()
355 | if to_train:
356 | model.train()
357 | elif to_test:
358 | model.test()
359 |
360 | if __name__ == '__main__':
361 |
362 | main()
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow
2 |
3 | import tensorflow as tf
4 | from tensorflow.examples.tutorials.mnist import input_data
5 | import numpy as np
6 | from scipy.misc import imsave
7 | import os
8 | import shutil
9 | from PIL import Image
10 | import time
11 | import random
12 |
13 |
14 | from layers import *
15 |
16 | img_height = 256
17 | img_width = 256
18 | img_layer = 3
19 | img_size = img_height * img_width
20 |
21 |
22 | batch_size = 1
23 | pool_size = 50
24 | ngf = 32
25 | ndf = 64
26 |
27 |
28 |
29 |
30 |
31 | def build_resnet_block(inputres, dim, name="resnet"):
32 |
33 | with tf.variable_scope(name):
34 |
35 | out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")
36 | out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID","c1")
37 | out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")
38 | out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID","c2",do_relu=False)
39 |
40 | return tf.nn.relu(out_res + inputres)
41 |
42 |
43 | def build_generator_resnet_6blocks(inputgen, name="generator"):
44 | with tf.variable_scope(name):
45 | f = 7
46 | ks = 3
47 |
48 | pad_input = tf.pad(inputgen,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")
49 | o_c1 = general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02,name="c1")
50 | o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02,"SAME","c2")
51 | o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02,"SAME","c3")
52 |
53 | o_r1 = build_resnet_block(o_c3, ngf*4, "r1")
54 | o_r2 = build_resnet_block(o_r1, ngf*4, "r2")
55 | o_r3 = build_resnet_block(o_r2, ngf*4, "r3")
56 | o_r4 = build_resnet_block(o_r3, ngf*4, "r4")
57 | o_r5 = build_resnet_block(o_r4, ngf*4, "r5")
58 | o_r6 = build_resnet_block(o_r5, ngf*4, "r6")
59 |
60 | o_c4 = general_deconv2d(o_r6, [batch_size,64,64,ngf*2], ngf*2, ks, ks, 2, 2, 0.02,"SAME","c4")
61 | o_c5 = general_deconv2d(o_c4, [batch_size,128,128,ngf], ngf, ks, ks, 2, 2, 0.02,"SAME","c5")
62 | o_c5_pad = tf.pad(o_c5,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")
63 | o_c6 = general_conv2d(o_c5_pad, img_layer, f, f, 1, 1, 0.02,"VALID","c6",do_relu=False)
64 |
65 | # Adding the tanh layer
66 |
67 | out_gen = tf.nn.tanh(o_c6,"t1")
68 |
69 |
70 | return out_gen
71 |
72 | def build_generator_resnet_9blocks(inputgen, name="generator"):
73 | with tf.variable_scope(name):
74 | f = 7
75 | ks = 3
76 |
77 | pad_input = tf.pad(inputgen,[[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")
78 | o_c1 = general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02,name="c1")
79 | o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02,"SAME","c2")
80 | o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02,"SAME","c3")
81 |
82 | o_r1 = build_resnet_block(o_c3, ngf*4, "r1")
83 | o_r2 = build_resnet_block(o_r1, ngf*4, "r2")
84 | o_r3 = build_resnet_block(o_r2, ngf*4, "r3")
85 | o_r4 = build_resnet_block(o_r3, ngf*4, "r4")
86 | o_r5 = build_resnet_block(o_r4, ngf*4, "r5")
87 | o_r6 = build_resnet_block(o_r5, ngf*4, "r6")
88 | o_r7 = build_resnet_block(o_r6, ngf*4, "r7")
89 | o_r8 = build_resnet_block(o_r7, ngf*4, "r8")
90 | o_r9 = build_resnet_block(o_r8, ngf*4, "r9")
91 |
92 | o_c4 = general_deconv2d(o_r9, [batch_size,128,128,ngf*2], ngf*2, ks, ks, 2, 2, 0.02,"SAME","c4")
93 | o_c5 = general_deconv2d(o_c4, [batch_size,256,256,ngf], ngf, ks, ks, 2, 2, 0.02,"SAME","c5")
94 | o_c6 = general_conv2d(o_c5, img_layer, f, f, 1, 1, 0.02,"SAME","c6",do_relu=False)
95 |
96 | # Adding the tanh layer
97 |
98 | out_gen = tf.nn.tanh(o_c6,"t1")
99 |
100 |
101 | return out_gen
102 |
103 |
104 | def build_gen_discriminator(inputdisc, name="discriminator"):
105 |
106 | with tf.variable_scope(name):
107 | f = 4
108 |
109 | o_c1 = general_conv2d(inputdisc, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2)
110 | o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
111 | o_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
112 | o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4",relufactor=0.2)
113 | o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5",do_norm=False,do_relu=False)
114 |
115 | return o_c5
116 |
117 |
118 | def patch_discriminator(inputdisc, name="discriminator"):
119 |
120 | with tf.variable_scope(name):
121 | f= 4
122 |
123 | patch_input = tf.random_crop(inputdisc,[1,70,70,3])
124 | o_c1 = general_conv2d(patch_input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm="False", relufactor=0.2)
125 | o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
126 | o_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
127 | o_c4 = general_conv2d(o_c3, ndf*8, f, f, 2, 2, 0.02, "SAME", "c4", relufactor=0.2)
128 | o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5",do_norm=False,do_relu=False)
129 |
130 | return o_c5
--------------------------------------------------------------------------------