├── FashionGAN-Tutorial.ipynb └── generatormodel.h5 /FashionGAN-Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "64f04b6a-fe54-41ce-92fe-ae94fc587387", 6 | "metadata": { 7 | "id": "64f04b6a-fe54-41ce-92fe-ae94fc587387" 8 | }, 9 | "source": [ 10 | "# 1. Import Dependencies and Data" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "c3c355f3-39f9-4f73-ac4e-5ad62198a580", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install tensorflow tensorflow-gpu matplotlib tensorflow-datasets ipywidgets" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "e6112401-2397-423d-b8c2-113a9d323ab0", 29 | "metadata": { 30 | "tags": [] 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "!pip list" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "id": "b84be907-35e2-43db-a645-b6b164302aaa", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# Bringing in tensorflow\n", 45 | "import tensorflow as tf\n", 46 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 47 | "for gpu in gpus: \n", 48 | " tf.config.experimental.set_memory_growth(gpu, True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 6, 54 | "id": "a0f2aa32-064b-448c-bb27-f19a48c40115", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# Brining in tensorflow datasets for fashion mnist \n", 59 | "import tensorflow_datasets as tfds\n", 60 | "# Bringing in matplotlib for viz stuff\n", 61 | "from matplotlib import pyplot as plt" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 7, 67 | "id": "c933f988-d1ee-4d4d-8028-368a158c27e2", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Use the tensorflow datasets api to bring in the data source\n", 72 | "ds = tfds.load('fashion_mnist', split='train')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 14, 78 | "id": "c361db0d-8e7b-43e1-97f9-5e3f7cb01ffe", 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "2" 85 | ] 86 | }, 87 | "execution_count": 14, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "ds.as_numpy_iterator().next()['label']" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "ea1635e4-4beb-493d-92c1-b106c806ca70", 99 | "metadata": { 100 | "id": "ea1635e4-4beb-493d-92c1-b106c806ca70" 101 | }, 102 | "source": [ 103 | "# 2. Viz Data and Build Dataset" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 16, 109 | "id": "b0c62caf-e406-4d12-af31-6f4848155844", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Do some data transformation\n", 114 | "import numpy as np" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 17, 120 | "id": "3215c900-6e85-4b39-b300-ea18faf30e5c", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Setup connection aka iterator\n", 125 | "dataiterator = ds.as_numpy_iterator()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "c1d6e079-46da-43ca-80d2-b3c864d90360", 132 | "metadata": { 133 | "scrolled": true, 134 | "tags": [] 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "# Getting data out of the pipeline\n", 139 | "dataiterator.next()['image']" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 67, 145 | "id": "deb5fca0-fd8a-4557-9c72-1a60c289a2e5", 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "image/png": "\n", 151 | "text/plain": [ 152 | "
" 153 | ] 154 | }, 155 | "metadata": { 156 | "needs_background": "light" 157 | }, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "# Setup the subplot formatting \n", 163 | "fig, ax = plt.subplots(ncols=4, figsize=(20,20))\n", 164 | "# Loop four times and get images \n", 165 | "for idx in range(4): \n", 166 | " # Grab an image and label\n", 167 | " sample = dataiterator.next()\n", 168 | " # Plot the image using a specific subplot \n", 169 | " ax[idx].imshow(np.squeeze(sample['image']))\n", 170 | " # Appending the image label as the plot title \n", 171 | " ax[idx].title.set_text(sample['label'])" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 69, 177 | "id": "66c9d901-6a5c-42fd-ad06-cc03f7829728", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "# Scale and return images only \n", 182 | "def scale_images(data): \n", 183 | " image = data['image']\n", 184 | " return image / 255" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 72, 190 | "id": "dfc9b6b1-e06e-421c-9c5c-bfc3b3e3be77", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# Reload the dataset \n", 195 | "ds = tfds.load('fashion_mnist', split='train')\n", 196 | "# Running the dataset through the scale_images preprocessing step\n", 197 | "ds = ds.map(scale_images) \n", 198 | "# Cache the dataset for that batch \n", 199 | "ds = ds.cache()\n", 200 | "# Shuffle it up \n", 201 | "ds = ds.shuffle(60000)\n", 202 | "# Batch into 128 images per sample\n", 203 | "ds = ds.batch(128)\n", 204 | "# Reduces the likelihood of bottlenecking \n", 205 | "ds = ds.prefetch(64)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 74, 211 | "id": "fbb52952-faa1-445f-8931-2f0f37224bfb", 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "(128, 28, 28, 1)" 218 | ] 219 | }, 220 | "execution_count": 74, 221 | "metadata": {}, 222 | "output_type": "execute_result" 223 | } 224 | ], 225 | "source": [ 226 | "ds.as_numpy_iterator().next().shape" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "id": "9a5b08df-7b20-41f4-a8ff-112dface1cb0", 232 | "metadata": { 233 | "id": "9a5b08df-7b20-41f4-a8ff-112dface1cb0" 234 | }, 235 | "source": [ 236 | "# 3. Build Neural Network" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "38f66add-a3db-467f-96c3-f87b9f880159", 242 | "metadata": { 243 | "id": "38f66add-a3db-467f-96c3-f87b9f880159" 244 | }, 245 | "source": [ 246 | "### 3.1 Import Modelling Components" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 75, 252 | "id": "bb72da39-377f-4264-b525-c87f49fb0356", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "# Bring in the sequential api for the generator and discriminator\n", 257 | "from tensorflow.keras.models import Sequential\n", 258 | "# Bring in the layers for the neural network\n", 259 | "from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "id": "c40405df-1439-4661-8785-d76698df8152", 265 | "metadata": { 266 | "id": "c40405df-1439-4661-8785-d76698df8152" 267 | }, 268 | "source": [ 269 | "### 3.2 Build Generator" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 101, 275 | "id": "5d29d43a-e02a-4031-a0ec-de8aa810c118", 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def build_generator(): \n", 280 | " model = Sequential()\n", 281 | " \n", 282 | " # Takes in random values and reshapes it to 7x7x128\n", 283 | " # Beginnings of a generated image\n", 284 | " model.add(Dense(7*7*128, input_dim=128))\n", 285 | " model.add(LeakyReLU(0.2))\n", 286 | " model.add(Reshape((7,7,128)))\n", 287 | " \n", 288 | " # Upsampling block 1 \n", 289 | " model.add(UpSampling2D())\n", 290 | " model.add(Conv2D(128, 5, padding='same'))\n", 291 | " model.add(LeakyReLU(0.2))\n", 292 | " \n", 293 | " # Upsampling block 2 \n", 294 | " model.add(UpSampling2D())\n", 295 | " model.add(Conv2D(128, 5, padding='same'))\n", 296 | " model.add(LeakyReLU(0.2))\n", 297 | " \n", 298 | " # Convolutional block 1\n", 299 | " model.add(Conv2D(128, 4, padding='same'))\n", 300 | " model.add(LeakyReLU(0.2))\n", 301 | " \n", 302 | " # Convolutional block 2\n", 303 | " model.add(Conv2D(128, 4, padding='same'))\n", 304 | " model.add(LeakyReLU(0.2))\n", 305 | " \n", 306 | " # Conv layer to get to one channel\n", 307 | " model.add(Conv2D(1, 4, padding='same', activation='sigmoid'))\n", 308 | " \n", 309 | " return model" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 107, 315 | "id": "741b0d58-1b9f-4260-8405-dc400c73f843", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "generator = build_generator()" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 132, 325 | "id": "259ab9c1-6d6c-49a0-b0c4-f45b7c68f588", 326 | "metadata": { 327 | "scrolled": true, 328 | "tags": [] 329 | }, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "Model: \"sequential_10\"\n", 336 | "_________________________________________________________________\n", 337 | " Layer (type) Output Shape Param # \n", 338 | "=================================================================\n", 339 | " dense_10 (Dense) (None, 6272) 809088 \n", 340 | " \n", 341 | " leaky_re_lu_26 (LeakyReLU) (None, 6272) 0 \n", 342 | " \n", 343 | " reshape_9 (Reshape) (None, 7, 7, 128) 0 \n", 344 | " \n", 345 | " up_sampling2d_12 (UpSamplin (None, 14, 14, 128) 0 \n", 346 | " g2D) \n", 347 | " \n", 348 | " conv2d_19 (Conv2D) (None, 14, 14, 128) 409728 \n", 349 | " \n", 350 | " leaky_re_lu_27 (LeakyReLU) (None, 14, 14, 128) 0 \n", 351 | " \n", 352 | " up_sampling2d_13 (UpSamplin (None, 28, 28, 128) 0 \n", 353 | " g2D) \n", 354 | " \n", 355 | " conv2d_20 (Conv2D) (None, 28, 28, 128) 409728 \n", 356 | " \n", 357 | " leaky_re_lu_28 (LeakyReLU) (None, 28, 28, 128) 0 \n", 358 | " \n", 359 | " conv2d_21 (Conv2D) (None, 28, 28, 128) 262272 \n", 360 | " \n", 361 | " leaky_re_lu_29 (LeakyReLU) (None, 28, 28, 128) 0 \n", 362 | " \n", 363 | " conv2d_22 (Conv2D) (None, 28, 28, 128) 262272 \n", 364 | " \n", 365 | " leaky_re_lu_30 (LeakyReLU) (None, 28, 28, 128) 0 \n", 366 | " \n", 367 | " conv2d_23 (Conv2D) (None, 28, 28, 1) 2049 \n", 368 | " \n", 369 | "=================================================================\n", 370 | "Total params: 2,155,137\n", 371 | "Trainable params: 2,155,137\n", 372 | "Non-trainable params: 0\n", 373 | "_________________________________________________________________\n" 374 | ] 375 | } 376 | ], 377 | "source": [ 378 | "generator.summary()" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 166, 384 | "id": "10ba4d1c-6a15-4097-bf63-5fe6ddb404b6", 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "img = generator.predict(np.random.randn(4,128,1))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 136, 394 | "id": "9b4e0cb6-d741-4d43-b845-2a8f2615765b", 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "data": { 399 | "image/png": "\n", 400 | "text/plain": [ 401 | "
" 402 | ] 403 | }, 404 | "metadata": { 405 | "needs_background": "light" 406 | }, 407 | "output_type": "display_data" 408 | } 409 | ], 410 | "source": [ 411 | "# Generate new fashion\n", 412 | "img = generator.predict(np.random.randn(4,128,1))\n", 413 | "# Setup the subplot formatting \n", 414 | "fig, ax = plt.subplots(ncols=4, figsize=(20,20))\n", 415 | "# Loop four times and get images \n", 416 | "for idx, img in enumerate(img): \n", 417 | " # Plot the image using a specific subplot \n", 418 | " ax[idx].imshow(np.squeeze(img))\n", 419 | " # Appending the image label as the plot title \n", 420 | " ax[idx].title.set_text(idx)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "id": "2415abbf-24ed-4bac-8fb8-12c65017ec22", 426 | "metadata": { 427 | "id": "2415abbf-24ed-4bac-8fb8-12c65017ec22" 428 | }, 429 | "source": [ 430 | "### 3.3 Build Discriminator" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 150, 436 | "id": "b4e70bcb-cfd5-42bb-aed0-79f19bb38d17", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "def build_discriminator(): \n", 441 | " model = Sequential()\n", 442 | " \n", 443 | " # First Conv Block\n", 444 | " model.add(Conv2D(32, 5, input_shape = (28,28,1)))\n", 445 | " model.add(LeakyReLU(0.2))\n", 446 | " model.add(Dropout(0.4))\n", 447 | " \n", 448 | " # Second Conv Block\n", 449 | " model.add(Conv2D(64, 5))\n", 450 | " model.add(LeakyReLU(0.2))\n", 451 | " model.add(Dropout(0.4))\n", 452 | " \n", 453 | " # Third Conv Block\n", 454 | " model.add(Conv2D(128, 5))\n", 455 | " model.add(LeakyReLU(0.2))\n", 456 | " model.add(Dropout(0.4))\n", 457 | " \n", 458 | " # Fourth Conv Block\n", 459 | " model.add(Conv2D(256, 5))\n", 460 | " model.add(LeakyReLU(0.2))\n", 461 | " model.add(Dropout(0.4))\n", 462 | " \n", 463 | " # Flatten then pass to dense layer\n", 464 | " model.add(Flatten())\n", 465 | " model.add(Dropout(0.4))\n", 466 | " model.add(Dense(1, activation='sigmoid'))\n", 467 | " \n", 468 | " return model " 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 151, 474 | "id": "7173eb57-250b-4d21-9b37-de842c4552ac", 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "discriminator = build_discriminator()" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 152, 484 | "id": "ed6fecbc-f214-4f50-865c-91887b2430e7", 485 | "metadata": { 486 | "scrolled": true, 487 | "tags": [] 488 | }, 489 | "outputs": [ 490 | { 491 | "name": "stdout", 492 | "output_type": "stream", 493 | "text": [ 494 | "Model: \"sequential_15\"\n", 495 | "_________________________________________________________________\n", 496 | " Layer (type) Output Shape Param # \n", 497 | "=================================================================\n", 498 | " conv2d_32 (Conv2D) (None, 24, 24, 32) 832 \n", 499 | " \n", 500 | " leaky_re_lu_39 (LeakyReLU) (None, 24, 24, 32) 0 \n", 501 | " \n", 502 | " dropout_8 (Dropout) (None, 24, 24, 32) 0 \n", 503 | " \n", 504 | " conv2d_33 (Conv2D) (None, 20, 20, 64) 51264 \n", 505 | " \n", 506 | " leaky_re_lu_40 (LeakyReLU) (None, 20, 20, 64) 0 \n", 507 | " \n", 508 | " dropout_9 (Dropout) (None, 20, 20, 64) 0 \n", 509 | " \n", 510 | " conv2d_34 (Conv2D) (None, 16, 16, 128) 204928 \n", 511 | " \n", 512 | " leaky_re_lu_41 (LeakyReLU) (None, 16, 16, 128) 0 \n", 513 | " \n", 514 | " dropout_10 (Dropout) (None, 16, 16, 128) 0 \n", 515 | " \n", 516 | " conv2d_35 (Conv2D) (None, 12, 12, 256) 819456 \n", 517 | " \n", 518 | " leaky_re_lu_42 (LeakyReLU) (None, 12, 12, 256) 0 \n", 519 | " \n", 520 | " dropout_11 (Dropout) (None, 12, 12, 256) 0 \n", 521 | " \n", 522 | " flatten (Flatten) (None, 36864) 0 \n", 523 | " \n", 524 | " dropout_12 (Dropout) (None, 36864) 0 \n", 525 | " \n", 526 | " dense_11 (Dense) (None, 1) 36865 \n", 527 | " \n", 528 | "=================================================================\n", 529 | "Total params: 1,113,345\n", 530 | "Trainable params: 1,113,345\n", 531 | "Non-trainable params: 0\n", 532 | "_________________________________________________________________\n" 533 | ] 534 | } 535 | ], 536 | "source": [ 537 | "discriminator.summary()" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 161, 543 | "id": "19e32424-f9c5-499c-a13f-b450bc525bdc", 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "img = img[0]" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 163, 553 | "id": "9ce3acc9-02c8-468f-915a-0efd52da0bad", 554 | "metadata": {}, 555 | "outputs": [ 556 | { 557 | "data": { 558 | "text/plain": [ 559 | "(28, 28, 1)" 560 | ] 561 | }, 562 | "execution_count": 163, 563 | "metadata": {}, 564 | "output_type": "execute_result" 565 | } 566 | ], 567 | "source": [ 568 | "img.shape" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 167, 574 | "id": "8cd15246-b40c-4c7a-912d-b88a1c5c463b", 575 | "metadata": {}, 576 | "outputs": [ 577 | { 578 | "data": { 579 | "text/plain": [ 580 | "array([[0.50057805],\n", 581 | " [0.5006448 ],\n", 582 | " [0.50065 ],\n", 583 | " [0.5005127 ]], dtype=float32)" 584 | ] 585 | }, 586 | "execution_count": 167, 587 | "metadata": {}, 588 | "output_type": "execute_result" 589 | } 590 | ], 591 | "source": [ 592 | "discriminator.predict(img)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "id": "39b343b0-38d3-4281-bedb-72099a18097e", 598 | "metadata": { 599 | "id": "39b343b0-38d3-4281-bedb-72099a18097e" 600 | }, 601 | "source": [ 602 | "# 4. Construct Training Loop" 603 | ] 604 | }, 605 | { 606 | "cell_type": "markdown", 607 | "id": "884abab3-2f74-442d-856f-e104ef1ac8ef", 608 | "metadata": { 609 | "id": "884abab3-2f74-442d-856f-e104ef1ac8ef" 610 | }, 611 | "source": [ 612 | "### 4.1 Setup Losses and Optimizers" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 168, 618 | "id": "0bb1d23a-ea68-451a-bb38-e7795dc24311", 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [ 622 | "# Adam is going to be the optimizer for both\n", 623 | "from tensorflow.keras.optimizers import Adam\n", 624 | "# Binary cross entropy is going to be the loss for both \n", 625 | "from tensorflow.keras.losses import BinaryCrossentropy" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 169, 631 | "id": "198b2d4e-d6b9-4b6c-a98c-65cd1b81da26", 632 | "metadata": {}, 633 | "outputs": [], 634 | "source": [ 635 | "g_opt = Adam(learning_rate=0.0001) \n", 636 | "d_opt = Adam(learning_rate=0.00001) \n", 637 | "g_loss = BinaryCrossentropy()\n", 638 | "d_loss = BinaryCrossentropy()" 639 | ] 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "id": "9f170b0e-f731-4cbd-8068-24896f462c08", 644 | "metadata": { 645 | "id": "9f170b0e-f731-4cbd-8068-24896f462c08" 646 | }, 647 | "source": [ 648 | "### 4.2 Build Subclassed Model" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 170, 654 | "id": "9e2f5654-ed22-462d-be32-6c43d8b99b74", 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [ 658 | "# Importing the base model class to subclass our training step \n", 659 | "from tensorflow.keras.models import Model" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 193, 665 | "id": "40a0af46-0243-4396-94d6-c1316d984de9", 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [ 669 | "class FashionGAN(Model): \n", 670 | " def __init__(self, generator, discriminator, *args, **kwargs):\n", 671 | " # Pass through args and kwargs to base class \n", 672 | " super().__init__(*args, **kwargs)\n", 673 | " \n", 674 | " # Create attributes for gen and disc\n", 675 | " self.generator = generator \n", 676 | " self.discriminator = discriminator \n", 677 | " \n", 678 | " def compile(self, g_opt, d_opt, g_loss, d_loss, *args, **kwargs): \n", 679 | " # Compile with base class\n", 680 | " super().compile(*args, **kwargs)\n", 681 | " \n", 682 | " # Create attributes for losses and optimizers\n", 683 | " self.g_opt = g_opt\n", 684 | " self.d_opt = d_opt\n", 685 | " self.g_loss = g_loss\n", 686 | " self.d_loss = d_loss \n", 687 | "\n", 688 | " def train_step(self, batch):\n", 689 | " # Get the data \n", 690 | " real_images = batch\n", 691 | " fake_images = self.generator(tf.random.normal((128, 128, 1)), training=False)\n", 692 | " \n", 693 | " # Train the discriminator\n", 694 | " with tf.GradientTape() as d_tape: \n", 695 | " # Pass the real and fake images to the discriminator model\n", 696 | " yhat_real = self.discriminator(real_images, training=True) \n", 697 | " yhat_fake = self.discriminator(fake_images, training=True)\n", 698 | " yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)\n", 699 | " \n", 700 | " # Create labels for real and fakes images\n", 701 | " y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)\n", 702 | " \n", 703 | " # Add some noise to the TRUE outputs\n", 704 | " noise_real = 0.15*tf.random.uniform(tf.shape(yhat_real))\n", 705 | " noise_fake = -0.15*tf.random.uniform(tf.shape(yhat_fake))\n", 706 | " y_realfake += tf.concat([noise_real, noise_fake], axis=0)\n", 707 | " \n", 708 | " # Calculate loss - BINARYCROSS \n", 709 | " total_d_loss = self.d_loss(y_realfake, yhat_realfake)\n", 710 | " \n", 711 | " # Apply backpropagation - nn learn \n", 712 | " dgrad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables) \n", 713 | " self.d_opt.apply_gradients(zip(dgrad, self.discriminator.trainable_variables))\n", 714 | " \n", 715 | " # Train the generator \n", 716 | " with tf.GradientTape() as g_tape: \n", 717 | " # Generate some new images\n", 718 | " gen_images = self.generator(tf.random.normal((128,128,1)), training=True)\n", 719 | " \n", 720 | " # Create the predicted labels\n", 721 | " predicted_labels = self.discriminator(gen_images, training=False)\n", 722 | " \n", 723 | " # Calculate loss - trick to training to fake out the discriminator\n", 724 | " total_g_loss = self.g_loss(tf.zeros_like(predicted_labels), predicted_labels) \n", 725 | " \n", 726 | " # Apply backprop\n", 727 | " ggrad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)\n", 728 | " self.g_opt.apply_gradients(zip(ggrad, self.generator.trainable_variables))\n", 729 | " \n", 730 | " return {\"d_loss\":total_d_loss, \"g_loss\":total_g_loss}" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": 194, 736 | "id": "24d248c3-f4c1-4478-a699-a5811a7b1fd0", 737 | "metadata": {}, 738 | "outputs": [], 739 | "source": [ 740 | "# Create instance of subclassed model\n", 741 | "fashgan = FashionGAN(generator, discriminator)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 195, 747 | "id": "e1cf7e02-ee1a-4901-bdf0-9aa2301f8cfc", 748 | "metadata": {}, 749 | "outputs": [], 750 | "source": [ 751 | "# Compile the model\n", 752 | "fashgan.compile(g_opt, d_opt, g_loss, d_loss)" 753 | ] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "id": "e06d0adb-38d0-4558-b824-7416cf880082", 758 | "metadata": { 759 | "id": "e06d0adb-38d0-4558-b824-7416cf880082" 760 | }, 761 | "source": [ 762 | "### 4.3 Build Callback" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": 184, 768 | "id": "548f6918-366c-4799-9dac-1acedaab40c4", 769 | "metadata": {}, 770 | "outputs": [], 771 | "source": [ 772 | "import os\n", 773 | "from tensorflow.keras.preprocessing.image import array_to_img\n", 774 | "from tensorflow.keras.callbacks import Callback" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 185, 780 | "id": "d3e2bb77-2d7d-40d0-809f-526b8fd34170", 781 | "metadata": {}, 782 | "outputs": [], 783 | "source": [ 784 | "class ModelMonitor(Callback):\n", 785 | " def __init__(self, num_img=3, latent_dim=128):\n", 786 | " self.num_img = num_img\n", 787 | " self.latent_dim = latent_dim\n", 788 | "\n", 789 | " def on_epoch_end(self, epoch, logs=None):\n", 790 | " random_latent_vectors = tf.random.uniform((self.num_img, self.latent_dim,1))\n", 791 | " generated_images = self.model.generator(random_latent_vectors)\n", 792 | " generated_images *= 255\n", 793 | " generated_images.numpy()\n", 794 | " for i in range(self.num_img):\n", 795 | " img = array_to_img(generated_images[i])\n", 796 | " img.save(os.path.join('images', f'generated_img_{epoch}_{i}.png'))" 797 | ] 798 | }, 799 | { 800 | "cell_type": "markdown", 801 | "id": "16e2f159-25e7-4e35-95ef-f0fd18ac5897", 802 | "metadata": { 803 | "id": "16e2f159-25e7-4e35-95ef-f0fd18ac5897" 804 | }, 805 | "source": [ 806 | "### 4.3 Train " 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": 196, 812 | "id": "a779dceb-aba6-4bf3-af49-0d32a76dd2f7", 813 | "metadata": { 814 | "scrolled": true, 815 | "tags": [] 816 | }, 817 | "outputs": [ 818 | { 819 | "name": "stdout", 820 | "output_type": "stream", 821 | "text": [ 822 | "Epoch 1/20\n", 823 | "469/469 [==============================] - 43s 88ms/step - d_loss: 0.5911 - g_loss: 0.6832\n", 824 | "Epoch 2/20\n", 825 | "469/469 [==============================] - 41s 88ms/step - d_loss: 0.4158 - g_loss: 3.2742\n", 826 | "Epoch 3/20\n", 827 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2878 - g_loss: 6.5225\n", 828 | "Epoch 4/20\n", 829 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2780 - g_loss: 6.5009\n", 830 | "Epoch 5/20\n", 831 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2755 - g_loss: 6.2294\n", 832 | "Epoch 6/20\n", 833 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2732 - g_loss: 5.9429\n", 834 | "Epoch 7/20\n", 835 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2728 - g_loss: 5.6715\n", 836 | "Epoch 8/20\n", 837 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2708 - g_loss: 5.4097\n", 838 | "Epoch 9/20\n", 839 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2704 - g_loss: 5.1352\n", 840 | "Epoch 10/20\n", 841 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2701 - g_loss: 4.8858\n", 842 | "Epoch 11/20\n", 843 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2694 - g_loss: 4.6344\n", 844 | "Epoch 12/20\n", 845 | "469/469 [==============================] - 40s 85ms/step - d_loss: 0.2691 - g_loss: 4.3753\n", 846 | "Epoch 13/20\n", 847 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2694 - g_loss: 4.1332\n", 848 | "Epoch 14/20\n", 849 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2687 - g_loss: 3.9182\n", 850 | "Epoch 15/20\n", 851 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2679 - g_loss: 3.7484\n", 852 | "Epoch 16/20\n", 853 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2681 - g_loss: 3.5493\n", 854 | "Epoch 17/20\n", 855 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2683 - g_loss: 3.3835\n", 856 | "Epoch 18/20\n", 857 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2675 - g_loss: 3.2472\n", 858 | "Epoch 19/20\n", 859 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.2679 - g_loss: 3.1084\n", 860 | "Epoch 20/20\n", 861 | "469/469 [==============================] - 40s 86ms/step - d_loss: 0.3128 - g_loss: 2.6171\n" 862 | ] 863 | } 864 | ], 865 | "source": [ 866 | "# Recommend 2000 epochs\n", 867 | "hist = fashgan.fit(ds, epochs=20, callbacks=[ModelMonitor()])" 868 | ] 869 | }, 870 | { 871 | "cell_type": "markdown", 872 | "id": "39c665a1-a4cc-41ac-a08a-2e14ba64e88d", 873 | "metadata": { 874 | "id": "39c665a1-a4cc-41ac-a08a-2e14ba64e88d" 875 | }, 876 | "source": [ 877 | "### 4.4 Review Performance" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": 198, 883 | "id": "54381e8c-93ee-4022-9df6-24c4356720fe", 884 | "metadata": {}, 885 | "outputs": [ 886 | { 887 | "data": { 888 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAEVCAYAAADJrK/3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAApYUlEQVR4nO3deXyU1b3H8c9vZkLCvoZdBBcQ2SG4YXEBFZHivrRWRfGiVqy9aq/Y2talm7dWW1tKi3vdcEEQcEFrxV0hIKAsrlcroiRCWcKaZM7948yEAAkZIDPPM8n3/Xo9r9memfnlmclvzpw5v3PMOYeIiIRXJOgARERk95SoRURCTolaRCTklKhFREJOiVpEJOSUqEVEQk6JWkQk5JSoJWuY2edmNjzoOEQyTYlaRCTklKglq5lZrpn90cxWJrY/mllu4rY2ZjbLzNaa2Roze93MIonbrjezr8xsg5l9aGbDgv1LRKoXCzoAkX30M+AIoD/ggGeAG4GfA9cCK4D8xL5HAM7MegDjgcHOuZVm1hWIZjZskdSpRS3Z7nzgFudckXOuGLgZuCBxWynQAdjfOVfqnHvd+cltyoFc4FAzy3HOfe6c+zSQ6EVSoEQt2a4j8EWly18krgP4PfAJ8KKZfWZmEwCcc58APwZuAorMbIqZdUQkpJSoJdutBPavdLlL4jqccxucc9c65w4ARgPXJPuinXOPOueOTtzXAbdlNmyR1ClRS7bJMbO85AY8BtxoZvlm1gb4BfAwgJmNMrODzMyAdfguj7iZ9TCz4xM/Om4BNgPxYP4ckZopUUu2eQ6fWJNbHlAILAbeBxYAv0rsezDwT6AEeBv4q3PuFXz/9O+Ab4FvgLbADZn7E0T2jGnhABGRcFOLWkQk5JSoRURCTolaRCTklKhFREJOiVpEJOSUqEVEQk6JWkQk5JSoRURCTolaRCTklKhFREJOiVpEJOSUqEVEQk6JWkQk5JSoRURCTolaRCTklKhFREJOiVpEJORi6XjQNm3auK5du6bjoUVE6qT58+d/65zLr+q2tCTqrl27UlhYmI6HFhGpk8zsi+puU9eHiEjIKVGLiIScErWISMilpY9aRKQqpaWlrFixgi1btgQdSmDy8vLo3LkzOTk5Kd9HiVpEMmbFihU0bdqUrl27YmZBh5NxzjlWr17NihUr6NatW8r3U9eHiGTMli1baN26db1M0gBmRuvWrff4G4UStYhkVH1N0kl78/crUSdtWAWF98O/3wk6EhGRHdTvPuqtG2DZLHj/CfhsDri4v77XGXDCLdBiv0DDExGB+tiiLi+FD1+Apy6B3x8M0y+H1Z/Cd66Fy16HYybAh8/BXwbDnN/Btk1BRywiaXLTTTdx++23V3nbmDFjeOqppzIcUdXqR4vaOfhyrm85L5kGm1ZDw1Yw4Hzocw7sdxgk+4069PXXv/hzmPNbeO9hOPFWOPS07fuIiGRQ3U7UxR/55Lz4CVj7BcTyoMdI6HsuHHg8xBpUfb8WXeCcB+HzN+D5CfDkGNj/aDj5d9C+T0b/BJG66uaZS1i6cn2tPuahHZvxy+/22u0+v/71r3nwwQdp27Yt++23H4MGDarxcV9++WWuu+46ysrKGDx4MJMmTSI3N5cJEyYwY8YMYrEYJ554IrfffjtPPvkkN998M9FolObNm/Paa6/t899V9xL1hm/gg6mw+HH4ehFYBLodA8feAD1HQW7T1B+r69Fw2asw/wH416/g70Nh0Bg47kZo3Dpdf4GIpMn8+fOZMmUKCxcupKysjIEDB9aYqLds2cKYMWN4+eWX6d69OxdeeCGTJk3iggsuYNq0aSxfvhwzY+3atQDccsstzJ49m06dOlVct6/qVqKedgUsnuJ/FOzQH076DfQ+E5q23/vHjERh8FjofYbvs557t/8gOPan/vpo6tVFIrJdTS3fdHj99dc5/fTTadSoEQCjR4+u8T4ffvgh3bp1o3v37gBcdNFFTJw4kfHjx5OXl8fYsWMZNWoUo0aNAmDIkCGMGTOGc845hzPOOKNW4q47PybG474VfeAwuHKebwkfeeW+JenKGraEk2+DK96EjgPghevhb0fDp6/UzuOLSFaJxWLMnTuXs846i1mzZjFixAgA/va3v/GrX/2KL7/8kkGDBrF69ep9fq66k6g3rwFXDgefAPnd0/c8bXvCBdPhvEehbAs8dBo89n1Y83/pe04RqRVDhw5l+vTpbN68mQ0bNjBz5swa79OjRw8+//xzPvnkEwAeeughjjnmGEpKSli3bh0jR47kzjvvZNGiRQB8+umnHH744dxyyy3k5+fz5Zdf7nPcdafro2SVP23SNv3PZQaHnOJb7+9MhNf+ABMPgyPHw9E/hrzm6Y9BRPbYwIEDOffcc+nXrx9t27Zl8ODBNd4nLy+P+++/n7PPPrvix8TLL7+cNWvWcOqpp7Jlyxacc9xxxx0A/OQnP+Hjjz/GOcewYcPo16/fPsdtzrl9fpCdFRQUuIyv8PLpK751O+Y56Doks8+9fiX882bfPx7Lgx4n+2F/Bw2vfmSJSD20bNkyevbsGXQYgavqOJjZfOdcQVX7150W9cZif5qJFvXOmnWEM/4OR1zhx10vedqP127Y0o+/7nsO7HcEROpOT5OIZE7dSdQlRf60cZVrQ2ZGx/5+G/Fb+PRffvz24sdh/v3QvAv0Ocsn7bZqUYiExZVXXsmbb765w3VXX301F198cUAR7aoOJepVEM0NR/9wNAe6n+S3rSWw/FlfePPmn+CNO6BdH+h7NvQ+C5p3CjpakXpt4sSJQYdQo7qTqDcW+26PsJV55zaBfuf6raTId4ksfgJe+gW89EtfVNPnbDj0VGjYIuhoRSSE6k6naUlRsN0eqWjSFg6/DP7rZbhqga+W3PA1zPwR3H4wPP4D+Pe7QUcpIiFTtxJ1ED8k7q3WB8Kx18P4Qvivf0HBWD8X9n0nwYs3Qmn9XVNORHaUUqI2sxZm9pSZLTezZWZ2ZLoD22MbsyxRJ5lBp0F+wqcfLYSCi+GtP8PkY2HlwoCDE5EwSLVF/SfgBefcIUA/YFn6QtoL8Ths/BYaZ2Giriy3CYy6E86fClvWwj3DYM5tfg5tEcmoMM1HXWOiNrPmwFDgXgDn3Dbn3No0x7VnkuXj2diirsrBw+GHb0Ov02HOb+DeE6D4w6CjEpGApDLqoxtQDNxvZv2A+cDVzrmNlXcys3HAOIAuXbrUdpy7l8ny8Uxp2BLOvAcOGQWz/ttPsTrsl3D45Sqckbrh+Qnwzfu1+5jt+/huxN249dZbefjhh8nPz6+Yj/q6667b7X2Cno86lf/4GDAQmOScGwBsBCbsvJNzbrJzrsA5V5Cfn+HRFxXFLnUoUSf1Og1++A4ccCzMvgH+MRr+80XQUYlkpXnz5jF16lQWLVrE888/TypTXSTno3788cd5//33KSsrY9KkSaxevZpp06axZMkSFi9ezI033ghsn4960aJFzJgxo1biTqVFvQJY4ZxLjht7iioSdaCCLB/PhKbt4HtTfHn6CzfApCEw4jcw4ILwjRsXSVUNLd90ePPNNzn11FPJy8sjLy+P7373uzXeJyvmo3bOfQN8aWY9ElcNA5bWyrPXlmTXR9jHUe8LMxh4AfzwLV+mPuMqeOw8v6KNiGRcGOejvgp4xMwWA/2B3+zzM9emkqLwlI+nW4sucOEMGPE7+GwO/PUI+ODpoKMSyQpDhgxh5syZbNmyhZKSEmbNmlXjfbJmPmrn3EKgyun3QiGs5ePpEon4mfoOHAbTLoOnLobls2Dk7dCoVdDRiYTW4MGDGT16NH379qVdu3b06dOH5s1338DTfNS15aEzYPN/YFw9XBarvAzeuBNe/R00agMn/dpPrRqtO9O4SN0RhvmoS0pKaNKkCZs2bWLo0KFMnjyZgQMHZjSGPZ2Pum6M8yopgibtgo4iGNEYHPMTX4beqDVMHQt3DYB3JsHWDUFHJxI648aNo3///gwcOJAzzzwz40l6b9SNZtfGIug0IOgogtWhH1z+Onz4PLz9F3hhAsz5LQy62E8E1axj0BGKhMKjjz66w2XNR50J8fK6UT5eGyJR6DnKbysK/Zwhb90Fb0/0ixYcOR7a9w46SqnnnHNYiH5PyvR81HvT3Zz9XR+b6lj5eG3pXADnPOinUy24BJbOgL8NgYdO96vPpOG3CZGa5OXlsXr16r1KVnWBc47Vq1eTl5e3R/fL/hb1xkRVohJ11Vp1g5H/C8dO8EuCvft3n6zb9fYt7N5nagFeyZjOnTuzYsUKiouLgw4lMHl5eXTu3HmP7pP9iboul4/Xpkat4DvX+uT8/lO+W2T65fDyzb4Pe9DFWmFG0i4nJ4du3boFHUbWyf6uj7pePl7bYrkw4Hw/O98PpkJ+D/jnTXBnLz9Jzrqvgo5QRHaS/Ym6Ls6clwlmcNBwuPAZuOx1P0vfvLvhL4P9j4/lZUFHKCIJdSBRJ8rHc5sFHUn26tAXzvg7XDUfug6B2T+Fe46Hle8FHZmIUBcSdX0rH0+nll3h+0/AWff7yZ7uPt7P1re1JOjIROq17E/UJavq9qx5mWYGvc+AK+fCoDHwzl9h4uG+kEZEAlEHEnVx/S0fT6eGLfz6jZe8CLlN/ZSqj18A678OOjKReif7E/XGImiiFnXadDkcLnsNjv85fPwiTDwM5t7tFxQWkYzI7kSt8vHMiDWAodfBFW9Bp4Hw3HVw34mwaknQkYnUC9mdqCvKx9X1kRGtD4QLpsPpk2HNZ37B3Zd+Cds2BR2ZSJ2W3Ym6onxcXR8ZYwb9zoXxhdD3PHjzjzDpSPjk5aAjE6mzsjtRq3w8OI1awWkT4aJZEInBw2fA1Eth/cqgIxOpc+pGolZVYnC6fcf3XR8zAZY+40vRHxwNCx6CzWuDjk6kTsjuRK2Z88IhlgvH3eDHXg/9Caz7EmaMh9sPhinn+8V31Y8tsteye/Y8lY+HS6tucNxP4dgbYOUCeH8qfDDVL7zboImfT6TPWXDAsRDNCTpakayR3Yla5ePhZAadBvntxFvhizfh/Sd918jiKX5tx0NPgz5nw36H+1XVRaRaKSVqM/sc2ACUA2XVrZSbcSWr1O0RdpEodBvqt5G3+9Eh7z8JCx+FwnuhWWfoc6ZP2u1660NXpAp70qI+zjn3bdoi2RslxdB8z1ZKkADFcuGQkX7bWgIfPueT9tsT4c0/QZseftmwgku06oxIJdn9nVPl49krtwn0PQfOfxKu/QhOucPPL/LC9TBxsO/brqfr6onsLNVE7YAXzWy+mY2ragczG2dmhWZWmJH10OLlvo9aY6izX+PWMHgsjH0RfvC0/+HxqUv8NKufvxl0dCKBSzVRH+2cGwicDFxpZkN33sE5N9k5V+CcK8jPz0Ard9MacHGVj9c1Bw3zk0CdNsn/BvHASHj0PCj+MOjIRAKTUqJ2zn2VOC0CpgGHpTOolKh8vO6KRKH/9/2KM8N+6UeN/PUImHm1X9BApJ6pMVGbWWMza5o8D5wIfJDuwGqk8vG6L6chfOca+NFCOOwyeO8RuGsgvPJbrToj9UoqLep2wBtmtgiYCzzrnHshvWGlQOXj9Ufj1nDy72D8XOh+Irz6O7hrAMy7V4vwSr1QY6J2zn3mnOuX2Ho5536dicBqpPLx+qfVAXD2A3Dpy9D6IHj2Gt8lsvxZjRCROi17h+epfLz+6lwAFz8H5z3mC2SmfB/uHwkrCoOOTCQtsjtRq3y8/jLzhTNXvO3Xdlz9CdwzDB4+C5ZMh7KtQUcoUmuyd66PjUXq9hCIxnwlY59z/Irp8x+AJy+CvBa+LH3A+dChvz7QJatlb6JW+bhUltsEjvkf+M618NkcP5fIew/BvLuh7aF+uF+fc6Cpxt1L9sneRL2xCDoNCDoKCZtI1BfNHDTML1ywZJpP2i/e6Nd3PPgEn7S7n6z5RCRrZGeiVvm4pKJhCyi42G/FH8GiR2HRFPjoBWjYyneN9P8+dOinrhEJtexM1Coflz2V3x2G3wTH/xw+e8UXz8x/AOb+Hdr28gm77zn63UNCKTsTtcrHZW9FonDQcL9t/o9fJmzho/Diz+ClX0CPk2HwpX4VGrWyJSSyM1GXrPKn6vqQfdGwpZ+1b/BYKFoOCx/x2/JZ0Ppgn7D7fw/ymgcdqdRz2TmOuiQxjaq6PqS2tD3ELxv230vh9L/75PzC9fCHnjDzx7BqSdARSj2WnS1qdX1IuuTkQb/z/LbyPZh7Dyx6DObfD12OgsMuhUO+qxEjklFZ2qJW+bhkQMcBcNpEuGYZnHArbFjpFzT4Y2945TewfmXQEUo9kb2JWuXjkimNWsGQH8FV78H3n4T2feHV/4U7e8MTF8L/va5JoSStsrfrQ8OoJNMiET/NavcTYc1nUHgfLHgIlj4D+T39j5L9zoPcpkFHKnVMlraoVewiAWt1AJz4K7h2OZw60fdZP3ednyd74aMQjwcdodQhWZqoV+mHRAmHnIYw4Acw7lW45EVosT9MvwLuHwFfLw46Oqkjsi9Rx8th07camifhYgZdDoexL8Hov/hpVycfA89e5wtrRPZB9iXqZPm4uj4kjCIRGHiBX5h38KVQeC/8eRAs+Ie6Q2SvZV+i1hhqyQYNW8LI3/sukdYHw4yr4N7h8NWCoCOTLJR9iVrl45JNOvSFS17w1Y5rv4S7j4eZV/tvhiIpysJErfJxyTJmftjeVYVwxBV+SN+fB/pV1OPlQUcnWSD7ErW6PiRb5TWHEb+Fy9/wU6s+e41vYX85L+jIJORSTtRmFjWz98xsVjoDqlHJKpWPS3ZrdyiMmQVn3uvfz/cOh+lXbv+2KLKTPWlRXw0sS1cgKSsp9t0eKh+XbGYGfc6C8fPgqB/B4inwl0Hw9kQo3RJ0dBIyKSVqM+sMnALck95wUrCxSN0eUnfkNvXTq17xlp8EavZP4a7+MPduKNsadHQSEqm2qP8I/A9Q7UBQMxtnZoVmVlhcnMavcCofl7oovwdc+AxcNBNadt1ejj7vXijbFnR0ErAaE7WZjQKKnHPzd7efc26yc67AOVeQn5/GFq/Kx6Uu6zYULn4eLpgOzTr5Hxz/PNCv71heGnR0EpBUWtRDgNFm9jkwBTjezB5Oa1TVUfm41AdmcOBxMPZF+MFUP1PkzKsTFY4PKWHXQzUmaufcDc65zs65rsB5wL+ccz9Ie2RVUfm41CdmfhHeS1/282A3bAkzxsNfBsPCx6C8LOgIJUOyaxx1sipRXR9Sn5j5ObDHzYHzHoPcJjD9cph4GCx+QkUz9cAeJWrn3Bzn3Kh0BVOjimIXdX1IPWQGh4yEy16Hcx/2U6w+/V/w1yPg/ac06VMdlmUt6sRoEnV9SH1mBj2/6xP22Q+CRWHqWJh0lB8l8s0H6hapY7JrKS6Vj4tsF4lAr9Og52hYOg3m3OZHiQDkNIIO/aDTIOg00J+22F+FYlkquxK1ysdFdhWJQO8zodcZfi3HrxbAV/P9NvduKE8UzjRslUjcg7Yn8MZtgo1dUpJliVrl4yLVMoPWB/qt79n+uvJSWLUkkbgXwMoF8Mk/gcSq6S267Ji8Ow7wfd8SKtmVqFU+LrJnojnQsb/fBo/1123dAF8v2t7qXlEIS6b523Kb+SlZCy6Btj2Dilp2kl2JuqQImu8XdBQi2S23KXQ92m9JJUU+aX/wtK+CnDsZuhzlE/ahoyGWG1i4knWjPtSiFkmLJm2hx8lw5t1wzXI44VbY8DU8fSnc0RNe+oXv/5ZAZE+iVvm4SGY0bg1DfgRXLYALpsH+R8Fbf/GTRD10OiybpeF/GZY9XR+bVqt8XCSTIhE48Hi/rV/p5xmZ/wA8fj407QgDL4RBF0GzjkFHWudlT4u6RGOoRQLTrCMcez38+H0471G/Ss2rt8GdvWHK+X4kiSoj0yZ7WtQqHxcJXjQGh5zitzX/51vY7z0My2f5ebR7nQFtD4X87tD6YGjQKOiI64TsSdQqHxcJl1bd4ISb4bifwrKZUHgfvPkncMlJosyP084/xC+MkN/Dn2/THfJUtLYnsihRa+Y8kVCK5fr1H/uc5ZcPW/0pfPshFFfaPnsFyiutVNO0o291JxN3MpmrUrJK2ZOoNxapfFwk7GK5vv+63aE7Xl9eBmu/SCTu5fDtR/50wUNQunH7fi26wKGn+ZL4Dv1UhZyQPYla5eMi2Ssa217efsjI7dc7B+tWbG+BfzYH3vkrvHUXtDrA93n3TvR71+P//SxK1ForUaTOMYMW+/ntoOFw5JV+Jafls+CDqfDGHfD67dCmh29l9z4D2hwcdNQZlz2JemOxysdF6oNGrfwY7YEX+m/Sy56BD6bBnN/CnN9Auz7Q+3Tf2m7VLehoMyJ7EnVJkZ/dS0Tqjyb5MPhSv63/GpZO9/ORvHyL3zoO9K3sXqdD885BR5s22ZGoK8rHNTRPpN5q1gGOuMJva//tZ/z74Gl48Ua/7Xe4b2X3Og2atg862lqVHYla5eMiUlmLLjDkar+t/hSWPO27R164Hl6YAPsP8d0jPU+tE79tZUcJucrHRaQ6rQ+EoT+BH74FV86FY673w3mfvRb+0B3+caqvoNy0JuhI91p2JGqVj4tIKvJ7wHE3+IR9xVtw9DW+m2Tm1XD7wfDwmfDeI7B5bdCR7pEauz7MLA94DchN7P+Uc+6X6Q5sB8kWtbo+RCQVZtCul9+Ov9GvaLPkad+v/cwPYVYDOHCY/yGy+4jQl7Sn0ke9FTjeOVdiZjnAG2b2vHPunTTHtp26PkRkb5ltX45s+M1+7chk0v7oeV/xfPAJ25N2g8ZBR7yLGhO1c84BJYmLOYnNpTOoXWwsglieysdFZN+YQedBfjvhVlgx1yfsJdN9kU2T9nD2A7D/kUFHuoOU+qjNLGpmC4Ei4CXn3LtV7DPOzArNrLC4uLh2oywp9t0e9biEVERqWSQCXY6Ak2+Da5bChTN8a/rBUfDOJF/eHhIpJWrnXLlzrj/QGTjMzHpXsc9k51yBc64gP7+WuyhUPi4i6RSJwgHHwLhXfPfHCxPgqUtga0nN982APRr14ZxbC7wCjEhLNNXZWKwfEkUk/fKaw7kPw/CbfBXkPcPg24+DjqrmRG1m+WbWInG+IXACsDzNce2opEhViSKSGWZw9H/7hX03fguTj4OlMwINKZUWdQfgFTNbDMzD91HPSm9Ylah8XESCcMCxcNmrfmz2ExfAS78IbPX1VEZ9LAYGZCCWqql8XESC0rwzXPwczP6pX2bsqwVw1v0Z/80s/JWJFWOolahFJACxXDjlD3Da32DFPPj7UPhyXkZDCH+i3qhELSIh0P97MPYliDWA+0+GuXdnbAhf+BO1ysdFJCw69IVxc+DA4+G562DaZbBtU9qfNnsStcZRi0gYNGwJ35sCx90Ii5+Ae0/wU62mUfgTtcrHRSRsIhE45ifwg6dg/Vd+CN+Hz6fv6dL2yLWlpEjl4yISTgcNh3Gv+rUbHzsPXr7VDymuZdmRqNXtISJh1XJ/uGS2X4z3o9lQtrXWnyL8S3FtLPbL7oiIhFVOHoz+M2zdAA0a1frDZ0eLurFa1CKSBXKbpuVhw52oVT4uIhLyRK3ycRGRkCdqlY+LiIQ9Ua/yp0rUIlKPhTtRb0ws6aWuDxGpx8KdqFU+LiIS8kSt8nERkZAnapWPi4hkQaJWt4eI1HPhTtQbi6FJu6CjEBEJVLgTdckqlY+LSL0X3kQdL/eViRpDLSL1XHgTtcrHRUSAFBK1me1nZq+Y2VIzW2JmV2ciMJWPi4h4qcxHXQZc65xbYGZNgflm9pJzbmlaI1P5uIgIkEKL2jn3tXNuQeL8BmAZ0Cndgal8XETE26M+ajPrCgwA3k1LNJWp60NEBNiDRG1mTYCpwI+dc+uruH2cmRWaWWFxcfG+R1ayKlE+np4VE0REskVKidrMcvBJ+hHn3NNV7eOcm+ycK3DOFeTn18LY543FKh8XESG1UR8G3Assc87dkf6QElQ+LiICpNaiHgJcABxvZgsT28g0x6XycRGRhBqH5znn3gAy3/9Qsgo6Dcr404qIhE04KxNVPi4iUiGciTpZPq6uDxGRkCbqZFWiZs4TEQlrolaxi4hIUjgTtcrHRUQqhDNRq0UtIlIhpIla5eMiIknhTNQqHxcRqRDORF1SpG4PEZEEJWoRkZALZ6LeWKQx1CIiCeFL1CofFxHZQfgStcrHRUR2EL5ErfJxEZEdhDBRq9hFRKSy8CXqZPm4uj5ERIAwJmp1fYiI7CCEibpI5eMiIpWEL1GrfFxEZAfhS9SqShQR2YEStYhIyIUvUat8XERkBzUmajO7z8yKzOyDtEej8nERkV2k0qJ+ABiR5ji8jd+qfFxEZCc1Jmrn3GvAmgzE4rs9QF0fIiKVhKuPWuXjIiK7qLVEbWbjzKzQzAqLi4v37kFUPi4isotaS9TOucnOuQLnXEF+/l52Xah8XERkF+Hr+lD5uIjIDlIZnvcY8DbQw8xWmNnYtEWj8nERkV3EatrBOfe9TAQC+K4P/ZAoIrKDkHV9FCtRi4jsJFyJWuXjIiK7CE+idg6atofWBwYdiYhIqNTYR50xZnD5G0FHISISOuFpUYuISJWUqEVEQk6JWkQk5JSoRURCTolaRCTkQpWoV67djHMu6DBEREIlNMPzysrjnHLX6zRqEOPEXu04qVd7BndtRTSieT9EpH4LTaIud44bTu7J7CXf8Mi7/+b+Nz+nVeMGDO/ZlhG923PUgW3Iy4kGHaaISMZZOroaCgoKXGFh4V7fv2RrGa9+WMzsJd/wyvIiNmwto3GDKMce0paTerXnuB75NM3LqcWIRUSCZWbznXMFVd0WmhZ1ZU1yY5zStwOn9O3A1rJy3v50NbOXrOKlpat4dvHXNIhGOOqg1pzUqz3De7Yjv2lu0CGLiKRNKFvU1SmPOxb8+z/M/uAbZi/9hi/XbMYMCvZvyUm92nNSr/bs16pRrT+viEi67a5FnVWJujLnHMu+3sDsJd8we8k3LP9mA+Bb480b5lRsLRolzidOWzRsUOVtTXNjmBYsEJGA1MlEvbMvVm/kn8uKWPGfTazbXMq6TaWs21zK2s2lFZe3lcervX/EoHnDHBrnxmiYE6Vhgyh5OX5rmBPZ4bqGya3y5Qb+NCcaIRqxii2WPI0aUUteFyEa3X5b1Kzictz5ETDbyuOUlTvKyh2l8Tilicul5XHK4o7SsjilcUdZeZzSiuvjlMch7hw4fxpPnLpK5+POf9C5KvZJflhFzIiYPzWrdDliGGBmVe4Tixq5sSh5OZFdT3Oi5MX8aW4sQk60+tGh5XFHyZYy1m0uZf0W/xquT7yW2y+X7XKbc/jXJPF6+NcwssNrlLw9LxbZ5TWMmFEWd5TH/fEuj7vEZX+MK1+u2C/uKC/3lwFyokZONEIsGqFB1IhFI8QiRoNYhFgkUnG73yd53r8vAErL/eu9LfnaliXPx9lWtv313lbp+tJK74N43L+m5YnXtDxx2V/vKK98Pl75feGImpHXYPt7PHnsKl9X1fs+eUyjZrvEsyfnnXPEIv6YVT6OOYnrYlEjJ5I8bonbE8c0Fo1gFcev5ucrK3cVxy95vjxxXJL/Izv+f1BxbKv7n2qcG+Pnow7dqxyWdX3Ue2P/1o0Ze3S3am93zrGlNJ5I3ttYt2nHJJ68ftPWcraUlbN5WzmbS8tZt7mUVev8+c2l5WzZVs6m0nLK4xrvva+iESM3FiEvkbjzcqJsK4uzfnMpG7aW1XjfZnkxmiW+HTXLy6F98zwiZmwpLWdLaZzNpeWs2biNLcnXruK0+g/sbBWLJJOWEYn4D9BoJPlh6hsEtsN1O94eiUA08SFV+fgl/w/qo0hFA2WnxkrycqVjaYnTVo3T83tZnUnUNTEz/+nfIEr75nn7/Hil5fGKxJ18U2/aVkZZ3FVqicWJO7dLy6y8qhZZ4nLE2KUVsb3lsHNLzBKtie0tilhi3Hnyn3XHN1niTYVhkV3feL6tnGiRs1NrO7691RV34NipRZ5oiZSWO7aW+WOytaycrWVxtpb60y2VT0vju1y3pSxOTtRolpdTKQH7rqyKy4nTxg2ie91VFY87tpbFd0jeyYTkHLt8E/LnI4mWXqVvRZX3S5z694Z/7UvLdvw2lPyWVNGiq/TNqLTMfyMCaBCLVLzOOdEIDaIRcmLmT6ORSrdbxe2RNNYbOOcqXqPKx2pLaTmbt8V3aMSUxR05UauIMRYxcmKRithjUav2fE40ghmJb5KVvzH6lnDy22W1t8fjOMf24xbb/v9R5flY4hhWOh+t+FALVzdovUnUtS35ZmimYYJZJxLZ/qGdDg1iRgMi0CAtD59xZlbRDdgi6GDqqVCVkIuIyK6UqEVEQi6lRG1mI8zsQzP7xMwmpDsoERHZrsZEbWZRYCJwMnAo8D0z27vxJyIissdSaVEfBnzinPvMObcNmAKcmt6wREQkKZVE3Qn4stLlFYnrREQkA2rtx0QzG2dmhWZWWFxcXFsPKyJS76WSqL8C9qt0uXPiuh045yY75wqccwX5+fm1FZ+ISL1X41wfZhYDPgKG4RP0POD7zrklu7lPMfDFXsbUBvh2L++bCYpv3yi+faP49k2Y49vfOVdlK7fGykTnXJmZjQdmA1Hgvt0l6cR99rpJbWaF1U1MEgaKb98ovn2j+PZN2OOrTkol5M6554Dn0hyLiIhUQZWJIiIhF8ZEPTnoAGqg+PaN4ts3im/fhD2+KqVl4QAREak9YWxRi4hIJYEl6pomejKzXDN7PHH7u2bWNYOx7Wdmr5jZUjNbYmZXV7HPsWa2zswWJrZfZCq+xPN/bmbvJ557l3XPzLsrcfwWm9nADMbWo9JxWWhm683sxzvtk9HjZ2b3mVmRmX1Q6bpWZvaSmX2cOG1ZzX0vSuzzsZldlMH4fm9myxOv3zQza1HNfXf7XkhjfDeZ2VeVXsOR1dw37ZO6VRPf45Vi+9zMFlZz37Qfv33mEmt/ZXLDD/P7FDgAP736IuDQnfb5IfC3xPnzgMczGF8HYGDifFP8OPKd4zsWmBXE8Us8/+dAm93cPhJ4HjDgCODdAF/rb/BjRAM7fsBQYCDwQaXr/heYkDg/Abitivu1Aj5LnLZMnG+ZofhOBGKJ87dVFV8q74U0xncTcF0Kr/9u/9fTFd9Ot/8B+EVQx29ft6Ba1KlM9HQq8GDi/FPAMMvQ+jjOua+dcwsS5zcAy8i++U1OBf7hvHeAFmbWIYA4hgGfOuf2tgCqVjjnXgPW7HR15ffYg8BpVdz1JOAl59wa59x/gJeAEZmIzzn3onMuuXjkO/iq4EBUc/xSkZFJ3XYXXyJvnAM8VtvPmylBJepUJnqq2CfxZl0HtM5IdJUkulwGAO9WcfORZrbIzJ43s16ZjQwHvGhm881sXBW3h2UyrfOo/h8kyOMH0M4593Xi/DdAuyr2CctxvAT/DakqNb0X0ml8omvmvmq6jsJw/L4DrHLOfVzN7UEev5Tox8TdMLMmwFTgx8659TvdvAD/db4f8GdgeobDO9o5NxA/T/iVZjY0w89fIzNrAIwGnqzi5qCP3w6c/w4cyiFQZvYzoAx4pJpdgnovTAIOBPoDX+O7F8Loe+y+NR36/6WgEnUqEz1V7GN+vpHmwOqMROefMwefpB9xzj298+3OufXOuZLE+eeAHDNrk6n4nHNfJU6LgGn4r5iVpTSZVpqdDCxwzq3a+Yagj1/CqmR3UOK0qIp9Aj2OZjYGGAWcn/gw2UUK74W0cM6tcs6VO+fiwN3VPG/Qxy8GnAE8Xt0+QR2/PRFUop4HHGxm3RKtrvOAGTvtMwNI/sJ+FvCv6t6otS3Rp3UvsMw5d0c1+7RP9pmb2WH4Y5mRDxIza2xmTZPn8T86fbDTbjOACxOjP44A1lX6mp8p1bZkgjx+lVR+j10EPFPFPrOBE82sZeKr/YmJ69LOzEYA/wOMds5tqmafVN4L6Yqv8m8ep1fzvKn8r6fTcGC5c25FVTcGefz2SFC/YuJHJXyE/0X4Z4nrbsG/KQHy8F+ZPwHmAgdkMLaj8V+DFwMLE9tI4HLg8sQ+44El+F+x3wGOymB8BySed1EihuTxqxyf4ZdQ+xR4HyjI8OvbGJ94m1e6LrDjh//A+BooxfeTjsX/5vEy8DHwT6BVYt8C4J5K970k8T78BLg4g/F9gu/fTb4Hk6OgOgLP7e69kKH4Hkq8txbjk2+HneNLXN7lfz0T8SWufyD5nqu0b8aP375uqkwUEQk5/ZgoIhJyStQiIiGnRC0iEnJK1CIiIadELSISckrUIiIhp0QtIhJyStQiIiH3/770Au3UqfX0AAAAAElFTkSuQmCC\n", 889 | "text/plain": [ 890 | "
" 891 | ] 892 | }, 893 | "metadata": { 894 | "needs_background": "light" 895 | }, 896 | "output_type": "display_data" 897 | } 898 | ], 899 | "source": [ 900 | "plt.suptitle('Loss')\n", 901 | "plt.plot(hist.history['d_loss'], label='d_loss')\n", 902 | "plt.plot(hist.history['g_loss'], label='g_loss')\n", 903 | "plt.legend()\n", 904 | "plt.show()" 905 | ] 906 | }, 907 | { 908 | "cell_type": "markdown", 909 | "id": "d319a982-7ae5-4754-adcf-b490f17a79d6", 910 | "metadata": { 911 | "id": "d319a982-7ae5-4754-adcf-b490f17a79d6" 912 | }, 913 | "source": [ 914 | "# 5. Test Out the Generator" 915 | ] 916 | }, 917 | { 918 | "cell_type": "markdown", 919 | "id": "206ba81f-978a-4c31-9c3d-6ebe5a5bfc29", 920 | "metadata": { 921 | "id": "206ba81f-978a-4c31-9c3d-6ebe5a5bfc29" 922 | }, 923 | "source": [ 924 | "### 5.1 Generate Images" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": 211, 930 | "id": "c46f3d6a-8aa5-40d2-a5ac-67a0606a82f0", 931 | "metadata": {}, 932 | "outputs": [], 933 | "source": [ 934 | "generator.load_weights(os.path.join('archive', 'generatormodel.h5'))" 935 | ] 936 | }, 937 | { 938 | "cell_type": "code", 939 | "execution_count": 228, 940 | "id": "14cde11f-cb26-4ebf-ad04-2c64a54f871e", 941 | "metadata": {}, 942 | "outputs": [], 943 | "source": [ 944 | "imgs = generator.predict(tf.random.normal((16, 128, 1)))" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": 229, 950 | "id": "f745982f-c4d7-451f-91a7-f7c4341cb7b7", 951 | "metadata": {}, 952 | "outputs": [ 953 | { 954 | "data": { 955 | "image/png": "\n", 956 | "text/plain": [ 957 | "
" 958 | ] 959 | }, 960 | "metadata": { 961 | "needs_background": "light" 962 | }, 963 | "output_type": "display_data" 964 | } 965 | ], 966 | "source": [ 967 | "fig, ax = plt.subplots(ncols=4, nrows=4, figsize=(10,10))\n", 968 | "for r in range(4): \n", 969 | " for c in range(4): \n", 970 | " ax[r][c].imshow(imgs[(r+1)*(c+1)-1])" 971 | ] 972 | }, 973 | { 974 | "cell_type": "markdown", 975 | "id": "5137cffa-784d-4076-beef-0a067b86d3aa", 976 | "metadata": { 977 | "id": "5137cffa-784d-4076-beef-0a067b86d3aa" 978 | }, 979 | "source": [ 980 | "### 5.2 Save the Model" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": null, 986 | "id": "a7011d68-ef71-4377-91e2-e26a02fab382", 987 | "metadata": {}, 988 | "outputs": [], 989 | "source": [ 990 | "generator.save('generator.h5')\n", 991 | "discriminator.save('discriminator.h5')" 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "execution_count": null, 997 | "id": "d14c2bd3-a344-4ac1-b2ee-6c90420368e6", 998 | "metadata": {}, 999 | "outputs": [], 1000 | "source": [] 1001 | } 1002 | ], 1003 | "metadata": { 1004 | "accelerator": "GPU", 1005 | "colab": { 1006 | "background_execution": "on", 1007 | "collapsed_sections": [ 1008 | "206ba81f-978a-4c31-9c3d-6ebe5a5bfc29" 1009 | ], 1010 | "name": "FashionGAN.ipynb", 1011 | "provenance": [] 1012 | }, 1013 | "kernelspec": { 1014 | "display_name": "fashgan", 1015 | "language": "python", 1016 | "name": "fashgan" 1017 | }, 1018 | "language_info": { 1019 | "codemirror_mode": { 1020 | "name": "ipython", 1021 | "version": 3 1022 | }, 1023 | "file_extension": ".py", 1024 | "mimetype": "text/x-python", 1025 | "name": "python", 1026 | "nbconvert_exporter": "python", 1027 | "pygments_lexer": "ipython3", 1028 | "version": "3.9.7" 1029 | } 1030 | }, 1031 | "nbformat": 4, 1032 | "nbformat_minor": 5 1033 | } 1034 | -------------------------------------------------------------------------------- /generatormodel.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicknochnack/GANBasics/b0eabe8d22fd1e87f0f4408fbce1b6bd4b3d91b1/generatormodel.h5 --------------------------------------------------------------------------------