├── CycleGAN-keras.ipynb ├── README.md ├── data ├── data_loader.py └── save_data.py ├── models ├── __init__.py ├── loss.py ├── networks.py └── train_function.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── train.py └── util ├── __init__.py ├── image_pool.py └── util.py /CycleGAN-keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Keras implementation of https://github.com/junyanz/CycleGAN" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# tf.Session(config=tf.ConfigProto(log_device_placement=True))" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# import numpy as np\n", 26 | "# np.random.seed(9999)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "scrolled": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import os\n", 38 | "import keras.backend as K\n", 39 | "import tensorflow as tf\n", 40 | "import numpy as np\n", 41 | "import glob\n", 42 | "import time\n", 43 | "import warnings\n", 44 | "from PIL import Image\n", 45 | "from random import randint, shuffle, uniform\n", 46 | "warnings.simplefilter('error', Image.DecompressionBombWarning)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from keras.optimizers import RMSprop, SGD, Adam\n", 56 | "from keras.models import Sequential, Model\n", 57 | "from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout\n", 58 | "from keras.layers import Conv2DTranspose, UpSampling2D, Activation, Add, Lambda\n", 59 | "from keras.layers.advanced_activations import LeakyReLU\n", 60 | "from keras.activations import relu\n", 61 | "from keras.initializers import RandomNormal\n", 62 | "from keras_contrib.layers.normalization import InstanceNormalization" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Weights initializations\n", 72 | "\n", 73 | "# for convolution kernel\n", 74 | "conv_init = RandomNormal(0, 0.02)\n", 75 | "# for batch normalization\n", 76 | "gamma_init = RandomNormal(1., 0.02) " 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "def conv2d(f, *a, **k):\n", 86 | " return Conv2D(f, kernel_initializer = conv_init, *a, **k)\n", 87 | "def batchnorm():\n", 88 | " return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5, gamma_initializer = gamma_init)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False,\n", 98 | " has_activation_layer=True, use_leaky_relu=False, padding='same'):\n", 99 | " x = conv2d(filters, (size, size), strides=stride, padding=padding)(x)\n", 100 | " if has_norm_layer:\n", 101 | " if not use_norm_instance:\n", 102 | " x = batchnorm()(x)\n", 103 | " else:\n", 104 | " x = InstanceNormalization(axis=1)(x)\n", 105 | " if has_activation_layer:\n", 106 | " if not use_leaky_relu:\n", 107 | " x = Activation('relu')(x)\n", 108 | " else:\n", 109 | " x = LeakyReLU(alpha=0.2)(x)\n", 110 | " return x\n", 111 | "\n", 112 | "def res_block(x, filters=256, use_dropout=False):\n", 113 | " y = conv_block(x, filters, 3, (1, 1))\n", 114 | " if use_dropout:\n", 115 | " y = Dropout(0.5)(y)\n", 116 | " y = conv_block(y, filters, 3, (1, 1), has_activation_layer=False)\n", 117 | " return Add()([y, x])\n", 118 | "\n", 119 | "# decoder block\n", 120 | "def up_block(x, filters, size, use_conv_transpose=True, use_norm_instance=False):\n", 121 | " if use_conv_transpose:\n", 122 | " x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same',\n", 123 | " use_bias=True if use_norm_instance else False,\n", 124 | " kernel_initializer=RandomNormal(0, 0.02))(x)\n", 125 | " x = batchnorm()(x)\n", 126 | " x = Activation('relu')(x)\n", 127 | " else:\n", 128 | " x = UpSampling2D()(x)\n", 129 | " x = conv_block(x, filters, size, (1, 1))\n", 130 | " return x" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "# Defines the PatchGAN discriminator" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "def n_layer_discriminator(image_size=256, input_nc=3, ndf=64, hidden_layers=3):\n", 149 | " \"\"\"\n", 150 | " input_nc: input channels\n", 151 | " ndf: filters of the first layer\n", 152 | " \"\"\"\n", 153 | " inputs = Input(shape=(image_size, image_size, input_nc))\n", 154 | " x = inputs\n", 155 | " \n", 156 | " x = ZeroPadding2D(padding=(1, 1))(x)\n", 157 | " x = conv_block(x, ndf, 4, has_norm_layer=False, use_leaky_relu=True, padding='valid')\n", 158 | " \n", 159 | " x = ZeroPadding2D(padding=(1, 1))(x)\n", 160 | " for i in range(1, hidden_layers + 1):\n", 161 | " nf = 2 ** i * ndf\n", 162 | " x = conv_block(x, nf, 4, use_leaky_relu=True, padding='valid')\n", 163 | " x = ZeroPadding2D(padding=(1, 1))(x)\n", 164 | " \n", 165 | " x = conv2d(1, (4, 4), activation='sigmoid', strides=(1, 1))(x)\n", 166 | " outputs = x\n", 167 | " return Model(inputs=inputs, outputs=outputs)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# Defines the generator" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "def resnet_generator(image_size=256, input_nc=3, res_blocks=6, use_conv_transpose=True):\n", 186 | " inputs = Input(shape=(image_size, image_size, input_nc))\n", 187 | " x = inputs\n", 188 | " \n", 189 | " x = conv_block(x, 64, 7, (1, 1))\n", 190 | " x = conv_block(x, 128, 3, (2, 2))\n", 191 | " x = conv_block(x, 256, 3, (2, 2))\n", 192 | " \n", 193 | " for i in range(res_blocks):\n", 194 | " x = res_block(x)\n", 195 | " \n", 196 | " x = up_block(x, 128, 3, use_conv_transpose=use_conv_transpose)\n", 197 | " x = up_block(x, 64, 3, use_conv_transpose=use_conv_transpose)\n", 198 | " \n", 199 | " x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1) ,padding='same')(x) \n", 200 | " outputs = x\n", 201 | " return Model(inputs=inputs, outputs=outputs), inputs, outputs" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "def mkdirs(paths):\n", 211 | " if isinstance(paths, list) and not isinstance(paths, str):\n", 212 | " for path in paths:\n", 213 | " mkdir(path)\n", 214 | " else:\n", 215 | " mkdir(paths)\n", 216 | "\n", 217 | "def mkdir(path):\n", 218 | " if not os.path.exists(path):\n", 219 | " os.makedirs(path)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# gloabal variables\n", 229 | "image_size = 128\n", 230 | "image_jitter_range = 30\n", 231 | "load_size = image_size + image_jitter_range\n", 232 | "batch_size = 16\n", 233 | "input_nc = 3\n", 234 | "path = '/home/lin/Downloads/'\n", 235 | "dpath = path + 'weights-cyclelossweight10-batchsize{}-imagesize{}/'.format(batch_size, image_size)\n", 236 | "dpath_result = dpath + 'results'\n", 237 | "mkdirs([dpath, dpath_result])" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "def criterion_GAN(output, target, use_lsgan=True):\n", 247 | " if use_lsgan:\n", 248 | " diff = output-target\n", 249 | " dims = list(range(1,K.ndim(diff)))\n", 250 | " return K.expand_dims((K.mean(diff**2, dims)), 0)\n", 251 | " else:\n", 252 | " return K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))\n", 253 | " \n", 254 | "def criterion_cycle(rec, real):\n", 255 | " diff = K.abs(rec-real)\n", 256 | " dims = list(range(1,K.ndim(diff)))\n", 257 | " return K.expand_dims((K.mean(diff, dims)), 0)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "def netG_loss(inputs, cycle_loss_weight=10):\n", 267 | " netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B = inputs\n", 268 | " \n", 269 | " loss_G_A = criterion_GAN(netD_B_predict_fake, K.ones_like(netD_B_predict_fake))\n", 270 | " loss_cyc_A = criterion_cycle(rec_A, real_A)\n", 271 | " \n", 272 | " loss_G_B = criterion_GAN(netD_A_predict_fake, K.ones_like(netD_A_predict_fake))\n", 273 | " loss_cyc_B = criterion_cycle(rec_B, real_B)\n", 274 | " \n", 275 | " loss_G = loss_G_A + loss_G_B + cycle_loss_weight * (loss_cyc_A+loss_cyc_B)\n", 276 | " return loss_G" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "def netD_loss(netD_predict):\n", 286 | " netD_predict_real, netD_predict_fake = netD_predict\n", 287 | " \n", 288 | " netD_loss_real = criterion_GAN(netD_predict_real, K.ones_like(netD_predict_real))\n", 289 | " netD_loss_fake = criterion_GAN(netD_predict_fake, K.zeros_like(netD_predict_fake))\n", 290 | " \n", 291 | " loss_netD= 0.5 * (netD_loss_real + netD_loss_fake)\n", 292 | " return loss_netD" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "netD_A = n_layer_discriminator(image_size)\n", 302 | "netD_B = n_layer_discriminator(image_size)\n", 303 | "# netD_A.summary()\n", 304 | "# netD_B.summary()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "netG_A, real_A, fake_B = resnet_generator(image_size, use_conv_transpose=True)\n", 314 | "netG_B, real_B, fake_A = resnet_generator(image_size, use_conv_transpose=True)\n", 315 | "# netG_A.summary()\n", 316 | "# netG_B.summary()" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# make generater train function" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "netD_B_predict_fake = netD_B(fake_B)\n", 335 | "rec_A= netG_B(fake_B)\n", 336 | "netD_A_predict_fake = netD_A(fake_A)\n", 337 | "rec_B = netG_A(fake_A)\n", 338 | "lambda_layer_inputs = [netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B]\n", 339 | "\n", 340 | "for l in netG_A.layers: \n", 341 | " l.trainable=True\n", 342 | "for l in netG_B.layers: \n", 343 | " l.trainable=True\n", 344 | "for l in netD_A.layers: \n", 345 | " l.trainable=False\n", 346 | "for l in netD_B.layers: \n", 347 | " l.trainable=False\n", 348 | " \n", 349 | "netG_train_function = Model([real_A, real_B],Lambda(netG_loss)(lambda_layer_inputs))\n", 350 | "Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=None, decay=0.0)\n", 351 | "netG_train_function.compile('adam', 'mae')" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "# make discriminator A train function" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "netD_A_predict_real = netD_A(real_A)\n", 370 | "\n", 371 | "_fake_A = Input(shape=(image_size, image_size, input_nc))\n", 372 | "_netD_A_predict_fake = netD_A(_fake_A)\n", 373 | "\n", 374 | "for l in netG_A.layers: \n", 375 | " l.trainable=False\n", 376 | "for l in netG_B.layers: \n", 377 | " l.trainable=False\n", 378 | "for l in netD_A.layers: \n", 379 | " l.trainable=True \n", 380 | "for l in netD_B.layers: \n", 381 | " l.trainable=False\n", 382 | "\n", 383 | "netD_A_train_function = Model([real_A, _fake_A], Lambda(netD_loss)([netD_A_predict_real, _netD_A_predict_fake]))\n", 384 | "netD_A_train_function.compile('adam', 'mae')" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "# make discriminator B train function" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "netD_B_predict_real = netD_B(real_B)\n", 403 | "\n", 404 | "_fake_B = Input(shape=(image_size, image_size, input_nc))\n", 405 | "_netD_B_predict_fake = netD_B(_fake_B)\n", 406 | "\n", 407 | "for l in netG_A.layers: \n", 408 | " l.trainable=False\n", 409 | "for l in netG_B.layers: \n", 410 | " l.trainable=False\n", 411 | "for l in netD_B.layers: \n", 412 | " l.trainable=True \n", 413 | "for l in netD_A.layers: \n", 414 | " l.trainable=False \n", 415 | " \n", 416 | "netD_B_train_function= Model([real_B, _fake_B], Lambda(netD_loss)([netD_B_predict_real, _netD_B_predict_fake]))\n", 417 | "netD_B_train_function.compile('adam', 'mae')" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "def load_data(file_pattern):\n", 427 | " return glob.glob(file_pattern)\n", 428 | "\n", 429 | "def read_image(img, loadsize=load_size, imagesize=image_size):\n", 430 | " img = Image.open(img).convert('RGB')\n", 431 | " img = img.resize((loadsize, loadsize), Image.BICUBIC)\n", 432 | " img = np.array(img)\n", 433 | " assert img.shape == (loadsize, loadsize, 3)\n", 434 | " img = img.astype(np.float32)\n", 435 | " img = (img-127.5) / 127.5\n", 436 | " # random jitter\n", 437 | " w_offset = h_offset = randint(0, max(0, loadsize - imagesize - 1))\n", 438 | " img = img[h_offset:h_offset + imagesize,\n", 439 | " w_offset:w_offset + imagesize, :]\n", 440 | " # horizontal flip\n", 441 | " if randint(0, 1):\n", 442 | " img = img[:, ::-1]\n", 443 | " return img\n", 444 | "\n", 445 | "def try_read_img(data, index):\n", 446 | " try:\n", 447 | " img = read_image(data[index])\n", 448 | " return img\n", 449 | " except:\n", 450 | " img = try_read_img(data, index + 1)\n", 451 | " return img\n", 452 | "\n", 453 | "train_A = load_data('/home/lin/Downloads/m-cycle/trainA/*')\n", 454 | "train_B = load_data('/home/lin/Downloads/m-cycle/trainB/*')\n", 455 | "print(len(train_A))\n", 456 | "print(len(train_B))\n", 457 | "\n", 458 | "val_A = load_data('/home/lin/Downloads/m-cycle/testA/*')\n", 459 | "val_B = load_data('/home/lin/Downloads/m-cycle/testB/*')" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "def minibatch(data, batch_size):\n", 469 | " length = len(data)\n", 470 | " shuffle(data)\n", 471 | " epoch = i = 0\n", 472 | " tmpsize = None \n", 473 | " \n", 474 | " while True:\n", 475 | " size = tmpsize if tmpsize else batch_size\n", 476 | " if i+size > length:\n", 477 | " shuffle(data)\n", 478 | " i = 0\n", 479 | " epoch+=1 \n", 480 | " rtn = []\n", 481 | " for j in range(i,i+size):\n", 482 | " img = try_read_img(data, j)\n", 483 | " rtn.append(img)\n", 484 | " rtn = np.stack(rtn, axis=0) \n", 485 | " i+=size\n", 486 | " tmpsize = yield epoch, np.float32(rtn)\n", 487 | "\n", 488 | "def minibatchAB(dataA, dataB, batch_size):\n", 489 | " batchA=minibatch(dataA, batch_size)\n", 490 | " batchB=minibatch(dataB, batch_size)\n", 491 | " tmpsize = None \n", 492 | " while True:\n", 493 | " ep1, A = batchA.send(tmpsize)\n", 494 | " ep2, B = batchB.send(tmpsize)\n", 495 | " tmpsize = yield max(ep1, ep2), A, B" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "from IPython.display import display\n", 505 | "def display_image(X, rows=1):\n", 506 | " assert X.shape[0]%rows == 0\n", 507 | " int_X = ((X*127.5+127.5).clip(0,255).astype('uint8'))\n", 508 | " int_X = int_X.reshape(-1,image_size,image_size, 3)\n", 509 | " int_X = int_X.reshape(rows, -1, image_size, image_size,3).swapaxes(1,2).reshape(rows*image_size,-1, 3)\n", 510 | " pil_X = Image.fromarray(int_X)\n", 511 | " t = str(round(time.time()))\n", 512 | " pil_X.save(dpath+'results/'+ t, 'JPEG')\n", 513 | " display(pil_X)" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "train_batch = minibatchAB(train_A, train_B, 6)\n", 523 | "\n", 524 | "_, A, B = next(train_batch)\n", 525 | "display_image(A)\n", 526 | "display_image(B)\n", 527 | "_, A, B = next(train_batch)\n", 528 | "display_image(A)\n", 529 | "display_image(B)\n", 530 | "del train_batch, A, B" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": null, 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "val_batch = minibatchAB(val_A, val_B, 4)\n", 540 | "\n", 541 | "_, A, B = next(val_batch)\n", 542 | "display_image(A)\n", 543 | "display_image(B)\n", 544 | "del val_batch, A, B" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "def get_output(netG_alpha, netG_beta, X):\n", 554 | " real_input = X\n", 555 | " fake_output = netG_alpha.predict(real_input)\n", 556 | " rec_input = netG_beta.predict(fake_output)\n", 557 | " outputs = [fake_output, rec_input]\n", 558 | " return outputs" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": null, 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "def get_combined_output(netG_alpha, netG_beta, X):\n", 568 | " r = [get_output(netG_alpha, netG_beta, X[i:i+1]) for i in range(X.shape[0])]\n", 569 | " r = np.array(r)\n", 570 | " return r.swapaxes(0,1)[:,:,0] " 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "def show_generator_image(A,B, netG_alpha, netG_beta):\n", 580 | " assert A.shape==B.shape\n", 581 | " \n", 582 | " rA = get_combined_output(netG_alpha, netG_beta, A)\n", 583 | " rB = get_combined_output(netG_beta, netG_alpha, B)\n", 584 | " \n", 585 | " arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]]) \n", 586 | " display_image(arr, 3)" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": null, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "def get_generater_function(netG):\n", 596 | " real_input = netG.inputs[0]\n", 597 | " fake_output = netG.outputs[0]\n", 598 | " function = K.function([real_input, K.learning_phase()], [fake_output])\n", 599 | " return function\n", 600 | "\n", 601 | "netG_A_function = get_generater_function(netG_A)\n", 602 | "netG_B_function = get_generater_function(netG_B)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "class ImagePool():\n", 612 | " def __init__(self, pool_size=200):\n", 613 | " self.pool_size = pool_size\n", 614 | " if self.pool_size > 0:\n", 615 | " self.num_imgs = 0\n", 616 | " self.images = []\n", 617 | "\n", 618 | " def query(self, images):\n", 619 | " if self.pool_size == 0:\n", 620 | " return images\n", 621 | " return_images = []\n", 622 | " for image in images:\n", 623 | " if self.num_imgs < self.pool_size:\n", 624 | " self.num_imgs = self.num_imgs + 1\n", 625 | " self.images.append(image)\n", 626 | " return_images.append(image)\n", 627 | " else:\n", 628 | " p = uniform(0, 1)\n", 629 | " if p > 0.5:\n", 630 | " random_id = randint(0, self.pool_size-1)\n", 631 | " tmp = self.images[random_id]\n", 632 | " self.images[random_id] = image\n", 633 | " return_images.append(tmp)\n", 634 | " else:\n", 635 | " return_images.append(image)\n", 636 | " return_images = np.stack(return_images, axis=0)\n", 637 | " return return_images" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "K.learning_phase()" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "import time\n", 656 | "from IPython.display import clear_output\n", 657 | "time_start = time.time()\n", 658 | "how_many_epochs = 10\n", 659 | "iteration_count = 0\n", 660 | "epoch_count = 0\n", 661 | "display_freq = 1000 // batch_size \n", 662 | "save_freq = 20000 // batch_size\n", 663 | "val_batch = minibatchAB(val_A, val_B, batch_size=4)\n", 664 | "_, val_A, val_B = next(val_batch)\n", 665 | "train_batch = minibatchAB(train_A, train_B, batch_size)\n", 666 | " \n", 667 | "fake_A_pool = ImagePool()\n", 668 | "fake_B_pool = ImagePool()\n", 669 | "\n", 670 | "while epoch_count < how_many_epochs: \n", 671 | " target_label = np.zeros((batch_size, 1))\n", 672 | " epoch_count, A, B = next(train_batch)\n", 673 | "\n", 674 | " tmp_fake_B = netG_A_function([A, 1])[0]\n", 675 | " tmp_fake_A = netG_B_function([B, 1])[0]\n", 676 | " \n", 677 | " _fake_B = fake_B_pool.query(tmp_fake_B)\n", 678 | " _fake_A = fake_A_pool.query(tmp_fake_A)\n", 679 | "\n", 680 | " netG_train_function.train_on_batch([A, B], target_label)\n", 681 | " \n", 682 | " netD_B_train_function.train_on_batch([B, _fake_B], target_label)\n", 683 | " netD_A_train_function.train_on_batch([A, _fake_A], target_label)\n", 684 | " \n", 685 | " iteration_count+=1\n", 686 | " \n", 687 | " save_name = dpath + '{}' + str(iteration_count) + '.h5'\n", 688 | " \n", 689 | " if iteration_count%display_freq == 0:\n", 690 | " clear_output()\n", 691 | " timecost = (time.time()-time_start)/60\n", 692 | " print('epoch_count: {} iter_count: {} timecost: {}mins'.format(epoch_count, iteration_count, timecost))\n", 693 | " show_generator_image(val_A,val_B, netG_A, netG_B)\n", 694 | " netG_A.save_weights(save_name.format('tf_GA_weights'))\n", 695 | " netG_B.save_weights(save_name.format('tf_GB_weights'))\n", 696 | "\n", 697 | " if iteration_count%save_freq == 0:\n", 698 | " netD_A.save_weights(save_name.format('tf_DA_weights'))\n", 699 | " netD_B.save_weights(save_name.format('tf_DB_weights'))\n", 700 | " netG_train_function.save_weights(save_name.format('tf_G_train_weights'))\n", 701 | " netD_A_train_function.save_weights(save_name.format('tf_D_A_train_weights'))\n", 702 | " netD_B_train_function.save_weights(save_name.format('tf_D_B_train_weights'))" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": null, 708 | "metadata": {}, 709 | "outputs": [], 710 | "source": [ 711 | "# inference" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": {}, 718 | "outputs": [], 719 | "source": [ 720 | "load_name = dpath + '{}' + '1000.h5'\n", 721 | "netG_A.load_weights(load_name.format('tf_GA_weights'))\n", 722 | "netG_B.load_weights(load_name.format('tf_GB_weights'))\n", 723 | "netD_A.load_weights(load_name.format('tf_DA_weights'))\n", 724 | "netD_B.load_weights(load_name.format('tf_DB_weights'))\n", 725 | "netG_train_function.load_weights(load_name.format('tf_G_train_weights'))\n", 726 | "netD_A_train_function.load_weights(load_name.format('tf_D_A_train_weights'))\n", 727 | "netD_B_train_function.load_weights(load_name.format('tf_D_B_train_weights'))" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "val_batch = minibatchAB(val_A, val_B, batch_size=2)" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "# run batch normalization layer in training mode" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [ 754 | "_,A, B = next(val_batch)\n", 755 | "show_generator_image(A,B, netG_A, netG_B)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "markdown", 760 | "metadata": {}, 761 | "source": [ 762 | "\n", 763 | "\n", 764 | "\n", 765 | "\n", 766 | "\n", 767 | "\n", 768 | "\n" 769 | ] 770 | } 771 | ], 772 | "metadata": { 773 | "kernelspec": { 774 | "display_name": "Python 3", 775 | "language": "python", 776 | "name": "python3" 777 | }, 778 | "language_info": { 779 | "codemirror_mode": { 780 | "name": "ipython", 781 | "version": 3 782 | }, 783 | "file_extension": ".py", 784 | "mimetype": "text/x-python", 785 | "name": "python", 786 | "nbconvert_exporter": "python", 787 | "pygments_lexer": "ipython3", 788 | "version": "3.6.3" 789 | } 790 | }, 791 | "nbformat": 4, 792 | "nbformat_minor": 1 793 | } 794 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cyclegan-keras 2 | 3 | keras implementation of cycle-gan based on [pytorch-CycleGan](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) (by junyanz) and [tf/torch/keras/lasagne] (by tjwei) 4 | 5 | ## Prerequisites 6 | train.py has not been tested, CycleGAN-keras.ipynb is recommended and tested OK on 7 | - Ubuntu 16.04 8 | - Python 3.6 9 | - Keras 2.1.2 10 | - Tensorflow 1.0.1 11 | - NVIDIA GPU + CUDA8.0 CuDNN6 or CuDNN5 12 | 13 | 14 | 15 | ## Demos [[manga-colorization-demo]](http://www.styletransfer.tech) 16 | 17 | Colorize manga with Cycle-GAN model totally run in browser. 18 | - Built based on [Keras.js](https://github.com/transcranial/keras-js) and [keras.js demos](https://transcranial.github.io/keras-js) 19 | - Model trained by juyter notebook version of this git repo 20 | - Check [Demo-Introduction](https://zhuanlan.zhihu.com/p/34672860) or my [demo-repo](https://github.com/MingwangLin/manga-colorization) for more details 21 | 22 | 23 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | from PIL import Image 4 | from random import randint, shuffle 5 | 6 | 7 | def load_data(file_pattern): 8 | return glob.glob(file_pattern) 9 | 10 | 11 | def read_image(img, loadsize=286, imagesize=256): 12 | img = Image.open(img).convert('RGB') 13 | img = img.resize((loadsize, loadsize), Image.BICUBIC) 14 | img = np.array(img) 15 | assert img.shape == (loadsize, loadsize, 3) 16 | img = img.astype(np.float32) 17 | img = (img - 127.5) / 127.5 18 | # random jitter 19 | w_offset = h_offset = randint(0, max(0, loadsize - imagesize - 1)) 20 | img = img[h_offset:h_offset + imagesize, w_offset:w_offset + imagesize, :] 21 | # horizontal flip 22 | if randint(0, 1): 23 | img = img[:, ::-1] 24 | return img 25 | 26 | 27 | def try_read_img(data, index): 28 | try: 29 | img = read_image(data[index]) 30 | return img 31 | except: 32 | try_read_img(data, index + 1) 33 | 34 | 35 | def minibatch(data, batch_size): 36 | length = len(data) 37 | shuffle(data) 38 | epoch = i = 0 39 | tmpsize = None 40 | 41 | while True: 42 | size = tmpsize if tmpsize else batch_size 43 | if i + size > length: 44 | shuffle(data) 45 | i = 0 46 | epoch += 1 47 | rtn = [] 48 | for j in range(i, i + size): 49 | img = try_read_img(data, j) 50 | rtn.append(img) 51 | rtn = np.stack(rtn, axis=0) 52 | i += size 53 | tmpsize = yield epoch, np.float32(rtn) 54 | 55 | 56 | def minibatchAB(dataA, dataB, batch_size): 57 | batchA = minibatch(dataA, batch_size) 58 | batchB = minibatch(dataB, batch_size) 59 | tmpsize = None 60 | while True: 61 | ep1, A = batchA.send(tmpsize) 62 | ep2, B = batchB.send(tmpsize) 63 | tmpsize = yield max(ep1, ep2), A, B 64 | -------------------------------------------------------------------------------- /data/save_data.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from IPython.display import display 4 | from PIL import Image 5 | 6 | 7 | def get_output(netG_alpha, netG_beta, X): 8 | real_input = X 9 | fake_output = netG_alpha.predict(real_input) 10 | rec_input = netG_beta.predict(fake_output) 11 | outputs = [fake_output, rec_input] 12 | return outputs 13 | 14 | 15 | def get_combined_output(netG_alpha, netG_beta, X): 16 | r = [get_output(netG_alpha, netG_beta, X[i:i + 1]) for i in range(X.shape[0])] 17 | r = np.array(r) 18 | return r.swapaxes(0, 1)[:, :, 0] 19 | 20 | 21 | def save_image(X, rows=1, image_size=256): 22 | assert X.shape[0] % rows == 0 23 | int_X = ((X * 255).clip(0, 255).astype('uint8')) 24 | int_X = int_X.reshape(-1, image_size, image_size, 3) 25 | int_X = int_X.reshape(rows, -1, image_size, image_size, 3).swapaxes(1, 2).reshape(rows * image_size, -1, 3) 26 | pil_X = Image.fromarray(int_X) 27 | t = str(time.time()) 28 | pil_X.save(dpath + 'results/' + t, 'JPEG') 29 | 30 | 31 | def show_generator_image(A, B, netG_alpha, netG_beta): 32 | assert A.shape == B.shape 33 | 34 | rA = get_combined_output(netG_alpha, netG_beta, A) 35 | rB = get_combined_output(netG_beta, netG_alpha, B) 36 | 37 | arr = np.concatenate([A, B, rA[0], rB[0], rA[1], rB[1]]) 38 | save_image(arr, rows=3) 39 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/models/__init__.py -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | 4 | def criterion_GAN(output, target, use_lsgan=True): 5 | if use_lsgan: 6 | diff = output - target 7 | dims = list(range(1, K.ndim(diff))) 8 | return K.expand_dims((K.mean(diff ** 2, dims)), 0) 9 | else: 10 | return K.mean(K.log(output + 1e-12) * target + K.log(1 - output + 1e-12) * (1 - target)) 11 | 12 | 13 | def criterion_cycle(rec, real): 14 | diff = K.abs(rec - real) 15 | dims = list(range(1, K.ndim(diff))) 16 | return K.expand_dims((K.mean(diff, dims)), 0) 17 | 18 | 19 | def netG_loss(G_tensors, loss_weight=10): 20 | netD_A_predict_fake, rec_A, G_A_input, netD_B_predict_fake, rec_B, G_B_input = G_tensors 21 | 22 | loss_G_B = criterion_GAN(netD_A_predict_fake, K.ones_like(netD_A_predict_fake)) 23 | loss_cyc_A = criterion_cycle(rec_A, G_A_input) 24 | 25 | loss_G_A = criterion_GAN(netD_B_predict_fake, K.ones_like(netD_B_predict_fake)) 26 | loss_cyc_B = criterion_cycle(rec_B, G_B_input) 27 | 28 | loss_G = loss_G_A + loss_G_B + loss_weight * (loss_cyc_A + loss_cyc_B) 29 | 30 | return loss_G 31 | 32 | 33 | def netD_loss(netD_predict): 34 | netD_predict_real, netD_predict_fake = netD_predict 35 | 36 | netD_loss_real = criterion_GAN(netD_predict_real, K.ones_like(netD_predict_real)) 37 | netD_loss_fake = criterion_GAN(netD_predict_fake, K.zeros_like(netD_predict_fake)) 38 | 39 | loss_netD = (1 / 2) * (netD_loss_real + netD_loss_fake) 40 | return loss_netD 41 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.models import Model 3 | from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout 4 | from keras.layers import UpSampling2D, Conv2DTranspose, Activation, Add 5 | from keras.layers.advanced_activations import LeakyReLU 6 | from keras.initializers import RandomNormal 7 | from keras_contrib.layers.normalization import InstanceNormalization 8 | 9 | 10 | def conv2d(f, *a, **k): 11 | return Conv2D(f, kernel_initializer=RandomNormal(0, 0.02), *a, **k) 12 | 13 | 14 | def batchnorm(): 15 | return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5, 16 | gamma_initializer=RandomNormal(1., 0.02)) 17 | 18 | 19 | def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False, 20 | has_activation_layer=True, use_leaky_relu=False, padding='same'): 21 | x = conv2d(filters, (size, size), strides=stride, padding=padding)(x) 22 | if has_norm_layer: 23 | if not use_norm_instance: 24 | x = batchnorm()(x) 25 | else: 26 | x = InstanceNormalization(axis=1)(x) 27 | if has_activation_layer: 28 | if not use_leaky_relu: 29 | x = Activation('relu')(x) 30 | else: 31 | x = LeakyReLU(alpha=0.2)(x) 32 | return x 33 | 34 | 35 | def res_block(x, filters=256, use_dropout=False): 36 | y = conv_block(x, filters, 3, (1, 1)) 37 | if use_dropout: 38 | y = Dropout(0.5)(y) 39 | y = conv_block(y, filters, 3, (1, 1), has_activation_layer=False) 40 | return Add()([y, x]) 41 | 42 | 43 | def up_block(x, filters, size, use_conv_transpose=True, use_norm_instance=False): 44 | if use_conv_transpose: 45 | x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same', 46 | use_bias=True if use_norm_instance else False, 47 | kernel_initializer=RandomNormal(0, 0.02))(x) 48 | x = batchnorm()(x) 49 | x = Activation('relu')(x) 50 | 51 | else: 52 | x = UpSampling2D()(x) 53 | x = conv_block(x, filters, size, (1, 1)) 54 | 55 | return x 56 | 57 | 58 | # Defines the Resnet generator 59 | def resnet_generator(image_size=256, input_nc=3, res_blocks=6): 60 | inputs = Input(shape=(image_size, image_size, input_nc)) 61 | x = inputs 62 | 63 | x = conv_block(x, 64, 7, (1, 1)) 64 | x = conv_block(x, 128, 3, (2, 2)) 65 | x = conv_block(x, 256, 3, (2, 2)) 66 | 67 | for i in range(res_blocks): 68 | x = res_block(x) 69 | 70 | x = up_block(x, 128, 3) 71 | x = up_block(x, 64, 3) 72 | 73 | x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1), padding='same')(x) 74 | outputs = x 75 | 76 | return Model(inputs=inputs, outputs=outputs), inputs, outputs 77 | 78 | 79 | # Defines the PatchGAN discriminator 80 | def n_layer_discriminator(image_size=256, input_nc=3, ndf=64, hidden_layers=3): 81 | """ 82 | input_nc: input channels 83 | ndf: filters of the first layer 84 | """ 85 | inputs = Input(shape=(image_size, image_size, input_nc)) 86 | x = inputs 87 | 88 | x = ZeroPadding2D(padding=(1, 1))(x) 89 | x = conv_block(x, ndf, 4, has_norm_layer=False, use_leaky_relu=True, padding='valid') 90 | 91 | x = ZeroPadding2D(padding=(1, 1))(x) 92 | for i in range(1, hidden_layers + 1): 93 | nf = 2 ** i * ndf 94 | x = conv_block(x, nf, 4, use_leaky_relu=True, padding='valid') 95 | x = ZeroPadding2D(padding=(1, 1))(x) 96 | 97 | x = conv2d(1, (4, 4), activation='sigmoid', strides=(1, 1))(x) 98 | outputs = x 99 | 100 | return Model(inputs=[inputs], outputs=outputs), inputs, outputs 101 | 102 | def get_generater_function(netG): 103 | real_input = netG.inputs[0] 104 | fake_output = netG.outputs[0] 105 | function = K.function([real_input], [fake_output]) 106 | return function 107 | 108 | -------------------------------------------------------------------------------- /models/train_function.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, BatchNormalization 2 | from keras.optimizers import Adam 3 | from keras.models import Model 4 | from keras.layers import Lambda 5 | from models.loss import netG_loss, netD_loss 6 | 7 | 8 | def get_train_function(inputs, loss_function, lambda_layer_inputs): 9 | Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=None, decay=0.0) 10 | train_function = Model(inputs, Lambda(loss_function)(lambda_layer_inputs)) 11 | train_function.compile('adam', 'mae') 12 | return train_function 13 | 14 | 15 | # create generator train function 16 | def netG_train_function_creator(netD_A, netD_B, netG_A, netG_B, real_A, real_B, fake_A, fake_B): 17 | netD_B_predict_fake = netD_B(fake_B) 18 | rec_A = netG_B(fake_B) 19 | netD_A_predict_fake = netD_A(fake_A) 20 | rec_B = netG_A(fake_A) 21 | lambda_layer_inputs = [netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B] 22 | for l in netG_A.layers: 23 | l.trainable = True 24 | for l in netG_B.layers: 25 | l.trainable = True 26 | for l in netD_A.layers: 27 | l.trainable = False 28 | for l in netD_B.layers: 29 | l.trainable = False 30 | netG_train_function = get_train_function(inputs=[real_A, real_B], loss_function=netG_loss, 31 | lambda_layer_inputs=lambda_layer_inputs) 32 | return netG_train_function 33 | 34 | 35 | # create discriminator A train function 36 | def netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, finesize, input_nc): 37 | netD_A_predict_real = netD_A(real_A) 38 | _fake_A = Input(shape=(finesize, finesize, input_nc)) 39 | _netD_A_predict_fake = netD_A(_fake_A) 40 | for l in netG_A.layers: 41 | l.trainable = False 42 | for l in netG_B.layers: 43 | l.trainable = False 44 | for l in netD_A.layers: 45 | l.trainable = True 46 | for l in netD_B.layers: 47 | l.trainable = False 48 | netD_A_train_function = get_train_function(inputs=[real_A, _fake_A], loss_function=netD_loss, 49 | lambda_layer_inputs=[netD_A_predict_real, _netD_A_predict_fake]) 50 | return netD_A_train_function 51 | 52 | 53 | # create discriminator B train function 54 | def netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, finesize, input_nc): 55 | netD_B_predict_real = netD_B(real_B) 56 | _fake_B = Input(shape=(finesize, finesize, input_nc)) 57 | _netD_B_predict_fake = netD_B(_fake_B) 58 | for l in netG_A.layers: 59 | l.trainable = False 60 | if isinstance(l, BatchNormalization): 61 | l._per_input_updates = {} 62 | for l in netG_B.layers: 63 | l.trainable = False 64 | if isinstance(l, BatchNormalization): 65 | l._per_input_updates = {} 66 | for l in netD_B.layers: 67 | l.trainable = True 68 | for l in netD_A.layers: 69 | l.trainable = False 70 | netD_B_train_function = get_train_function(inputs=[real_B, _fake_B], loss_function=netD_loss, 71 | lambda_layer_inputs=[netD_B_predict_real, _netD_B_predict_fake]) 72 | return netD_B_train_function 73 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class BaseOptions(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 14 | self.parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 15 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 16 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 17 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 18 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 19 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 20 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 21 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 22 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 23 | self.parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use.') 24 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 25 | self.parser.add_argument('--nThreads', default=6, type=int, help='# threads for loading data') 26 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 27 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 28 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 29 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 30 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 31 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 32 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 33 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 34 | 35 | self.initialized = True 36 | 37 | def parse(self): 38 | if not self.initialized: 39 | self.initialize() 40 | self.opt = self.parser.parse_args() 41 | self.opt.isTrain = self.isTrain # train or test 42 | 43 | args = vars(self.opt) 44 | 45 | print('------------ Options -------------') 46 | for k, v in sorted(args.items()): 47 | print('%s: %s' % (str(k), str(v))) 48 | print('-------------- End ----------------') 49 | 50 | # save to the disk 51 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 52 | util.mkdirs(expr_dir) 53 | file_name = os.path.join(expr_dir, 'opt.txt') 54 | with open(file_name, 'wt') as opt_file: 55 | opt_file.write('------------ Options -------------\n') 56 | for k, v in sorted(args.items()): 57 | opt_file.write('%s: %s\n' % (str(k), str(v))) 58 | opt_file.write('-------------- End ----------------\n') 59 | return self.opt 60 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=35, help='how many test images to run') 13 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 8 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 9 | self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 10 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 11 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 12 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 13 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 14 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 15 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 16 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 17 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 18 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 19 | self.parser.add_argument('--lambda_param', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') 20 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 21 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 22 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 23 | 24 | self.isTrain = True 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from IPython.display import clear_output 4 | from options.train_options import TrainOptions 5 | from data.data_loader import load_data, minibatchAB 6 | from data.save_data import show_generator_image 7 | from util.image_pool import ImagePool 8 | from models.networks import get_generater_function 9 | from models.networks import resnet_generator, n_layer_discriminator 10 | from models.train_function import * 11 | opt = TrainOptions().parse() 12 | 13 | # load data 14 | dpath = opt.dataroot 15 | train_A = load_data(dpath + 'trainA/*') 16 | train_B = load_data(dpath + 'trainB/*') 17 | train_batch = minibatchAB(train_A, train_B, batch_size=opt.batch_size) 18 | val_A = load_data(dpath + 'valA/*') 19 | val_B = load_data(dpath + 'valB/*') 20 | val_batch = minibatchAB(val_A, val_B, batch_size=4) 21 | 22 | # create gennerator models 23 | netG_A, real_A, fake_B = resnet_generator() 24 | netG_B, real_B, fake_A = resnet_generator() 25 | 26 | # create discriminator models 27 | netD_A = n_layer_discriminator() 28 | netD_B = n_layer_discriminator() 29 | 30 | # create generators train function 31 | netG_train_function = netG_train_function_creator(netD_A, netD_B, netG_A, netG_B, real_A, real_B, fake_A, fake_B) 32 | # create discriminator A train function 33 | netD_A_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_A, opt.finesize, opt.input_nc) 34 | # create discriminator B train function 35 | netD_B_train_function = netD_A_train_function(netD_A, netD_B, netG_A, netG_B, real_B, opt.finesize, opt.input_nc) 36 | 37 | # train loop 38 | time_start = time.time() 39 | how_many_epochs = 5 40 | iteration_count = 0 41 | epoch_count = 0 42 | batch_size = opt.batch_size 43 | display_freq = 10000 44 | 45 | netG_A_function = get_generater_function(netG_A) 46 | netG_B_functionr = get_generater_function(netG_B) 47 | 48 | fake_A_pool = ImagePool() 49 | fake_B_pool = ImagePool() 50 | 51 | while epoch_count < how_many_epochs: 52 | target_label = np.zeros((batch_size, 1)) 53 | epoch_count, A, B = next(train_batch) 54 | 55 | tmp_fake_B = netG_A_function([A])[0] 56 | tmp_fake_A = netG_B_functionr([B])[0] 57 | 58 | _fake_B = fake_B_pool.query(tmp_fake_B) 59 | _fake_A = fake_A_pool.query(tmp_fake_A) 60 | 61 | netG_train_function.train_on_batch([A, B], target_label) 62 | 63 | netD_B_train_function.train_on_batch([B, _fake_B], target_label) 64 | netD_A_train_function.train_on_batch([A, _fake_A], target_label) 65 | 66 | iteration_count += 1 67 | 68 | if iteration_count % display_freq == 0: 69 | clear_output() 70 | traintime = (time.time() - time_start) / iteration_count 71 | print('epoch_count: {} iter_count: {} timecost/iter: {}s'.format(epoch_count, iteration_count, traintime)) 72 | _, val_A, val_B = next(val_batch) 73 | show_generator_image(val_A, val_B, netG_A, netG_B) 74 | 75 | save_name = dpath + '{}' + str(iteration_count) + '.h5' 76 | 77 | netG_A.save(save_name.format('tf_GA')) 78 | netG_A.save_weights(save_name.format('tf_GA_weights')) 79 | netG_B.save(save_name.format('tf_GB')) 80 | netG_B.save_weights(save_name.format('tf_GB_weights')) 81 | netD_A.save(save_name.format('tf_DA')) 82 | 83 | netG_train_function.save(save_name.format('tf_G_train')) 84 | netG_train_function.save_weights(save_name.format('tf_G_train_weights')) 85 | netD_A_train_function.save(save_name.format('tf_D_A_train')) 86 | netD_A_train_function.save_weights(save_name.format('tf_D_A_train_weights')) 87 | netD_B_train_function.save(save_name.format('tf_D_B_train')) 88 | netD_B_train_function.save_weights(save_name.format('tf_D_B_train_weights')) 89 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingwangLin/cyclegan-keras/7a000382f63d4025ffdabdbc373e8e92ecae4573/util/__init__.py -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import randint, uniform 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size=50): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = uniform(0, 1) 23 | if p > 0.5: 24 | random_id = randint(0, self.pool_size - 1) 25 | tmp = self.images[random_id] 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = np.stack(return_images, axis=0) 31 | return return_images 32 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | import os 6 | 7 | 8 | # Converts a Tensor into a Numpy array 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(image_tensor, imtype=np.uint8): 11 | image_numpy = image_tensor[0].cpu().float().numpy() 12 | if image_numpy.shape[0] == 1: 13 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 15 | return image_numpy.astype(imtype) 16 | 17 | 18 | def diagnose_network(net, name='network'): 19 | mean = 0.0 20 | count = 0 21 | for param in net.parameters(): 22 | if param.grad is not None: 23 | mean += torch.mean(torch.abs(param.grad.data)) 24 | count += 1 25 | if count > 0: 26 | mean = mean / count 27 | print(name) 28 | print(mean) 29 | 30 | 31 | def save_image(image_numpy, image_path): 32 | image_pil = Image.fromarray(image_numpy) 33 | image_pil.save(image_path) 34 | 35 | 36 | def print_numpy(x, val=True, shp=False): 37 | x = x.astype(np.float64) 38 | if shp: 39 | print('shape,', x.shape) 40 | if val: 41 | x = x.flatten() 42 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 43 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 44 | 45 | 46 | def mkdirs(paths): 47 | if isinstance(paths, list) and not isinstance(paths, str): 48 | for path in paths: 49 | mkdir(path) 50 | else: 51 | mkdir(paths) 52 | 53 | 54 | def mkdir(path): 55 | if not os.path.exists(path): 56 | os.makedirs(path) 57 | --------------------------------------------------------------------------------