├── CycleGAN-keras.ipynb
├── README.md
├── data
├── data_loader.py
└── save_data.py
├── models
├── __init__.py
├── loss.py
├── networks.py
└── train_function.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── train.py
└── util
├── __init__.py
├── image_pool.py
└── util.py
/CycleGAN-keras.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Keras implementation of https://github.com/junyanz/CycleGAN"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# tf.Session(config=tf.ConfigProto(log_device_placement=True))"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# import numpy as np\n",
26 | "# np.random.seed(9999)"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {
33 | "scrolled": true
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import os\n",
38 | "import keras.backend as K\n",
39 | "import tensorflow as tf\n",
40 | "import numpy as np\n",
41 | "import glob\n",
42 | "import time\n",
43 | "import warnings\n",
44 | "from PIL import Image\n",
45 | "from random import randint, shuffle, uniform\n",
46 | "warnings.simplefilter('error', Image.DecompressionBombWarning)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "from keras.optimizers import RMSprop, SGD, Adam\n",
56 | "from keras.models import Sequential, Model\n",
57 | "from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout\n",
58 | "from keras.layers import Conv2DTranspose, UpSampling2D, Activation, Add, Lambda\n",
59 | "from keras.layers.advanced_activations import LeakyReLU\n",
60 | "from keras.activations import relu\n",
61 | "from keras.initializers import RandomNormal\n",
62 | "from keras_contrib.layers.normalization import InstanceNormalization"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "# Weights initializations\n",
72 | "\n",
73 | "# for convolution kernel\n",
74 | "conv_init = RandomNormal(0, 0.02)\n",
75 | "# for batch normalization\n",
76 | "gamma_init = RandomNormal(1., 0.02) "
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "def conv2d(f, *a, **k):\n",
86 | " return Conv2D(f, kernel_initializer = conv_init, *a, **k)\n",
87 | "def batchnorm():\n",
88 | " return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5, gamma_initializer = gamma_init)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False,\n",
98 | " has_activation_layer=True, use_leaky_relu=False, padding='same'):\n",
99 | " x = conv2d(filters, (size, size), strides=stride, padding=padding)(x)\n",
100 | " if has_norm_layer:\n",
101 | " if not use_norm_instance:\n",
102 | " x = batchnorm()(x)\n",
103 | " else:\n",
104 | " x = InstanceNormalization(axis=1)(x)\n",
105 | " if has_activation_layer:\n",
106 | " if not use_leaky_relu:\n",
107 | " x = Activation('relu')(x)\n",
108 | " else:\n",
109 | " x = LeakyReLU(alpha=0.2)(x)\n",
110 | " return x\n",
111 | "\n",
112 | "def res_block(x, filters=256, use_dropout=False):\n",
113 | " y = conv_block(x, filters, 3, (1, 1))\n",
114 | " if use_dropout:\n",
115 | " y = Dropout(0.5)(y)\n",
116 | " y = conv_block(y, filters, 3, (1, 1), has_activation_layer=False)\n",
117 | " return Add()([y, x])\n",
118 | "\n",
119 | "# decoder block\n",
120 | "def up_block(x, filters, size, use_conv_transpose=True, use_norm_instance=False):\n",
121 | " if use_conv_transpose:\n",
122 | " x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same',\n",
123 | " use_bias=True if use_norm_instance else False,\n",
124 | " kernel_initializer=RandomNormal(0, 0.02))(x)\n",
125 | " x = batchnorm()(x)\n",
126 | " x = Activation('relu')(x)\n",
127 | " else:\n",
128 | " x = UpSampling2D()(x)\n",
129 | " x = conv_block(x, filters, size, (1, 1))\n",
130 | " return x"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "# Defines the PatchGAN discriminator"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "def n_layer_discriminator(image_size=256, input_nc=3, ndf=64, hidden_layers=3):\n",
149 | " \"\"\"\n",
150 | " input_nc: input channels\n",
151 | " ndf: filters of the first layer\n",
152 | " \"\"\"\n",
153 | " inputs = Input(shape=(image_size, image_size, input_nc))\n",
154 | " x = inputs\n",
155 | " \n",
156 | " x = ZeroPadding2D(padding=(1, 1))(x)\n",
157 | " x = conv_block(x, ndf, 4, has_norm_layer=False, use_leaky_relu=True, padding='valid')\n",
158 | " \n",
159 | " x = ZeroPadding2D(padding=(1, 1))(x)\n",
160 | " for i in range(1, hidden_layers + 1):\n",
161 | " nf = 2 ** i * ndf\n",
162 | " x = conv_block(x, nf, 4, use_leaky_relu=True, padding='valid')\n",
163 | " x = ZeroPadding2D(padding=(1, 1))(x)\n",
164 | " \n",
165 | " x = conv2d(1, (4, 4), activation='sigmoid', strides=(1, 1))(x)\n",
166 | " outputs = x\n",
167 | " return Model(inputs=inputs, outputs=outputs)"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "# Defines the generator"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "def resnet_generator(image_size=256, input_nc=3, res_blocks=6, use_conv_transpose=True):\n",
186 | " inputs = Input(shape=(image_size, image_size, input_nc))\n",
187 | " x = inputs\n",
188 | " \n",
189 | " x = conv_block(x, 64, 7, (1, 1))\n",
190 | " x = conv_block(x, 128, 3, (2, 2))\n",
191 | " x = conv_block(x, 256, 3, (2, 2))\n",
192 | " \n",
193 | " for i in range(res_blocks):\n",
194 | " x = res_block(x)\n",
195 | " \n",
196 | " x = up_block(x, 128, 3, use_conv_transpose=use_conv_transpose)\n",
197 | " x = up_block(x, 64, 3, use_conv_transpose=use_conv_transpose)\n",
198 | " \n",
199 | " x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1) ,padding='same')(x) \n",
200 | " outputs = x\n",
201 | " return Model(inputs=inputs, outputs=outputs), inputs, outputs"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "def mkdirs(paths):\n",
211 | " if isinstance(paths, list) and not isinstance(paths, str):\n",
212 | " for path in paths:\n",
213 | " mkdir(path)\n",
214 | " else:\n",
215 | " mkdir(paths)\n",
216 | "\n",
217 | "def mkdir(path):\n",
218 | " if not os.path.exists(path):\n",
219 | " os.makedirs(path)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "# gloabal variables\n",
229 | "image_size = 128\n",
230 | "image_jitter_range = 30\n",
231 | "load_size = image_size + image_jitter_range\n",
232 | "batch_size = 16\n",
233 | "input_nc = 3\n",
234 | "path = '/home/lin/Downloads/'\n",
235 | "dpath = path + 'weights-cyclelossweight10-batchsize{}-imagesize{}/'.format(batch_size, image_size)\n",
236 | "dpath_result = dpath + 'results'\n",
237 | "mkdirs([dpath, dpath_result])"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "def criterion_GAN(output, target, use_lsgan=True):\n",
247 | " if use_lsgan:\n",
248 | " diff = output-target\n",
249 | " dims = list(range(1,K.ndim(diff)))\n",
250 | " return K.expand_dims((K.mean(diff**2, dims)), 0)\n",
251 | " else:\n",
252 | " return K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))\n",
253 | " \n",
254 | "def criterion_cycle(rec, real):\n",
255 | " diff = K.abs(rec-real)\n",
256 | " dims = list(range(1,K.ndim(diff)))\n",
257 | " return K.expand_dims((K.mean(diff, dims)), 0)"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "def netG_loss(inputs, cycle_loss_weight=10):\n",
267 | " netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B = inputs\n",
268 | " \n",
269 | " loss_G_A = criterion_GAN(netD_B_predict_fake, K.ones_like(netD_B_predict_fake))\n",
270 | " loss_cyc_A = criterion_cycle(rec_A, real_A)\n",
271 | " \n",
272 | " loss_G_B = criterion_GAN(netD_A_predict_fake, K.ones_like(netD_A_predict_fake))\n",
273 | " loss_cyc_B = criterion_cycle(rec_B, real_B)\n",
274 | " \n",
275 | " loss_G = loss_G_A + loss_G_B + cycle_loss_weight * (loss_cyc_A+loss_cyc_B)\n",
276 | " return loss_G"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "def netD_loss(netD_predict):\n",
286 | " netD_predict_real, netD_predict_fake = netD_predict\n",
287 | " \n",
288 | " netD_loss_real = criterion_GAN(netD_predict_real, K.ones_like(netD_predict_real))\n",
289 | " netD_loss_fake = criterion_GAN(netD_predict_fake, K.zeros_like(netD_predict_fake))\n",
290 | " \n",
291 | " loss_netD= 0.5 * (netD_loss_real + netD_loss_fake)\n",
292 | " return loss_netD"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "netD_A = n_layer_discriminator(image_size)\n",
302 | "netD_B = n_layer_discriminator(image_size)\n",
303 | "# netD_A.summary()\n",
304 | "# netD_B.summary()"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "netG_A, real_A, fake_B = resnet_generator(image_size, use_conv_transpose=True)\n",
314 | "netG_B, real_B, fake_A = resnet_generator(image_size, use_conv_transpose=True)\n",
315 | "# netG_A.summary()\n",
316 | "# netG_B.summary()"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "# make generater train function"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": null,
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "netD_B_predict_fake = netD_B(fake_B)\n",
335 | "rec_A= netG_B(fake_B)\n",
336 | "netD_A_predict_fake = netD_A(fake_A)\n",
337 | "rec_B = netG_A(fake_A)\n",
338 | "lambda_layer_inputs = [netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B]\n",
339 | "\n",
340 | "for l in netG_A.layers: \n",
341 | " l.trainable=True\n",
342 | "for l in netG_B.layers: \n",
343 | " l.trainable=True\n",
344 | "for l in netD_A.layers: \n",
345 | " l.trainable=False\n",
346 | "for l in netD_B.layers: \n",
347 | " l.trainable=False\n",
348 | " \n",
349 | "netG_train_function = Model([real_A, real_B],Lambda(netG_loss)(lambda_layer_inputs))\n",
350 | "Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=None, decay=0.0)\n",
351 | "netG_train_function.compile('adam', 'mae')"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "metadata": {},
358 | "outputs": [],
359 | "source": [
360 | "# make discriminator A train function"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {},
367 | "outputs": [],
368 | "source": [
369 | "netD_A_predict_real = netD_A(real_A)\n",
370 | "\n",
371 | "_fake_A = Input(shape=(image_size, image_size, input_nc))\n",
372 | "_netD_A_predict_fake = netD_A(_fake_A)\n",
373 | "\n",
374 | "for l in netG_A.layers: \n",
375 | " l.trainable=False\n",
376 | "for l in netG_B.layers: \n",
377 | " l.trainable=False\n",
378 | "for l in netD_A.layers: \n",
379 | " l.trainable=True \n",
380 | "for l in netD_B.layers: \n",
381 | " l.trainable=False\n",
382 | "\n",
383 | "netD_A_train_function = Model([real_A, _fake_A], Lambda(netD_loss)([netD_A_predict_real, _netD_A_predict_fake]))\n",
384 | "netD_A_train_function.compile('adam', 'mae')"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": null,
390 | "metadata": {},
391 | "outputs": [],
392 | "source": [
393 | "# make discriminator B train function"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "metadata": {},
400 | "outputs": [],
401 | "source": [
402 | "netD_B_predict_real = netD_B(real_B)\n",
403 | "\n",
404 | "_fake_B = Input(shape=(image_size, image_size, input_nc))\n",
405 | "_netD_B_predict_fake = netD_B(_fake_B)\n",
406 | "\n",
407 | "for l in netG_A.layers: \n",
408 | " l.trainable=False\n",
409 | "for l in netG_B.layers: \n",
410 | " l.trainable=False\n",
411 | "for l in netD_B.layers: \n",
412 | " l.trainable=True \n",
413 | "for l in netD_A.layers: \n",
414 | " l.trainable=False \n",
415 | " \n",
416 | "netD_B_train_function= Model([real_B, _fake_B], Lambda(netD_loss)([netD_B_predict_real, _netD_B_predict_fake]))\n",
417 | "netD_B_train_function.compile('adam', 'mae')"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "def load_data(file_pattern):\n",
427 | " return glob.glob(file_pattern)\n",
428 | "\n",
429 | "def read_image(img, loadsize=load_size, imagesize=image_size):\n",
430 | " img = Image.open(img).convert('RGB')\n",
431 | " img = img.resize((loadsize, loadsize), Image.BICUBIC)\n",
432 | " img = np.array(img)\n",
433 | " assert img.shape == (loadsize, loadsize, 3)\n",
434 | " img = img.astype(np.float32)\n",
435 | " img = (img-127.5) / 127.5\n",
436 | " # random jitter\n",
437 | " w_offset = h_offset = randint(0, max(0, loadsize - imagesize - 1))\n",
438 | " img = img[h_offset:h_offset + imagesize,\n",
439 | " w_offset:w_offset + imagesize, :]\n",
440 | " # horizontal flip\n",
441 | " if randint(0, 1):\n",
442 | " img = img[:, ::-1]\n",
443 | " return img\n",
444 | "\n",
445 | "def try_read_img(data, index):\n",
446 | " try:\n",
447 | " img = read_image(data[index])\n",
448 | " return img\n",
449 | " except:\n",
450 | " img = try_read_img(data, index + 1)\n",
451 | " return img\n",
452 | "\n",
453 | "train_A = load_data('/home/lin/Downloads/m-cycle/trainA/*')\n",
454 | "train_B = load_data('/home/lin/Downloads/m-cycle/trainB/*')\n",
455 | "print(len(train_A))\n",
456 | "print(len(train_B))\n",
457 | "\n",
458 | "val_A = load_data('/home/lin/Downloads/m-cycle/testA/*')\n",
459 | "val_B = load_data('/home/lin/Downloads/m-cycle/testB/*')"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": null,
465 | "metadata": {},
466 | "outputs": [],
467 | "source": [
468 | "def minibatch(data, batch_size):\n",
469 | " length = len(data)\n",
470 | " shuffle(data)\n",
471 | " epoch = i = 0\n",
472 | " tmpsize = None \n",
473 | " \n",
474 | " while True:\n",
475 | " size = tmpsize if tmpsize else batch_size\n",
476 | " if i+size > length:\n",
477 | " shuffle(data)\n",
478 | " i = 0\n",
479 | " epoch+=1 \n",
480 | " rtn = []\n",
481 | " for j in range(i,i+size):\n",
482 | " img = try_read_img(data, j)\n",
483 | " rtn.append(img)\n",
484 | " rtn = np.stack(rtn, axis=0) \n",
485 | " i+=size\n",
486 | " tmpsize = yield epoch, np.float32(rtn)\n",
487 | "\n",
488 | "def minibatchAB(dataA, dataB, batch_size):\n",
489 | " batchA=minibatch(dataA, batch_size)\n",
490 | " batchB=minibatch(dataB, batch_size)\n",
491 | " tmpsize = None \n",
492 | " while True:\n",
493 | " ep1, A = batchA.send(tmpsize)\n",
494 | " ep2, B = batchB.send(tmpsize)\n",
495 | " tmpsize = yield max(ep1, ep2), A, B"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": null,
501 | "metadata": {},
502 | "outputs": [],
503 | "source": [
504 | "from IPython.display import display\n",
505 | "def display_image(X, rows=1):\n",
506 | " assert X.shape[0]%rows == 0\n",
507 | " int_X = ((X*127.5+127.5).clip(0,255).astype('uint8'))\n",
508 | " int_X = int_X.reshape(-1,image_size,image_size, 3)\n",
509 | " int_X = int_X.reshape(rows, -1, image_size, image_size,3).swapaxes(1,2).reshape(rows*image_size,-1, 3)\n",
510 | " pil_X = Image.fromarray(int_X)\n",
511 | " t = str(round(time.time()))\n",
512 | " pil_X.save(dpath+'results/'+ t, 'JPEG')\n",
513 | " display(pil_X)"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": null,
519 | "metadata": {},
520 | "outputs": [],
521 | "source": [
522 | "train_batch = minibatchAB(train_A, train_B, 6)\n",
523 | "\n",
524 | "_, A, B = next(train_batch)\n",
525 | "display_image(A)\n",
526 | "display_image(B)\n",
527 | "_, A, B = next(train_batch)\n",
528 | "display_image(A)\n",
529 | "display_image(B)\n",
530 | "del train_batch, A, B"
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": null,
536 | "metadata": {},
537 | "outputs": [],
538 | "source": [
539 | "val_batch = minibatchAB(val_A, val_B, 4)\n",
540 | "\n",
541 | "_, A, B = next(val_batch)\n",
542 | "display_image(A)\n",
543 | "display_image(B)\n",
544 | "del val_batch, A, B"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "def get_output(netG_alpha, netG_beta, X):\n",
554 | " real_input = X\n",
555 | " fake_output = netG_alpha.predict(real_input)\n",
556 | " rec_input = netG_beta.predict(fake_output)\n",
557 | " outputs = [fake_output, rec_input]\n",
558 | " return outputs"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": null,
564 | "metadata": {},
565 | "outputs": [],
566 | "source": [
567 | "def get_combined_output(netG_alpha, netG_beta, X):\n",
568 | " r = [get_output(netG_alpha, netG_beta, X[i:i+1]) for i in range(X.shape[0])]\n",
569 | " r = np.array(r)\n",
570 | " return r.swapaxes(0,1)[:,:,0] "
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "metadata": {},
577 | "outputs": [],
578 | "source": [
579 | "def show_generator_image(A,B, netG_alpha, netG_beta):\n",
580 | " assert A.shape==B.shape\n",
581 | " \n",
582 | " rA = get_combined_output(netG_alpha, netG_beta, A)\n",
583 | " rB = get_combined_output(netG_beta, netG_alpha, B)\n",
584 | " \n",
585 | " arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]]) \n",
586 | " display_image(arr, 3)"
587 | ]
588 | },
589 | {
590 | "cell_type": "code",
591 | "execution_count": null,
592 | "metadata": {},
593 | "outputs": [],
594 | "source": [
595 | "def get_generater_function(netG):\n",
596 | " real_input = netG.inputs[0]\n",
597 | " fake_output = netG.outputs[0]\n",
598 | " function = K.function([real_input, K.learning_phase()], [fake_output])\n",
599 | " return function\n",
600 | "\n",
601 | "netG_A_function = get_generater_function(netG_A)\n",
602 | "netG_B_function = get_generater_function(netG_B)"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {},
609 | "outputs": [],
610 | "source": [
611 | "class ImagePool():\n",
612 | " def __init__(self, pool_size=200):\n",
613 | " self.pool_size = pool_size\n",
614 | " if self.pool_size > 0:\n",
615 | " self.num_imgs = 0\n",
616 | " self.images = []\n",
617 | "\n",
618 | " def query(self, images):\n",
619 | " if self.pool_size == 0:\n",
620 | " return images\n",
621 | " return_images = []\n",
622 | " for image in images:\n",
623 | " if self.num_imgs < self.pool_size:\n",
624 | " self.num_imgs = self.num_imgs + 1\n",
625 | " self.images.append(image)\n",
626 | " return_images.append(image)\n",
627 | " else:\n",
628 | " p = uniform(0, 1)\n",
629 | " if p > 0.5:\n",
630 | " random_id = randint(0, self.pool_size-1)\n",
631 | " tmp = self.images[random_id]\n",
632 | " self.images[random_id] = image\n",
633 | " return_images.append(tmp)\n",
634 | " else:\n",
635 | " return_images.append(image)\n",
636 | " return_images = np.stack(return_images, axis=0)\n",
637 | " return return_images"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": null,
643 | "metadata": {},
644 | "outputs": [],
645 | "source": [
646 | "K.learning_phase()"
647 | ]
648 | },
649 | {
650 | "cell_type": "code",
651 | "execution_count": null,
652 | "metadata": {},
653 | "outputs": [],
654 | "source": [
655 | "import time\n",
656 | "from IPython.display import clear_output\n",
657 | "time_start = time.time()\n",
658 | "how_many_epochs = 10\n",
659 | "iteration_count = 0\n",
660 | "epoch_count = 0\n",
661 | "display_freq = 1000 // batch_size \n",
662 | "save_freq = 20000 // batch_size\n",
663 | "val_batch = minibatchAB(val_A, val_B, batch_size=4)\n",
664 | "_, val_A, val_B = next(val_batch)\n",
665 | "train_batch = minibatchAB(train_A, train_B, batch_size)\n",
666 | " \n",
667 | "fake_A_pool = ImagePool()\n",
668 | "fake_B_pool = ImagePool()\n",
669 | "\n",
670 | "while epoch_count < how_many_epochs: \n",
671 | " target_label = np.zeros((batch_size, 1))\n",
672 | " epoch_count, A, B = next(train_batch)\n",
673 | "\n",
674 | " tmp_fake_B = netG_A_function([A, 1])[0]\n",
675 | " tmp_fake_A = netG_B_function([B, 1])[0]\n",
676 | " \n",
677 | " _fake_B = fake_B_pool.query(tmp_fake_B)\n",
678 | " _fake_A = fake_A_pool.query(tmp_fake_A)\n",
679 | "\n",
680 | " netG_train_function.train_on_batch([A, B], target_label)\n",
681 | " \n",
682 | " netD_B_train_function.train_on_batch([B, _fake_B], target_label)\n",
683 | " netD_A_train_function.train_on_batch([A, _fake_A], target_label)\n",
684 | " \n",
685 | " iteration_count+=1\n",
686 | " \n",
687 | " save_name = dpath + '{}' + str(iteration_count) + '.h5'\n",
688 | " \n",
689 | " if iteration_count%display_freq == 0:\n",
690 | " clear_output()\n",
691 | " timecost = (time.time()-time_start)/60\n",
692 | " print('epoch_count: {} iter_count: {} timecost: {}mins'.format(epoch_count, iteration_count, timecost))\n",
693 | " show_generator_image(val_A,val_B, netG_A, netG_B)\n",
694 | " netG_A.save_weights(save_name.format('tf_GA_weights'))\n",
695 | " netG_B.save_weights(save_name.format('tf_GB_weights'))\n",
696 | "\n",
697 | " if iteration_count%save_freq == 0:\n",
698 | " netD_A.save_weights(save_name.format('tf_DA_weights'))\n",
699 | " netD_B.save_weights(save_name.format('tf_DB_weights'))\n",
700 | " netG_train_function.save_weights(save_name.format('tf_G_train_weights'))\n",
701 | " netD_A_train_function.save_weights(save_name.format('tf_D_A_train_weights'))\n",
702 | " netD_B_train_function.save_weights(save_name.format('tf_D_B_train_weights'))"
703 | ]
704 | },
705 | {
706 | "cell_type": "code",
707 | "execution_count": null,
708 | "metadata": {},
709 | "outputs": [],
710 | "source": [
711 | "# inference"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": null,
717 | "metadata": {},
718 | "outputs": [],
719 | "source": [
720 | "load_name = dpath + '{}' + '1000.h5'\n",
721 | "netG_A.load_weights(load_name.format('tf_GA_weights'))\n",
722 | "netG_B.load_weights(load_name.format('tf_GB_weights'))\n",
723 | "netD_A.load_weights(load_name.format('tf_DA_weights'))\n",
724 | "netD_B.load_weights(load_name.format('tf_DB_weights'))\n",
725 | "netG_train_function.load_weights(load_name.format('tf_G_train_weights'))\n",
726 | "netD_A_train_function.load_weights(load_name.format('tf_D_A_train_weights'))\n",
727 | "netD_B_train_function.load_weights(load_name.format('tf_D_B_train_weights'))"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": null,
733 | "metadata": {},
734 | "outputs": [],
735 | "source": [
736 | "val_batch = minibatchAB(val_A, val_B, batch_size=2)"
737 | ]
738 | },
739 | {
740 | "cell_type": "code",
741 | "execution_count": null,
742 | "metadata": {},
743 | "outputs": [],
744 | "source": [
745 | "# run batch normalization layer in training mode"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": null,
751 | "metadata": {},
752 | "outputs": [],
753 | "source": [
754 | "_,A, B = next(val_batch)\n",
755 | "show_generator_image(A,B, netG_A, netG_B)"
756 | ]
757 | },
758 | {
759 | "cell_type": "markdown",
760 | "metadata": {},
761 | "source": [
762 | "\n",
763 | "\n",
764 | "\n",
765 | "\n",
766 | "\n",
767 | "\n",
768 | "\n"
769 | ]
770 | }
771 | ],
772 | "metadata": {
773 | "kernelspec": {
774 | "display_name": "Python 3",
775 | "language": "python",
776 | "name": "python3"
777 | },
778 | "language_info": {
779 | "codemirror_mode": {
780 | "name": "ipython",
781 | "version": 3
782 | },
783 | "file_extension": ".py",
784 | "mimetype": "text/x-python",
785 | "name": "python",
786 | "nbconvert_exporter": "python",
787 | "pygments_lexer": "ipython3",
788 | "version": "3.6.3"
789 | }
790 | },
791 | "nbformat": 4,
792 | "nbformat_minor": 1
793 | }
794 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cyclegan-keras
2 |
3 | keras implementation of cycle-gan based on [pytorch-CycleGan](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) (by junyanz) and [tf/torch/keras/lasagne] (by tjwei)
4 |
5 | ## Prerequisites
6 | train.py has not been tested, CycleGAN-keras.ipynb is recommended and tested OK on
7 | - Ubuntu 16.04
8 | - Python 3.6
9 | - Keras 2.1.2
10 | - Tensorflow 1.0.1
11 | - NVIDIA GPU + CUDA8.0 CuDNN6 or CuDNN5
12 |
13 |
14 |
15 | ## Demos [[manga-colorization-demo]](http://www.styletransfer.tech)
16 |
17 | Colorize manga with Cycle-GAN model totally run in browser.
18 | - Built based on [Keras.js](https://github.com/transcranial/keras-js) and [keras.js demos](https://transcranial.github.io/keras-js)
19 | - Model trained by juyter notebook version of this git repo
20 | - Check [Demo-Introduction](https://zhuanlan.zhihu.com/p/34672860) or my [demo-repo](https://github.com/MingwangLin/manga-colorization) for more details
21 |
22 |
23 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | from PIL import Image
4 | from random import randint, shuffle
5 |
6 |
7 | def load_data(file_pattern):
8 | return glob.glob(file_pattern)
9 |
10 |
11 | def read_image(img, loadsize=286, imagesize=256):
12 | img = Image.open(img).convert('RGB')
13 | img = img.resize((loadsize, loadsize), Image.BICUBIC)
14 | img = np.array(img)
15 | assert img.shape == (loadsize, loadsize, 3)
16 | img = img.astype(np.float32)
17 | img = (img - 127.5) / 127.5
18 | # random jitter
19 | w_offset = h_offset = randint(0, max(0, loadsize - imagesize - 1))
20 | img = img[h_offset:h_offset + imagesize, w_offset:w_offset + imagesize, :]
21 | # horizontal flip
22 | if randint(0, 1):
23 | img = img[:, ::-1]
24 | return img
25 |
26 |
27 | def try_read_img(data, index):
28 | try:
29 | img = read_image(data[index])
30 | return img
31 | except:
32 | try_read_img(data, index + 1)
33 |
34 |
35 | def minibatch(data, batch_size):
36 | length = len(data)
37 | shuffle(data)
38 | epoch = i = 0
39 | tmpsize = None
40 |
41 | while True:
42 | size = tmpsize if tmpsize else batch_size
43 | if i + size > length:
44 | shuffle(data)
45 | i = 0
46 | epoch += 1
47 | rtn = []
48 | for j in range(i, i + size):
49 | img = try_read_img(data, j)
50 | rtn.append(img)
51 | rtn = np.stack(rtn, axis=0)
52 | i += size
53 | tmpsize = yield epoch, np.float32(rtn)
54 |
55 |
56 | def minibatchAB(dataA, dataB, batch_size):
57 | batchA = minibatch(dataA, batch_size)
58 | batchB = minibatch(dataB, batch_size)
59 | tmpsize = None
60 | while True:
61 | ep1, A = batchA.send(tmpsize)
62 | ep2, B = batchB.send(tmpsize)
63 | tmpsize = yield max(ep1, ep2), A, B
64 |
--------------------------------------------------------------------------------
/data/save_data.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from IPython.display import display
4 | from PIL import Image
5 |
6 |
7 | def get_output(netG_alpha, netG_beta, X):
8 | real_input = X
9 | fake_output = netG_alpha.predict(real_input)
10 | rec_input = netG_beta.predict(fake_output)
11 | outputs = [fake_output, rec_input]
12 | return outputs
13 |
14 |
15 | def get_combined_output(netG_alpha, netG_beta, X):
16 | r = [get_output(netG_alpha, netG_beta, X[i:i + 1]) for i in range(X.shape[0])]
17 | r = np.array(r)
18 | return r.swapaxes(0, 1)[:, :, 0]
19 |
20 |
21 | def save_image(X, rows=1, image_size=256):
22 | assert X.shape[0] % rows == 0
23 | int_X = ((X * 255).clip(0, 255).astype('uint8'))
24 | int_X = int_X.reshape(-1, image_size, image_size, 3)
25 | int_X = int_X.reshape(rows, -1, image_size, image_size, 3).swapaxes(1, 2).reshape(rows * image_size, -1, 3)
26 | pil_X = Image.fromarray(int_X)
27 | t = str(time.time())
28 | pil_X.save(dpath + 'results/' + t, 'JPEG')
29 |
30 |
31 | def show_generator_image(A, B, netG_alpha, netG_beta):
32 | assert A.shape == B.shape
33 |
34 | rA = get_combined_output(netG_alpha, netG_beta, A)
35 | rB = get_combined_output(netG_beta, netG_alpha, B)
36 |
37 | arr = np.concatenate([A, B, rA[0], rB[0], rA[1], rB[1]])
38 | save_image(arr, rows=3)
39 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/models/__init__.py
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 |
3 |
4 | def criterion_GAN(output, target, use_lsgan=True):
5 | if use_lsgan:
6 | diff = output - target
7 | dims = list(range(1, K.ndim(diff)))
8 | return K.expand_dims((K.mean(diff ** 2, dims)), 0)
9 | else:
10 | return K.mean(K.log(output + 1e-12) * target + K.log(1 - output + 1e-12) * (1 - target))
11 |
12 |
13 | def criterion_cycle(rec, real):
14 | diff = K.abs(rec - real)
15 | dims = list(range(1, K.ndim(diff)))
16 | return K.expand_dims((K.mean(diff, dims)), 0)
17 |
18 |
19 | def netG_loss(G_tensors, loss_weight=10):
20 | netD_A_predict_fake, rec_A, G_A_input, netD_B_predict_fake, rec_B, G_B_input = G_tensors
21 |
22 | loss_G_B = criterion_GAN(netD_A_predict_fake, K.ones_like(netD_A_predict_fake))
23 | loss_cyc_A = criterion_cycle(rec_A, G_A_input)
24 |
25 | loss_G_A = criterion_GAN(netD_B_predict_fake, K.ones_like(netD_B_predict_fake))
26 | loss_cyc_B = criterion_cycle(rec_B, G_B_input)
27 |
28 | loss_G = loss_G_A + loss_G_B + loss_weight * (loss_cyc_A + loss_cyc_B)
29 |
30 | return loss_G
31 |
32 |
33 | def netD_loss(netD_predict):
34 | netD_predict_real, netD_predict_fake = netD_predict
35 |
36 | netD_loss_real = criterion_GAN(netD_predict_real, K.ones_like(netD_predict_real))
37 | netD_loss_fake = criterion_GAN(netD_predict_fake, K.zeros_like(netD_predict_fake))
38 |
39 | loss_netD = (1 / 2) * (netD_loss_real + netD_loss_fake)
40 | return loss_netD
41 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.models import Model
3 | from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout
4 | from keras.layers import UpSampling2D, Conv2DTranspose, Activation, Add
5 | from keras.layers.advanced_activations import LeakyReLU
6 | from keras.initializers import RandomNormal
7 | from keras_contrib.layers.normalization import InstanceNormalization
8 |
9 |
10 | def conv2d(f, *a, **k):
11 | return Conv2D(f, kernel_initializer=RandomNormal(0, 0.02), *a, **k)
12 |
13 |
14 | def batchnorm():
15 | return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5,
16 | gamma_initializer=RandomNormal(1., 0.02))
17 |
18 |
19 | def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False,
20 | has_activation_layer=True, use_leaky_relu=False, padding='same'):
21 | x = conv2d(filters, (size, size), strides=stride, padding=padding)(x)
22 | if has_norm_layer:
23 | if not use_norm_instance:
24 | x = batchnorm()(x)
25 | else:
26 | x = InstanceNormalization(axis=1)(x)
27 | if has_activation_layer:
28 | if not use_leaky_relu:
29 | x = Activation('relu')(x)
30 | else:
31 | x = LeakyReLU(alpha=0.2)(x)
32 | return x
33 |
34 |
35 | def res_block(x, filters=256, use_dropout=False):
36 | y = conv_block(x, filters, 3, (1, 1))
37 | if use_dropout:
38 | y = Dropout(0.5)(y)
39 | y = conv_block(y, filters, 3, (1, 1), has_activation_layer=False)
40 | return Add()([y, x])
41 |
42 |
43 | def up_block(x, filters, size, use_conv_transpose=True, use_norm_instance=False):
44 | if use_conv_transpose:
45 | x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same',
46 | use_bias=True if use_norm_instance else False,
47 | kernel_initializer=RandomNormal(0, 0.02))(x)
48 | x = batchnorm()(x)
49 | x = Activation('relu')(x)
50 |
51 | else:
52 | x = UpSampling2D()(x)
53 | x = conv_block(x, filters, size, (1, 1))
54 |
55 | return x
56 |
57 |
58 | # Defines the Resnet generator
59 | def resnet_generator(image_size=256, input_nc=3, res_blocks=6):
60 | inputs = Input(shape=(image_size, image_size, input_nc))
61 | x = inputs
62 |
63 | x = conv_block(x, 64, 7, (1, 1))
64 | x = conv_block(x, 128, 3, (2, 2))
65 | x = conv_block(x, 256, 3, (2, 2))
66 |
67 | for i in range(res_blocks):
68 | x = res_block(x)
69 |
70 | x = up_block(x, 128, 3)
71 | x = up_block(x, 64, 3)
72 |
73 | x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1), padding='same')(x)
74 | outputs = x
75 |
76 | return Model(inputs=inputs, outputs=outputs), inputs, outputs
77 |
78 |
79 | # Defines the PatchGAN discriminator
80 | def n_layer_discriminator(image_size=256, input_nc=3, ndf=64, hidden_layers=3):
81 | """
82 | input_nc: input channels
83 | ndf: filters of the first layer
84 | """
85 | inputs = Input(shape=(image_size, image_size, input_nc))
86 | x = inputs
87 |
88 | x = ZeroPadding2D(padding=(1, 1))(x)
89 | x = conv_block(x, ndf, 4, has_norm_layer=False, use_leaky_relu=True, padding='valid')
90 |
91 | x = ZeroPadding2D(padding=(1, 1))(x)
92 | for i in range(1, hidden_layers + 1):
93 | nf = 2 ** i * ndf
94 | x = conv_block(x, nf, 4, use_leaky_relu=True, padding='valid')
95 | x = ZeroPadding2D(padding=(1, 1))(x)
96 |
97 | x = conv2d(1, (4, 4), activation='sigmoid', strides=(1, 1))(x)
98 | outputs = x
99 |
100 | return Model(inputs=[inputs], outputs=outputs), inputs, outputs
101 |
102 | def get_generater_function(netG):
103 | real_input = netG.inputs[0]
104 | fake_output = netG.outputs[0]
105 | function = K.function([real_input], [fake_output])
106 | return function
107 |
108 |
--------------------------------------------------------------------------------
/models/train_function.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Input, BatchNormalization
2 | from keras.optimizers import Adam
3 | from keras.models import Model
4 | from keras.layers import Lambda
5 | from models.loss import netG_loss, netD_loss
6 |
7 |
8 | def get_train_function(inputs, loss_function, lambda_layer_inputs):
9 | Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=None, decay=0.0)
10 | train_function = Model(inputs, Lambda(loss_function)(lambda_layer_inputs))
11 | train_function.compile('adam', 'mae')
12 | return train_function
13 |
14 |
15 | # create generator train function
16 | def netG_train_function_creator(netD_A, netD_B, netG_A, netG_B, real_A, real_B, fake_A, fake_B):
17 | netD_B_predict_fake = netD_B(fake_B)
18 | rec_A = netG_B(fake_B)
19 | netD_A_predict_fake = netD_A(fake_A)
20 | rec_B = netG_A(fake_A)
21 | lambda_layer_inputs = [netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B]
22 | for l in netG_A.layers:
23 | l.trainable = True
24 | for l in netG_B.layers:
25 | l.trainable = True
26 | for l in netD_A.layers:
27 | l.trainable = False
28 | for l in netD_B.layers:
29 | l.trainable = False
30 | netG_train_function = get_train_function(inputs=[real_A, real_B], loss_function=netG_loss,
31 | lambda_layer_inputs=lambda_layer_inputs)
32 | return netG_train_function
33 |
34 |
35 | # create discriminator A train function
36 | def netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, finesize, input_nc):
37 | netD_A_predict_real = netD_A(real_A)
38 | _fake_A = Input(shape=(finesize, finesize, input_nc))
39 | _netD_A_predict_fake = netD_A(_fake_A)
40 | for l in netG_A.layers:
41 | l.trainable = False
42 | for l in netG_B.layers:
43 | l.trainable = False
44 | for l in netD_A.layers:
45 | l.trainable = True
46 | for l in netD_B.layers:
47 | l.trainable = False
48 | netD_A_train_function = get_train_function(inputs=[real_A, _fake_A], loss_function=netD_loss,
49 | lambda_layer_inputs=[netD_A_predict_real, _netD_A_predict_fake])
50 | return netD_A_train_function
51 |
52 |
53 | # create discriminator B train function
54 | def netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, finesize, input_nc):
55 | netD_B_predict_real = netD_B(real_B)
56 | _fake_B = Input(shape=(finesize, finesize, input_nc))
57 | _netD_B_predict_fake = netD_B(_fake_B)
58 | for l in netG_A.layers:
59 | l.trainable = False
60 | if isinstance(l, BatchNormalization):
61 | l._per_input_updates = {}
62 | for l in netG_B.layers:
63 | l.trainable = False
64 | if isinstance(l, BatchNormalization):
65 | l._per_input_updates = {}
66 | for l in netD_B.layers:
67 | l.trainable = True
68 | for l in netD_A.layers:
69 | l.trainable = False
70 | netD_B_train_function = get_train_function(inputs=[real_B, _fake_B], loss_function=netD_loss,
71 | lambda_layer_inputs=[netD_B_predict_real, _netD_B_predict_fake])
72 | return netD_B_train_function
73 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 |
7 | class BaseOptions():
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | self.initialized = False
11 |
12 | def initialize(self):
13 | self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
14 | self.parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
15 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
16 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
17 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
18 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
19 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
20 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
21 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
22 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
23 | self.parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use.')
24 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
25 | self.parser.add_argument('--nThreads', default=6, type=int, help='# threads for loading data')
26 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
27 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
28 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
29 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
30 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
31 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
32 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
33 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
34 |
35 | self.initialized = True
36 |
37 | def parse(self):
38 | if not self.initialized:
39 | self.initialize()
40 | self.opt = self.parser.parse_args()
41 | self.opt.isTrain = self.isTrain # train or test
42 |
43 | args = vars(self.opt)
44 |
45 | print('------------ Options -------------')
46 | for k, v in sorted(args.items()):
47 | print('%s: %s' % (str(k), str(v)))
48 | print('-------------- End ----------------')
49 |
50 | # save to the disk
51 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
52 | util.mkdirs(expr_dir)
53 | file_name = os.path.join(expr_dir, 'opt.txt')
54 | with open(file_name, 'wt') as opt_file:
55 | opt_file.write('------------ Options -------------\n')
56 | for k, v in sorted(args.items()):
57 | opt_file.write('%s: %s\n' % (str(k), str(v)))
58 | opt_file.write('-------------- End ----------------\n')
59 | return self.opt
60 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12 | self.parser.add_argument('--how_many', type=int, default=35, help='how many test images to run')
13 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
8 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
9 | self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
10 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
11 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
12 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
13 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
14 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
15 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
16 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
17 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
18 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
19 | self.parser.add_argument('--lambda_param', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
20 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
21 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
22 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
23 |
24 | self.isTrain = True
25 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from IPython.display import clear_output
4 | from options.train_options import TrainOptions
5 | from data.data_loader import load_data, minibatchAB
6 | from data.save_data import show_generator_image
7 | from util.image_pool import ImagePool
8 | from models.networks import get_generater_function
9 | from models.networks import resnet_generator, n_layer_discriminator
10 | from models.train_function import *
11 | opt = TrainOptions().parse()
12 |
13 | # load data
14 | dpath = opt.dataroot
15 | train_A = load_data(dpath + 'trainA/*')
16 | train_B = load_data(dpath + 'trainB/*')
17 | train_batch = minibatchAB(train_A, train_B, batch_size=opt.batch_size)
18 | val_A = load_data(dpath + 'valA/*')
19 | val_B = load_data(dpath + 'valB/*')
20 | val_batch = minibatchAB(val_A, val_B, batch_size=4)
21 |
22 | # create gennerator models
23 | netG_A, real_A, fake_B = resnet_generator()
24 | netG_B, real_B, fake_A = resnet_generator()
25 |
26 | # create discriminator models
27 | netD_A = n_layer_discriminator()
28 | netD_B = n_layer_discriminator()
29 |
30 | # create generators train function
31 | netG_train_function = netG_train_function_creator(netD_A, netD_B, netG_A, netG_B, real_A, real_B, fake_A, fake_B)
32 | # create discriminator A train function
33 | netD_A_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, opt.finesize, opt.input_nc)
34 | # create discriminator B train function
35 | netD_B_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, opt.finesize, opt.input_nc)
36 |
37 | # train loop
38 | time_start = time.time()
39 | how_many_epochs = 5
40 | iteration_count = 0
41 | epoch_count = 0
42 | batch_size = opt.batch_size
43 | display_freq = 10000
44 |
45 | netG_A_function = get_generater_function(netG_A)
46 | netG_B_functionr = get_generater_function(netG_B)
47 |
48 | fake_A_pool = ImagePool()
49 | fake_B_pool = ImagePool()
50 |
51 | while epoch_count < how_many_epochs:
52 | target_label = np.zeros((batch_size, 1))
53 | epoch_count, A, B = next(train_batch)
54 |
55 | tmp_fake_B = netG_A_function([A])[0]
56 | tmp_fake_A = netG_B_functionr([B])[0]
57 |
58 | _fake_B = fake_B_pool.query(tmp_fake_B)
59 | _fake_A = fake_A_pool.query(tmp_fake_A)
60 |
61 | netG_train_function.train_on_batch([A, B], target_label)
62 |
63 | netD_B_train_function.train_on_batch([B, _fake_B], target_label)
64 | netD_A_train_function.train_on_batch([A, _fake_A], target_label)
65 |
66 | iteration_count += 1
67 |
68 | if iteration_count % display_freq == 0:
69 | clear_output()
70 | traintime = (time.time() - time_start) / iteration_count
71 | print('epoch_count: {} iter_count: {} timecost/iter: {}s'.format(epoch_count, iteration_count, traintime))
72 | _, val_A, val_B = next(val_batch)
73 | show_generator_image(val_A, val_B, netG_A, netG_B)
74 |
75 | save_name = dpath + '{}' + str(iteration_count) + '.h5'
76 |
77 | netG_A.save(save_name.format('tf_GA'))
78 | netG_A.save_weights(save_name.format('tf_GA_weights'))
79 | netG_B.save(save_name.format('tf_GB'))
80 | netG_B.save_weights(save_name.format('tf_GB_weights'))
81 | netD_A.save(save_name.format('tf_DA'))
82 |
83 | netG_train_function.save(save_name.format('tf_G_train'))
84 | netG_train_function.save_weights(save_name.format('tf_G_train_weights'))
85 | netD_A_train_function.save(save_name.format('tf_D_A_train'))
86 | netD_A_train_function.save_weights(save_name.format('tf_D_A_train_weights'))
87 | netD_B_train_function.save(save_name.format('tf_D_B_train'))
88 | netD_B_train_function.save_weights(save_name.format('tf_D_B_train_weights'))
89 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/util/__init__.py
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from random import randint, uniform
3 |
4 |
5 | class ImagePool():
6 | def __init__(self, pool_size=50):
7 | self.pool_size = pool_size
8 | if self.pool_size > 0:
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 | return_images = []
16 | for image in images:
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = uniform(0, 1)
23 | if p > 0.5:
24 | random_id = randint(0, self.pool_size - 1)
25 | tmp = self.images[random_id]
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = np.stack(return_images, axis=0)
31 | return return_images
32 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | from PIL import Image
4 | import numpy as np
5 | import os
6 |
7 |
8 | # Converts a Tensor into a Numpy array
9 | # |imtype|: the desired type of the converted numpy array
10 | def tensor2im(image_tensor, imtype=np.uint8):
11 | image_numpy = image_tensor[0].cpu().float().numpy()
12 | if image_numpy.shape[0] == 1:
13 | image_numpy = np.tile(image_numpy, (3, 1, 1))
14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
15 | return image_numpy.astype(imtype)
16 |
17 |
18 | def diagnose_network(net, name='network'):
19 | mean = 0.0
20 | count = 0
21 | for param in net.parameters():
22 | if param.grad is not None:
23 | mean += torch.mean(torch.abs(param.grad.data))
24 | count += 1
25 | if count > 0:
26 | mean = mean / count
27 | print(name)
28 | print(mean)
29 |
30 |
31 | def save_image(image_numpy, image_path):
32 | image_pil = Image.fromarray(image_numpy)
33 | image_pil.save(image_path)
34 |
35 |
36 | def print_numpy(x, val=True, shp=False):
37 | x = x.astype(np.float64)
38 | if shp:
39 | print('shape,', x.shape)
40 | if val:
41 | x = x.flatten()
42 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
43 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
44 |
45 |
46 | def mkdirs(paths):
47 | if isinstance(paths, list) and not isinstance(paths, str):
48 | for path in paths:
49 | mkdir(path)
50 | else:
51 | mkdir(paths)
52 |
53 |
54 | def mkdir(path):
55 | if not os.path.exists(path):
56 | os.makedirs(path)
57 |
--------------------------------------------------------------------------------