├── chapter12 └── images │ ├── dali.jpg │ ├── gem.jpeg │ ├── escher.jpg │ ├── modern.jpg │ ├── monet.jpg │ ├── teddy.jpeg │ ├── einstein.jpg │ ├── escher.jpeg │ ├── picasso.jpg │ ├── pollock.jpg │ ├── honest_abe.jpeg │ ├── humming_bird.jpeg │ └── vincent-van-gogh.jpg ├── chapter13 ├── images │ ├── butterfly.jpg │ └── cats_dogs.jpg └── ch13.ipynb ├── chapter06 └── ch06.ipynb ├── chapter08 └── ch08.ipynb └── chapter14 └── ch14.ipynb /chapter12/images/dali.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/dali.jpg -------------------------------------------------------------------------------- /chapter12/images/gem.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/gem.jpeg -------------------------------------------------------------------------------- /chapter12/images/escher.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/escher.jpg -------------------------------------------------------------------------------- /chapter12/images/modern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/modern.jpg -------------------------------------------------------------------------------- /chapter12/images/monet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/monet.jpg -------------------------------------------------------------------------------- /chapter12/images/teddy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/teddy.jpeg -------------------------------------------------------------------------------- /chapter12/images/einstein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/einstein.jpg -------------------------------------------------------------------------------- /chapter12/images/escher.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/escher.jpeg -------------------------------------------------------------------------------- /chapter12/images/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/picasso.jpg -------------------------------------------------------------------------------- /chapter12/images/pollock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/pollock.jpg -------------------------------------------------------------------------------- /chapter13/images/butterfly.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter13/images/butterfly.jpg -------------------------------------------------------------------------------- /chapter13/images/cats_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter13/images/cats_dogs.jpg -------------------------------------------------------------------------------- /chapter12/images/honest_abe.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/honest_abe.jpeg -------------------------------------------------------------------------------- /chapter12/images/humming_bird.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/humming_bird.jpeg -------------------------------------------------------------------------------- /chapter12/images/vincent-van-gogh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paperd/deep-learning-models/HEAD/chapter12/images/vincent-van-gogh.jpg -------------------------------------------------------------------------------- /chapter13/ch13.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "ch13.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "E0DIZQALOE23" 21 | }, 22 | "source": [ 23 | "# Object Detection\n", 24 | "\n", 25 | "Image classification involves assigning a class label to an image, whereas object localization involves drawing a bounding box around one or more objects in an image. Object detection is more challenging and combines these two tasks and draws a bounding box around each object of interest in the image and assigns them a class label.\n", 26 | "\n", 27 | "For an excellent overview, peruse:\n", 28 | "\n", 29 | "https://www.fritz.ai/object-detection/\n", 30 | "\n", 31 | "Technical resources:\n", 32 | "\n", 33 | "https://machinelearningmastery.com/object-recognition-with-deep-learning/\n", 34 | "\n", 35 | "https://www.tensorflow.org/hub/tutorials/object_detection\n", 36 | "\n", 37 | "https://www.tensorflow.org/hub/tutorials/tf2_object_detection" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "KyHEcXv30Bhf" 44 | }, 45 | "source": [ 46 | "**Image classification** predicts the type (or class) of an object in an image.\n", 47 | "* Input: an image with a single object, such as a photograph.\n", 48 | "* Output: a class label (e.g. one or more integers that are mapped to class labels).\n", 49 | "\n", 50 | "**Object localization** involves locating the presence of objects in an image and indicating their location with a bounding box.\n", 51 | "* Input: an image with one or more objects, such as a photograph.\n", 52 | "* Output: one or more bounding boxes (e.g. defined by a point, width, and height).\n", 53 | "\n", 54 | "**Object Detection** involved locating the presence of objects with a bounding box and types (or classes) of the located objects in an image.\n", 55 | "* Input: an image with one or more objects, such as a photograph.\n", 56 | "* Output: one or more bounding boxes (e.g. defined by a point, width, and height), and a class label for each bounding box." 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "id": "ioHCOdEcTX1X" 63 | }, 64 | "source": [ 65 | "# Import **tensorflow** library" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": { 71 | "id": "VpygsU5NTY4O" 72 | }, 73 | "source": [ 74 | "Import library and alias it:" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "metadata": { 80 | "id": "prmzt6Q0TdMM" 81 | }, 82 | "source": [ 83 | "import tensorflow as tf" 84 | ], 85 | "execution_count": null, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "z1PaVreZTcPm" 92 | }, 93 | "source": [ 94 | "# GPU Hardware Accelerator\n", 95 | "\n", 96 | "To vastly speed up processing, we can use the GPU available from the Google Colab cloud service. Colab provides a free Tesla K80 GPU of about 12 GB. It’s very easy to enable the GPU in a Colab notebook:\n", 97 | "\n", 98 | "1.\tclick **Runtime** in the top left menu\n", 99 | "2.\tclick **Change runtime** type from the drop-down menu\n", 100 | "3.\tchoose **GPU** from the Hardware accelerator drop-down menu\n", 101 | "4.\tclick **SAVE**" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "nu9mIFhNTgeW" 108 | }, 109 | "source": [ 110 | "Verify that GPU is available:" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "id": "P-LVWYafThG_" 117 | }, 118 | "source": [ 119 | "tf.__version__, tf.test.gpu_device_name()" 120 | ], 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": { 127 | "id": "GcllIuuxVKfX" 128 | }, 129 | "source": [ 130 | "# Import Requisite Libraries" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "id": "n5eRwZX4VN8e" 137 | }, 138 | "source": [ 139 | "Enable access to the TF-hub module:" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "64LS09veVKkf" 146 | }, 147 | "source": [ 148 | "import tensorflow_hub as hub" 149 | ], 150 | "execution_count": null, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "jXmVNSq125tR" 157 | }, 158 | "source": [ 159 | "For processing an image:" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "metadata": { 165 | "id": "gfSSdjOs2xn5" 166 | }, 167 | "source": [ 168 | "import matplotlib.pyplot as plt\n", 169 | "import tempfile\n", 170 | "from six.moves.urllib.request import urlopen\n", 171 | "from six import BytesIO" 172 | ], 173 | "execution_count": null, 174 | "outputs": [] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "id": "ShgoLJx03I57" 180 | }, 181 | "source": [ 182 | "For drawing onto an image:" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "metadata": { 188 | "id": "BBklNNna2xqH" 189 | }, 190 | "source": [ 191 | "from PIL import Image\n", 192 | "from PIL import ImageColor\n", 193 | "from PIL import ImageDraw\n", 194 | "from PIL import ImageFont\n", 195 | "from PIL import ImageOps" 196 | ], 197 | "execution_count": null, 198 | "outputs": [] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "id": "cBPWSJGt3Tv-" 204 | }, 205 | "source": [ 206 | "General library:" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "metadata": { 212 | "id": "04F-SfvO2weY" 213 | }, 214 | "source": [ 215 | "import numpy as np" 216 | ], 217 | "execution_count": null, 218 | "outputs": [] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": { 223 | "id": "S4FHO1NvlCKj" 224 | }, 225 | "source": [ 226 | "# Create Functions" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "QPWJrS7A3hhw" 233 | }, 234 | "source": [ 235 | "Display an image:" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "metadata": { 241 | "id": "4rHLLedC3hnY" 242 | }, 243 | "source": [ 244 | "def display_image(image):\n", 245 | " fig = plt.figure(figsize=(20, 15))\n", 246 | " plt.grid(False)\n", 247 | " plt.imshow(image)\n", 248 | " plt.axis('off')" 249 | ], 250 | "execution_count": null, 251 | "outputs": [] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": { 256 | "id": "nBDlPpFr3h2x" 257 | }, 258 | "source": [ 259 | "Draw bounding box on image:" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "metadata": { 265 | "id": "wFw5dU2lOY6s" 266 | }, 267 | "source": [ 268 | "def draw_bounding_box_on_image(\n", 269 | " image, ymin, xmin, ymax, xmax,\n", 270 | " color, font, thickness=4, display_str_list=()):\n", 271 | " \"\"\"Adds a bounding box to an image.\"\"\"\n", 272 | " draw = ImageDraw.Draw(image)\n", 273 | " im_width, im_height = image.size\n", 274 | " (left, right, top, bottom) = (\n", 275 | " xmin * im_width, xmax * im_width,\n", 276 | " ymin * im_height, ymax * im_height)\n", 277 | " draw.line([(left, top), (left, bottom),\n", 278 | " (right, bottom), (right, top),\n", 279 | " (left, top)],\n", 280 | " width=thickness, fill=color)\n", 281 | " # If the total height of the display strings added to the top of the bounding\n", 282 | " # box exceeds the top of the image, stack the strings below the bounding box\n", 283 | " # instead of above.\n", 284 | " display_str_heights = [font.getsize(ds)[1]\n", 285 | " for ds in display_str_list]\n", 286 | " # Each display_str has a top and bottom margin of 0.05x.\n", 287 | " total_display_str_height = (\n", 288 | " 1 + 2 * 0.05) * sum(display_str_heights)\n", 289 | " if top > total_display_str_height:\n", 290 | " text_bottom = top\n", 291 | " else:\n", 292 | " text_bottom = top + total_display_str_height\n", 293 | " # Reverse list and print from bottom to top.\n", 294 | " for display_str in display_str_list[::-1]:\n", 295 | " text_width, text_height = font.getsize(display_str)\n", 296 | " margin = np.ceil(0.05 * text_height)\n", 297 | " draw.rectangle(\n", 298 | " [(left, text_bottom - text_height - 2 * margin),\n", 299 | " (left + text_width, text_bottom)], fill=color)\n", 300 | " draw.text(\n", 301 | " (left + margin, text_bottom - text_height - margin),\n", 302 | " display_str, fill='black', font=font)\n", 303 | " text_bottom -= text_height - 2 * margin" 304 | ], 305 | "execution_count": null, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": { 311 | "id": "GLuQrUEI4W7j" 312 | }, 313 | "source": [ 314 | "Draw boxes:" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "metadata": { 320 | "id": "LKYEjDCV4XDh" 321 | }, 322 | "source": [ 323 | "def draw_boxes(\n", 324 | " image, boxes, class_names, scores,\n", 325 | " max_boxes=10, min_score=0.1):\n", 326 | " # Overlay labeled boxes on an image with formatted scores and label names.\n", 327 | " colors = list(ImageColor.colormap.values())\n", 328 | " one = '/usr/share/fonts/truetype/liberation/'\n", 329 | " two = 'LiberationSansNarrow-Regular.ttf'\n", 330 | " font_url = one + two\n", 331 | " try:\n", 332 | " font = ImageFont.truetype(font_url, 25)\n", 333 | " except IOError:\n", 334 | " print('Font not found, using default font.')\n", 335 | " font = ImageFont.load_default()\n", 336 | " for i in range(min(boxes.shape[0], max_boxes)):\n", 337 | " if scores[i] >= min_score:\n", 338 | " ymin, xmin, ymax, xmax = tuple(boxes[i])\n", 339 | " display_str = '{}: {}%'.format(\n", 340 | " class_names[i].decode('ascii'),\n", 341 | " int(100 * scores[i]))\n", 342 | " color = colors[hash(class_names[i]) % len(colors)]\n", 343 | " image_pil = Image.fromarray(\n", 344 | " np.uint8(image)).convert('RGB')\n", 345 | " draw_bounding_box_on_image(\n", 346 | " image_pil, ymin, xmin, ymax, xmax,\n", 347 | " color, font, display_str_list=[display_str])\n", 348 | " np.copyto(image, np.array(image_pil))\n", 349 | " return image" 350 | ], 351 | "execution_count": null, 352 | "outputs": [] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": { 357 | "id": "KAJ6T8rwlG_r" 358 | }, 359 | "source": [ 360 | "# Load a Module" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "UsDHqwFGume8" 367 | }, 368 | "source": [ 369 | "Load an object detection module and apply on the downloaded image:" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "metadata": { 375 | "id": "eRzfuloFch98" 376 | }, 377 | "source": [ 378 | "p1 = 'https://tfhub.dev/google/faster_rcnn/'\n", 379 | "p2 = 'openimages_v4/inception_resnet_v2/1'\n", 380 | "URL = p1 + p2\n", 381 | "module_handle = URL\n", 382 | "obj_detect = hub.load(module_handle).signatures['default']" 383 | ], 384 | "execution_count": null, 385 | "outputs": [] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": { 390 | "id": "ecLagjrvoDwf" 391 | }, 392 | "source": [ 393 | "# Load an Image from Google Drive" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "id": "40HKDybrap1O" 400 | }, 401 | "source": [ 402 | "Mount Google Drive to Colab:" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "metadata": { 408 | "id": "WjJU3p7ZWWA9" 409 | }, 410 | "source": [ 411 | "from google.colab import drive\n", 412 | "drive.mount('/content/gdrive')" 413 | ], 414 | "execution_count": null, 415 | "outputs": [] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "id": "xzORL-s8hfuB" 421 | }, 422 | "source": [ 423 | "Be sure that the image is in the *appropriate directory* in **your** Google Drive!" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": { 429 | "id": "DCJtHQI8av6u" 430 | }, 431 | "source": [ 432 | "Access and display the image:" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "metadata": { 438 | "id": "hopHBZv5WWF0" 439 | }, 440 | "source": [ 441 | "img_path = 'gdrive/My Drive/Colab Notebooks/images/cats_dogs.jpg'\n", 442 | "pil_image = Image.open(img_path)\n", 443 | "display_image(pil_image)" 444 | ], 445 | "execution_count": null, 446 | "outputs": [] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": { 451 | "id": "srFfloYEmQyf" 452 | }, 453 | "source": [ 454 | "Convert the JPEG image to a PIL image and display it. The Python Imaging Library (PIL) is a library that supports opening, manipulating, and saving many different image file formats. It is also known as the Pillow library." 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": { 460 | "id": "XjDrsbPcf5rb" 461 | }, 462 | "source": [ 463 | "Check image size:" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "metadata": { 469 | "id": "iLAeneBzf5xW" 470 | }, 471 | "source": [ 472 | "pil_image.size" 473 | ], 474 | "execution_count": null, 475 | "outputs": [] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": { 480 | "id": "EbguQJnPa4X_" 481 | }, 482 | "source": [ 483 | "# Prepare the Image" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": { 489 | "id": "P6aUN33EbA8X" 490 | }, 491 | "source": [ 492 | "Generate a temporary path for the image file:" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "metadata": { 498 | "id": "3XRWXtATVZ5w" 499 | }, 500 | "source": [ 501 | "_, filename = tempfile.mkstemp(suffix='.jpg')\n", 502 | "filename" 503 | ], 504 | "execution_count": null, 505 | "outputs": [] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": { 510 | "id": "p3Cm2VY1bQqO" 511 | }, 512 | "source": [ 513 | "Prepare the image for processing and save it to the temporary file path:" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "metadata": { 519 | "id": "bMGTrcS3VeAn" 520 | }, 521 | "source": [ 522 | "pil_image_rgb = pil_image.convert('RGB')\n", 523 | "pil_image_rgb.save(filename, format='JPEG', quality=90)\n", 524 | "print('Image downloaded to %s.' % filename)\n", 525 | "display_image(pil_image)" 526 | ], 527 | "execution_count": null, 528 | "outputs": [] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": { 533 | "id": "HerwlpasbpGw" 534 | }, 535 | "source": [ 536 | "# Run Object Detection on the Image" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": { 542 | "id": "tIBbI_FTbvNA" 543 | }, 544 | "source": [ 545 | "Create a function to load the image:" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "z0GKiXdHXm8l" 552 | }, 553 | "source": [ 554 | "def load_img(path):\n", 555 | " img = tf.io.read_file(path)\n", 556 | " img = tf.image.decode_jpeg(img, channels=3)\n", 557 | " return img" 558 | ], 559 | "execution_count": null, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": { 565 | "id": "aGoscR93kXO6" 566 | }, 567 | "source": [ 568 | " The function loads the image and prepares it for the pretrained model." 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": { 574 | "id": "HVoX8yK3b1WG" 575 | }, 576 | "source": [ 577 | "Create a function to run object detection:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "metadata": { 583 | "id": "ITZL7Hm4Xs9N" 584 | }, 585 | "source": [ 586 | "def run_detector(detector, path):\n", 587 | " img = load_img(path)\n", 588 | " converted_img = tf.image.convert_image_dtype(\n", 589 | " img, tf.float32)[tf.newaxis, ...]\n", 590 | " result = detector(converted_img)\n", 591 | " result = {key:value.numpy()\n", 592 | " for key,value in result.items()}\n", 593 | " print(\"Found %d objects.\" %\\\n", 594 | " len(result[\"detection_scores\"]))\n", 595 | " image_with_boxes = draw_boxes(\n", 596 | " img.numpy(), result[\"detection_boxes\"],\n", 597 | " result[\"detection_class_entities\"],\n", 598 | " result[\"detection_scores\"])\n", 599 | " display_image(image_with_boxes)" 600 | ], 601 | "execution_count": null, 602 | "outputs": [] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": { 607 | "id": "_5IvmbkFcDwP" 608 | }, 609 | "source": [ 610 | "Invoke the detector:" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "metadata": { 616 | "id": "BO-5-gvPVeC2" 617 | }, 618 | "source": [ 619 | "run_detector(obj_detect, filename)" 620 | ], 621 | "execution_count": null, 622 | "outputs": [] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": { 627 | "id": "NNODcAcxqM4a" 628 | }, 629 | "source": [ 630 | "The detector did really well with this image!" 631 | ] 632 | }, 633 | { 634 | "cell_type": "markdown", 635 | "metadata": { 636 | "id": "ay-rQAFT3gyf" 637 | }, 638 | "source": [ 639 | "Let's try another one:" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "metadata": { 645 | "id": "kYsnzP9_3g5h" 646 | }, 647 | "source": [ 648 | "img_path = 'gdrive/My Drive/Colab Notebooks/images/butterfly.jpg'\n", 649 | "pil_image = Image.open(img_path)\n", 650 | "display_image(pil_image)" 651 | ], 652 | "execution_count": null, 653 | "outputs": [] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": { 658 | "id": "qYJH5Aht3x_g" 659 | }, 660 | "source": [ 661 | "Process:" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "metadata": { 667 | "id": "sVNrKuPQ3yHG" 668 | }, 669 | "source": [ 670 | "_, filename = tempfile.mkstemp(suffix='.jpg')\n", 671 | "pil_image_rgb = pil_image.convert('RGB')\n", 672 | "pil_image_rgb.save(filename, format='JPEG', quality=90)\n", 673 | "print('Image downloaded to %s.' % filename)" 674 | ], 675 | "execution_count": null, 676 | "outputs": [] 677 | }, 678 | { 679 | "cell_type": "markdown", 680 | "metadata": { 681 | "id": "0wmfNIy03t46" 682 | }, 683 | "source": [ 684 | "Run detector:" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "metadata": { 690 | "id": "7XSDT4-U3uB3" 691 | }, 692 | "source": [ 693 | "run_detector(obj_detect, filename)" 694 | ], 695 | "execution_count": null, 696 | "outputs": [] 697 | }, 698 | { 699 | "cell_type": "markdown", 700 | "metadata": { 701 | "id": "1LUdwtCtk2oI" 702 | }, 703 | "source": [ 704 | "# Download Images from Wikimedia Commons\n", 705 | "\n", 706 | "We have **already located images** from Wikimedia Commons!" 707 | ] 708 | }, 709 | { 710 | "cell_type": "markdown", 711 | "metadata": { 712 | "id": "P5OIRTMT5kNn" 713 | }, 714 | "source": [ 715 | "## Get Your Own Images\n", 716 | "\n", 717 | "However, you can locate your own images from Wikimedia Commons by following a few simple steps:\n", 718 | "\n", 719 | "1. go to the following URL: https://commons.wikimedia.org/wiki/Main_Page\n", 720 | "2. click **Images**\n", 721 | "3. click on an image\n", 722 | "4. right click the image\n", 723 | "5. select 'Copy link address' from the drop-down menu\n", 724 | "6. paste the link address into a code cell\n", 725 | "7. surround the link address with single or double quotes\n", 726 | "8. assign to a variable " 727 | ] 728 | }, 729 | { 730 | "cell_type": "markdown", 731 | "metadata": { 732 | "id": "cu67rIHhaaFX" 733 | }, 734 | "source": [ 735 | "## Create a Function to Download an Image" 736 | ] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "metadata": { 741 | "id": "ajBJgW9GlBCR" 742 | }, 743 | "source": [ 744 | "Create a function to download, process, and save an image to a temporary file path:" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "metadata": { 750 | "id": "1mpV9YdWRlV4" 751 | }, 752 | "source": [ 753 | "def download_and_resize_image(\n", 754 | " url, new_width=256, new_height=256,\n", 755 | " display=False):\n", 756 | " _, filename = tempfile.mkstemp(suffix='.jpg')\n", 757 | " response = urlopen(url)\n", 758 | " image_data = response.read()\n", 759 | " image_data = BytesIO(image_data)\n", 760 | " pil_image = Image.open(image_data)\n", 761 | " pil_image = ImageOps.fit(\n", 762 | " pil_image, (new_width, new_height),\n", 763 | " Image.ANTIALIAS)\n", 764 | " pil_image_rgb = pil_image.convert('RGB')\n", 765 | " pil_image_rgb.save(\n", 766 | " filename, format='JPEG', quality=90)\n", 767 | " print('Image downloaded to %s.' % filename)\n", 768 | " if display:\n", 769 | " display_image(pil_image)\n", 770 | " return filename" 771 | ], 772 | "execution_count": null, 773 | "outputs": [] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": { 778 | "id": "Z1bTQoHOlWz7" 779 | }, 780 | "source": [ 781 | "The function generates a temporary path for the image file. It then reads the image file from the supplied URL. The function continues by converting the image file to a PIL image. The PIL image is then resized, converted to RGB, and saved to the temporary file path." 782 | ] 783 | }, 784 | { 785 | "cell_type": "markdown", 786 | "metadata": { 787 | "id": "t2_x4X-_6TGn" 788 | }, 789 | "source": [ 790 | "## Load an Image from a URL" 791 | ] 792 | }, 793 | { 794 | "cell_type": "markdown", 795 | "metadata": { 796 | "id": "jpN8_TVilHIB" 797 | }, 798 | "source": [ 799 | "Load an image from a Wikimedia Commons URL:" 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "CoRu59zqQ-39" 806 | }, 807 | "source": [ 808 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/7/79/'\n", 809 | "p2 = 'At_taverna_under_the_church%2C_Ano_Potamia%2C_Naxos%'\n", 810 | "p3 = '2C_190574.jpg'\n", 811 | "URL = p1 + p2 + p3\n", 812 | "\n", 813 | "downloaded_image_path = download_and_resize_image(\n", 814 | " URL, 1280, 856, True)" 815 | ], 816 | "execution_count": null, 817 | "outputs": [] 818 | }, 819 | { 820 | "cell_type": "markdown", 821 | "metadata": { 822 | "id": "PJ4le3Kv5I4Q" 823 | }, 824 | "source": [ 825 | "The source for the image is located at:\n", 826 | "\n", 827 | "https://commons.wikimedia.org/wiki/File:At_taverna_under_the_church,_Ano_Potamia,_Naxos,_190574.jpg" 828 | ] 829 | }, 830 | { 831 | "cell_type": "markdown", 832 | "metadata": { 833 | "id": "7flJ9_Lf5arh" 834 | }, 835 | "source": [ 836 | "# Run Object Detection" 837 | ] 838 | }, 839 | { 840 | "cell_type": "markdown", 841 | "metadata": { 842 | "id": "eNyNz8l76nSv" 843 | }, 844 | "source": [ 845 | "Run object detection with the function we created earlier in this notebook:" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "metadata": { 851 | "id": "OjC_YBCNQY6x" 852 | }, 853 | "source": [ 854 | "run_detector(obj_detect, downloaded_image_path)" 855 | ], 856 | "execution_count": null, 857 | "outputs": [] 858 | }, 859 | { 860 | "cell_type": "markdown", 861 | "metadata": { 862 | "id": "d3psLQ1cqfeD" 863 | }, 864 | "source": [ 865 | "Pretty good. But, not perfect." 866 | ] 867 | }, 868 | { 869 | "cell_type": "markdown", 870 | "metadata": { 871 | "id": "O0YdtejOqpQV" 872 | }, 873 | "source": [ 874 | "Let's try some more images. Piece together some paths:" 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "metadata": { 880 | "id": "cHoDWnFYSu5B" 881 | }, 882 | "source": [ 883 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/4/45/'\n", 884 | "p2 = 'Green_Dragon_Tavern_%2836196%29.jpg'\n", 885 | "tavern = p1 + p2\n", 886 | "\n", 887 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/3/31/'\n", 888 | "p2 = 'Circus_Circus_Hotel-Casino_sign.jpg'\n", 889 | "casino = p1 + p2\n", 890 | "\n", 891 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/9/91/'\n", 892 | "p2 = 'Leon_hot_air_balloon_festival_2010.jpg'\n", 893 | "balloon = p1 + p2\n", 894 | "\n", 895 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/d/d8/'\n", 896 | "p2 = '2012_Festival_of_Sail_-_7943922284.jpg'\n", 897 | "sail = p1 + p2\n", 898 | "\n", 899 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/a/ab/'\n", 900 | "p2 = '17_mai_2018.jpg'\n", 901 | "flag = p1 + p2\n", 902 | "\n", 903 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/4/43/'\n", 904 | "p2 = 'Fruit_baskets.jpg'\n", 905 | "basket= p1 + p2\n", 906 | "\n", 907 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/c/c7/'\n", 908 | "p2 = 'Fruit_stands%2C_Rue_de_Seine%2C_Paris_22_May_2014.jpg'\n", 909 | "stand= p1 + p2\n", 910 | "\n", 911 | "p1 = 'https://upload.wikimedia.org/wikipedia/commons/9/95/'\n", 912 | "p2 = 'Wine_tasting_%40_brown_brothers.jpg'\n", 913 | "wine = p1 + p2" 914 | ], 915 | "execution_count": null, 916 | "outputs": [] 917 | }, 918 | { 919 | "cell_type": "markdown", 920 | "metadata": { 921 | "id": "6zm_g-zDVYH8" 922 | }, 923 | "source": [ 924 | "Create a function to detect images:" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "metadata": { 930 | "id": "WiVSKq0wqwZa" 931 | }, 932 | "source": [ 933 | "def detect_img(image_url):\n", 934 | " image_path = download_and_resize_image(image_url, 640, 480)\n", 935 | " run_detector(obj_detect, image_path)" 936 | ], 937 | "execution_count": null, 938 | "outputs": [] 939 | }, 940 | { 941 | "cell_type": "markdown", 942 | "metadata": { 943 | "id": "SH9NQ6rKWgmQ" 944 | }, 945 | "source": [ 946 | "Run object detection on one of the images:" 947 | ] 948 | }, 949 | { 950 | "cell_type": "code", 951 | "metadata": { 952 | "id": "dKcbX8XDVAIl" 953 | }, 954 | "source": [ 955 | "detect_img(wine)" 956 | ], 957 | "execution_count": null, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "markdown", 962 | "metadata": { 963 | "id": "wtpHJW2HrHn8" 964 | }, 965 | "source": [ 966 | "Try another one:" 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "metadata": { 972 | "id": "VDgw71-iaWo2" 973 | }, 974 | "source": [ 975 | "detect_img(sail)" 976 | ], 977 | "execution_count": null, 978 | "outputs": [] 979 | }, 980 | { 981 | "cell_type": "markdown", 982 | "metadata": { 983 | "id": "PSxuN78P7KRo" 984 | }, 985 | "source": [ 986 | "Try some of the other scenes." 987 | ] 988 | }, 989 | { 990 | "cell_type": "markdown", 991 | "metadata": { 992 | "id": "cz1Xs_t-r3j1" 993 | }, 994 | "source": [ 995 | "# Find the Source\n", 996 | "\n", 997 | "We can translate the JPEG link to find the source:\n", 998 | "\n", 999 | "1. substitute **commons** for *upload*\n", 1000 | "2. change *wikipedia* to **wiki**\n", 1001 | "3. substitute *commons/(number)/(number)* for **File:**\n", 1002 | "4. translate the *%(number)* to **HTML encoded equivalent**\n", 1003 | "\n", 1004 | "Find the encoded equivalent:\n", 1005 | "\n", 1006 | "https://krypted.com/utilities/html-encoding-reference/" 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "markdown", 1011 | "metadata": { 1012 | "id": "k-2mAQrj2XHU" 1013 | }, 1014 | "source": [ 1015 | "## The First One\n", 1016 | "\n", 1017 | "Let's try the tavern image:\n", 1018 | "\n", 1019 | "https://upload.wikimedia.org/wikipedia/commons/4/45/Green_Dragon_Tavern_%2836196%29.jpg" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "markdown", 1024 | "metadata": { 1025 | "id": "U0VX_l6MxF6V" 1026 | }, 1027 | "source": [ 1028 | "Substitute **commons**:\n", 1029 | "\n", 1030 | "https://commons.wikimedia.org/wikipedia/commons/4/45/Green_Dragon_Tavern_%2836196%29.jpg" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "markdown", 1035 | "metadata": { 1036 | "id": "-qblYaQjxbTv" 1037 | }, 1038 | "source": [ 1039 | "Change to **wiki**:\n", 1040 | "\n", 1041 | "https://commons.wikimedia.org/wiki/commons/4/45/Green_Dragon_Tavern_%2836196%29.jpg" 1042 | ] 1043 | }, 1044 | { 1045 | "cell_type": "markdown", 1046 | "metadata": { 1047 | "id": "X5gXZPaIxkF4" 1048 | }, 1049 | "source": [ 1050 | "Change to **File:**\n", 1051 | "\n", 1052 | "https://commons.wikimedia.org/wiki/File:Green_Dragon_Tavern_%2836196%29.jpg" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "markdown", 1057 | "metadata": { 1058 | "id": "RtMaoWLJxxMo" 1059 | }, 1060 | "source": [ 1061 | "Translate:\n", 1062 | "\n", 1063 | "https://commons.wikimedia.org/wiki/File:Green_Dragon_Tavern_(36196).jpg\n", 1064 | "\n", 1065 | "From the HTML Encoding Reference, **%28:(** and **%29:)**." 1066 | ] 1067 | }, 1068 | { 1069 | "cell_type": "markdown", 1070 | "metadata": { 1071 | "id": "K5X1DWYV1_7t" 1072 | }, 1073 | "source": [ 1074 | "Here is the resource for the image:" 1075 | ] 1076 | }, 1077 | { 1078 | "cell_type": "code", 1079 | "metadata": { 1080 | "id": "04sgzu3TYAJ-" 1081 | }, 1082 | "source": [ 1083 | "# https://commons.wikimedia.org/wiki/File:Green_Dragon_Tavern_(36196).jpg" 1084 | ], 1085 | "execution_count": null, 1086 | "outputs": [] 1087 | }, 1088 | { 1089 | "cell_type": "markdown", 1090 | "metadata": { 1091 | "id": "i-ZaSmf3yFAR" 1092 | }, 1093 | "source": [ 1094 | "Just copy (sans the hash symbol) and paste into your favorite browser to find the resource for the image! Sometimes the URL doesn't translate correctly in a text cell. So we placed it in a code cell and commented it out." 1095 | ] 1096 | }, 1097 | { 1098 | "cell_type": "markdown", 1099 | "metadata": { 1100 | "id": "KdUiZXRJz5xD" 1101 | }, 1102 | "source": [ 1103 | "## The Second One\n", 1104 | "\n", 1105 | "Let's try the next one." 1106 | ] 1107 | }, 1108 | { 1109 | "cell_type": "markdown", 1110 | "metadata": { 1111 | "id": "VMlvSWek2Ov0" 1112 | }, 1113 | "source": [ 1114 | "Result:" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "code", 1119 | "metadata": { 1120 | "id": "7e9MG7Rn2JZ8" 1121 | }, 1122 | "source": [ 1123 | "# https://commons.wikimedia.org/wiki/File:Circus_Circus_Hotel-Casino_sign.jpg" 1124 | ], 1125 | "execution_count": null, 1126 | "outputs": [] 1127 | }, 1128 | { 1129 | "cell_type": "markdown", 1130 | "metadata": { 1131 | "id": "XVCwtiIU0F00" 1132 | }, 1133 | "source": [ 1134 | "This one was easy because we didn't need to translate." 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "markdown", 1139 | "metadata": { 1140 | "id": "cJ-bw6qg2pAn" 1141 | }, 1142 | "source": [ 1143 | "## The Rest\n", 1144 | "\n", 1145 | "Results:" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "metadata": { 1151 | "id": "PyjoS4sO2rvN" 1152 | }, 1153 | "source": [ 1154 | "'''\n", 1155 | "https://commons.wikimedia.org/wiki/File:Leon_hot_air_balloon_festival_2010.jpg\n", 1156 | "https://commons.wikimedia.org/wiki/File:2012_Festival_of_Sail_-_7943922284.jpg\n", 1157 | "https://commons.wikimedia.org/wiki/File:17_mai_2018.jpg\n", 1158 | "https://commons.wikimedia.org/wiki/File:Fruit_baskets.jpg\n", 1159 | "https://commons.wikimedia.org/wiki/File:Fruit_stands,_Rue_de_Seine,_Paris_22_May_2014.jpg\n", 1160 | "https://commons.wikimedia.org/wiki/File:Wine_tasting_@_brown_brothers.jpg\n", 1161 | "'''" 1162 | ], 1163 | "execution_count": null, 1164 | "outputs": [] 1165 | }, 1166 | { 1167 | "cell_type": "markdown", 1168 | "metadata": { 1169 | "id": "U0m6TX8uZCSb" 1170 | }, 1171 | "source": [ 1172 | "Just copy and paste to a browser." 1173 | ] 1174 | } 1175 | ] 1176 | } -------------------------------------------------------------------------------- /chapter06/ch06.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ch06.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "_EXoS-8gwxlx" 22 | }, 23 | "source": [ 24 | "# Transfer Learning\n", 25 | "\n", 26 | "Image classification models have millions of parameters. Training them from scratch requires a lot of labeled training data and a lot of computing power. Transfer learning is a technique that shortcuts much of this by taking a piece of a model that has already been trained on a related task and reusing it in a new model.\n", 27 | "\n", 28 | "**Transfer learning** is the process of applying existing machine learning models to scenarios for which they were not originally intended. This leveraging can save training time and extend the usefulness of existing machine learning models, models which may have had the available data and computation to have been trained for very long periods of time on very large datasets. If we train a model on a large set of data, we can then refine the result to be effective on our smaller amount of data. At least, that's the idea.\n", 29 | "\n", 30 | "Documentation:\n", 31 | "\n", 32 | "https://www.tensorflow.org/tutorials/images/transfer_learning" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "id": "EuZlMjQkO4TZ" 39 | }, 40 | "source": [ 41 | "# Import **tensorflow** Library" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "id": "3bHZBZc4K-en" 48 | }, 49 | "source": [ 50 | "Import the library and alias it:" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "id": "U8zzCdqgO4a4" 57 | }, 58 | "source": [ 59 | "import tensorflow as tf" 60 | ], 61 | "execution_count": null, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "y7UC2JNPw3pY" 68 | }, 69 | "source": [ 70 | "# GPU Hardware Accelerator\n", 71 | "\n", 72 | "To vastly speed up processing, we can use the GPU available from the Google Colab cloud service. Colab provides a free Tesla K80 GPU of about 12 GB. It’s very easy to enable the GPU in a Colab notebook:\n", 73 | "\n", 74 | "1.\tclick **Runtime** in the top left menu\n", 75 | "2.\tclick **Change runtime** type from the drop-down menu\n", 76 | "3.\tchoose **GPU** from the Hardware accelerator drop-down menu\n", 77 | "4.\tclick **SAVE**" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": { 83 | "id": "KDNFpWCTw7L4" 84 | }, 85 | "source": [ 86 | "Verify that GPU is active:" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "metadata": { 92 | "id": "Ck2w8yEvwqpp" 93 | }, 94 | "source": [ 95 | "tf.__version__, tf.test.gpu_device_name()" 96 | ], 97 | "execution_count": null, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "MENnnXfKFC00" 104 | }, 105 | "source": [ 106 | "# Pre-trained Models for Transfer Learning\n", 107 | "\n", 108 | "If we don't have enough training data, it is often a good idea to reuse the lower layers of a pre-trained model. **Transfer learning** is the process of creating new AI models by fine-tuning previously trained neural networks. Instead of training a neural network from scratch, we can download a pretrained, open-source deep learning model and fine tune it for our own purpose.\n", 109 | "\n", 110 | "To implement transfer learning, we reuse parts of a pre-trained model and change the final layer (or several layers) of the model. We then retrain those layers on our own dataset." 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "UH4t09XuwbKq" 117 | }, 118 | "source": [ 119 | "# Simple Transfer Learning with TensorFlow Hub\n", 120 | "\n", 121 | "We model Flowers data by using pre-trained TF2 SavedModels from TensorFlow Hub for image feature extraction. The pre-trained models were trained on very large and general datasets.\n", 122 | "\n", 123 | "We use two pre-trained TensorFlow Hub models to do transfer learning. **TensorFlow Hub** is a repository of trained machine learning models ready for fine-tuning and deployable anywhere. We begin with the MobileNet v2 pre-trained model. We then use the Inception v3 pre-trained model and compare results between the two.\n", 124 | "\n", 125 | "Resources:\n", 126 | "\n", 127 | "https://www.tensorflow.org/hub\n", 128 | "\n", 129 | "https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "Y0w1bpGuu0Ph" 136 | }, 137 | "source": [ 138 | "# MobileNet v2 Example\n", 139 | "\n", 140 | "Information about MobileNet and other pre-trained models is avaliable at the following URL:\n", 141 | "\n", 142 | "https://tfhub.dev/s?module-type=image-feature-vector&q=tf2" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "id": "X3T4YNQgG-N3" 149 | }, 150 | "source": [ 151 | "## Load flowers as TFDS" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "PxJ2XxwiClgK" 158 | }, 159 | "source": [ 160 | "Split 75% for train set, 15% for validation set, and 25% for test set:" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "8wTU7RONG-TO" 167 | }, 168 | "source": [ 169 | "import tensorflow_datasets as tfds\n", 170 | "\n", 171 | "(test, valid, train), info = tfds.load(\n", 172 | " 'tf_flowers', as_supervised=True,\n", 173 | " split = ['train[:10%]', 'train[10%:25%]', 'train[25%:]'],\n", 174 | " with_info=True, try_gcs=True)" 175 | ], 176 | "execution_count": null, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": { 182 | "id": "MmkPOJ-iG-YG" 183 | }, 184 | "source": [ 185 | "## Get Metadata" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "id": "fx3xxQjEJu9q" 192 | }, 193 | "source": [ 194 | "Display general information:" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "metadata": { 200 | "id": "OTXAcajOG-dG" 201 | }, 202 | "source": [ 203 | "info" 204 | ], 205 | "execution_count": null, 206 | "outputs": [] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": { 211 | "id": "gN_s6TK3G-he" 212 | }, 213 | "source": [ 214 | "Display number examples in data splits:" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "id": "BrQAa6GeG-me" 221 | }, 222 | "source": [ 223 | "num_train_img = info.splits['train[25%:]'].num_examples\n", 224 | "num_valid_img = info.splits['train[10%:25%]'].num_examples\n", 225 | "num_test_img = info.splits['train[:10%]'].num_examples\n", 226 | "print ('train images:', num_train_img)\n", 227 | "print ('valid images:', num_valid_img)\n", 228 | "print ('test images:', num_test_img)" 229 | ], 230 | "execution_count": null, 231 | "outputs": [] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "hG5NLjhYKpIR" 237 | }, 238 | "source": [ 239 | "Calculate number of examples in data splits manually to verify:" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "metadata": { 245 | "id": "YbkjuJV0Kb4Z" 246 | }, 247 | "source": [ 248 | "num_train_examples = 0\n", 249 | "num_valid_examples = 0\n", 250 | "num_test_examples = 0\n", 251 | "\n", 252 | "for example in train:\n", 253 | " num_train_examples += 1\n", 254 | "\n", 255 | "for example in valid:\n", 256 | " num_valid_examples += 1\n", 257 | "\n", 258 | "for example in test:\n", 259 | " num_test_examples += 1\n", 260 | "\n", 261 | "print('Total Number of Training Images: {}'\\\n", 262 | " .format(num_train_examples))\n", 263 | "print('Total Number of Validation Images: {}'\\\n", 264 | " .format(num_valid_examples))\n", 265 | "print('Total Number of Testing Images: {}'\\\n", 266 | " .format(num_test_examples))" 267 | ], 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "M--BAF0kKDxK" 275 | }, 276 | "source": [ 277 | "Get labels:" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "metadata": { 283 | "id": "mSRmdmIsKD2w" 284 | }, 285 | "source": [ 286 | "class_labels = info.features['label'].names\n", 287 | "class_labels" 288 | ], 289 | "execution_count": null, 290 | "outputs": [] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": { 295 | "id": "RRK9PcsAKHex" 296 | }, 297 | "source": [ 298 | "Get number of classes:" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "id": "DM16A6cJKHlI" 305 | }, 306 | "source": [ 307 | "num_classes = info.features['label'].num_classes\n", 308 | "num_classes" 309 | ], 310 | "execution_count": null, 311 | "outputs": [] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "id": "QWMmNDhzKM-5" 317 | }, 318 | "source": [ 319 | "## Display Examples" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": { 325 | "id": "etyNc4QmC1xp" 326 | }, 327 | "source": [ 328 | "Display some examples with **show_examples**:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "metadata": { 334 | "id": "CGHKa5kDKTxp" 335 | }, 336 | "source": [ 337 | "fig = tfds.show_examples(train, info)" 338 | ], 339 | "execution_count": null, 340 | "outputs": [] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": { 345 | "id": "-8oX9dWmLvpr" 346 | }, 347 | "source": [ 348 | "## Inspect Images" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": { 354 | "id": "0ATaMfY4C6Xv" 355 | }, 356 | "source": [ 357 | "Display shapes to check images sizes:" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "metadata": { 363 | "id": "Yv2oZaZdLvwz" 364 | }, 365 | "source": [ 366 | "for i, example in enumerate(train.take(5)):\n", 367 | " print('Image {} shape: {} label: {}'\\\n", 368 | " .format(i+1, example[0].shape,\n", 369 | " example[1]))" 370 | ], 371 | "execution_count": null, 372 | "outputs": [] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": { 377 | "id": "1lnJJm-oL5jZ" 378 | }, 379 | "source": [ 380 | "The images in the flowers dataset are not all the same size. So, we must resize images to a standard size to make them consumable by TensorFlow models." 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": { 386 | "id": "5BMod5pRK73p" 387 | }, 388 | "source": [ 389 | "## Build the Input Pipeline" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": { 395 | "id": "ae0WhPz8DGiB" 396 | }, 397 | "source": [ 398 | "Create a function to reformat all images to the resolution expected by MobileNet v2 (224, 224) and scale them. The function takes an 'image' and a 'label' as arguments and returns the new 'image' and corresponding 'label' in the desired form." 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "metadata": { 404 | "id": "h3qo8MNmK780" 405 | }, 406 | "source": [ 407 | "def format_image(image, label):\n", 408 | " image = tf.image.resize(image, (224, 224)) /255.0\n", 409 | " return image, label" 410 | ], 411 | "execution_count": null, 412 | "outputs": [] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "kIy79Wq8DUiC" 418 | }, 419 | "source": [ 420 | "Map function to train, validation, and test sets. And, apply other transformations:" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "metadata": { 426 | "id": "wq9zRe17M_ti" 427 | }, 428 | "source": [ 429 | "BATCH_SIZE = 367\n", 430 | "\n", 431 | "train_batches = train.shuffle(num_train_img//4).\\\n", 432 | " map(format_image).batch(BATCH_SIZE).prefetch(1)\n", 433 | "\n", 434 | "validation_batches = valid.map(format_image).\\\n", 435 | " batch(BATCH_SIZE).prefetch(1)\n", 436 | "\n", 437 | "test_batches = test.map(format_image).\\\n", 438 | " batch(BATCH_SIZE).prefetch(1) " 439 | ], 440 | "execution_count": null, 441 | "outputs": [] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": { 446 | "id": "v9ND2p2wx5ka" 447 | }, 448 | "source": [ 449 | "## Simple Transfer Learning with MobileNet-v2\n", 450 | "\n", 451 | "We begin the process by creating a feature_extractor. The partial model from TensorFlow Hub (without the final classification layer) is called a feature vector. Go to the TensorFlow Hub documentation (https://tfhub.dev/s?module-type=image-feature-vector) to see a list of available feature vectors.\n", 452 | "\n", 453 | "To get information about MobileNet-v2:\n", 454 | "\n", 455 | "https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV2\n", 456 | "\n", 457 | "Resource:\n", 458 | "\n", 459 | "https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l06c03_exercise_flowers_with_transfer_learning_solution.ipynb" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": { 465 | "id": "jl4OEnjKO6oa" 466 | }, 467 | "source": [ 468 | "### Create a Feature Extractor\n", 469 | "\n", 470 | "Create a feature_extractor using the MobileNet-v2 feature vector. A **feature extractor** is the partial model from TensorFlow Hub (without the final classification layer).\n", 471 | "\n", 472 | "To see a list of available feature vectors, visit:\n", 473 | "\n", 474 | "https://tfhub.dev/s?module-type=image-feature-vector&q=tf2\n", 475 | "\n", 476 | "Click on one of them, read the documentation, and get the corresponding URL to get the feature vector." 477 | ] 478 | }, 479 | { 480 | "cell_type": "markdown", 481 | "metadata": { 482 | "id": "-bjuBKpnDaIr" 483 | }, 484 | "source": [ 485 | "Create the feature extractor:" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "metadata": { 491 | "id": "UGN5Jr6QPck9" 492 | }, 493 | "source": [ 494 | "import tensorflow_hub as hub\n", 495 | "\n", 496 | "piece1 = 'https://tfhub.dev/google/tf2-preview/'\n", 497 | "piece2 = 'mobilenet_v2/feature_vector/4'\n", 498 | "URL = piece1 + piece2\n", 499 | "feature_extractor_mn = hub.KerasLayer(\n", 500 | " URL, input_shape=(224, 224, 3))" 501 | ], 502 | "execution_count": null, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "rjRWGdFxSC7x" 509 | }, 510 | "source": [ 511 | "The feature extractor is now a partial MobileNet-v2 model." 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "Fw43lTVWPnRV" 518 | }, 519 | "source": [ 520 | "## Freeze the Pretrained Model" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": { 526 | "id": "x9kQahuoDghk" 527 | }, 528 | "source": [ 529 | "Freeze the variables in the feature extractor layer, so that the training only modifies the final classifier layer:" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "id": "Dgl_TlRSPnij" 536 | }, 537 | "source": [ 538 | "feature_extractor_mn.trainable = False" 539 | ], 540 | "execution_count": null, 541 | "outputs": [] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "id": "HWKd1iyBQCFd" 547 | }, 548 | "source": [ 549 | "## Attach a Classification Head\n", 550 | "\n", 551 | "Create a classification head to leverage the pre-trained model for the dataset, which consists of a simple sequential model that includes the pre-trained model and the new classification layer." 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "id": "ZwBhqoCoRAVr" 558 | }, 559 | "source": [ 560 | "Import libraries:" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "metadata": { 566 | "id": "Qnv6li8bQ22C" 567 | }, 568 | "source": [ 569 | "from tensorflow.keras.models import Sequential\n", 570 | "from tensorflow.keras.layers import Dense, Dropout" 571 | ], 572 | "execution_count": null, 573 | "outputs": [] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": { 578 | "id": "I3mJoKNPW1Vb" 579 | }, 580 | "source": [ 581 | "Clear previous models and generate seed:" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "metadata": { 587 | "id": "hXRRHQvGW1cE" 588 | }, 589 | "source": [ 590 | "import numpy as np\n", 591 | "\n", 592 | "tf.keras.backend.clear_session()\n", 593 | "np.random.seed(0)\n", 594 | "tf.random.set_seed(0)" 595 | ], 596 | "execution_count": null, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": { 602 | "id": "TZMxr4WMRCIb" 603 | }, 604 | "source": [ 605 | "Build model:" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "metadata": { 611 | "id": "yypC6MPPQCMm" 612 | }, 613 | "source": [ 614 | "mobile_model = tf.keras.Sequential([\n", 615 | " feature_extractor_mn,\n", 616 | " Dropout(0.5),\n", 617 | " Dense(num_classes)])" 618 | ], 619 | "execution_count": null, 620 | "outputs": [] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": { 625 | "id": "ayOfuAe5REZb" 626 | }, 627 | "source": [ 628 | "Inspect model:" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "metadata": { 634 | "id": "IkP7ZsQ8REgT" 635 | }, 636 | "source": [ 637 | "mobile_model.summary()" 638 | ], 639 | "execution_count": null, 640 | "outputs": [] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "metadata": { 645 | "id": "BhI6ImegRL0c" 646 | }, 647 | "source": [ 648 | "## Compile" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": { 654 | "id": "62CAXyOmFLAA" 655 | }, 656 | "source": [ 657 | "Compile **SparseCategoricalCrossentropy**:" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "metadata": { 663 | "id": "9RRZ19elRL6N" 664 | }, 665 | "source": [ 666 | "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n", 667 | "\n", 668 | "mobile_model.compile(\n", 669 | " optimizer='adam',\n", 670 | " loss=SparseCategoricalCrossentropy(from_logits=True),\n", 671 | " metrics=['accuracy'])" 672 | ], 673 | "execution_count": null, 674 | "outputs": [] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": { 679 | "id": "7HgFrUxHRL-l" 680 | }, 681 | "source": [ 682 | "## Train" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": { 688 | "id": "56J55Lc8FSxN" 689 | }, 690 | "source": [ 691 | "Train model on train and validation sets for six epochs:" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "metadata": { 697 | "id": "aQtBFEMBRMEF" 698 | }, 699 | "source": [ 700 | "EPOCHS = 6\n", 701 | "\n", 702 | "history = mobile_model.fit(\n", 703 | " train_batches, epochs=EPOCHS,\n", 704 | " validation_data=validation_batches)" 705 | ], 706 | "execution_count": null, 707 | "outputs": [] 708 | }, 709 | { 710 | "cell_type": "markdown", 711 | "metadata": { 712 | "id": "2BJ5n6EuRj6d" 713 | }, 714 | "source": [ 715 | "We get good accuracy with just 6 epochs because MobileNet-v2 was carefully designed over a long time by experts and then trained on a massive dataset (ImageNet)." 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": { 721 | "id": "tbH_oIP6TFlf" 722 | }, 723 | "source": [ 724 | "## Visualize Performance" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": { 730 | "id": "sV2CInmjFY7t" 731 | }, 732 | "source": [ 733 | "Plot model performance:" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "metadata": { 739 | "id": "4f2zopmrS6nR" 740 | }, 741 | "source": [ 742 | "import matplotlib.pyplot as plt\n", 743 | "\n", 744 | "acc = history.history['accuracy']\n", 745 | "val_acc = history.history['val_accuracy']\n", 746 | "\n", 747 | "loss = history.history['loss']\n", 748 | "val_loss = history.history['val_loss']\n", 749 | "\n", 750 | "epochs_range = range(EPOCHS)\n", 751 | "\n", 752 | "plt.figure(figsize=(8, 8))\n", 753 | "plt.subplot(1, 2, 1)\n", 754 | "plt.plot(epochs_range, acc, label='Training Accuracy')\n", 755 | "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", 756 | "plt.legend(loc='lower right')\n", 757 | "plt.title('Training and Validation Accuracy')\n", 758 | "\n", 759 | "plt.subplot(1, 2, 2)\n", 760 | "plt.plot(epochs_range, loss, label='Training Loss')\n", 761 | "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", 762 | "plt.legend(loc='upper right')\n", 763 | "plt.title('Training and Validation Loss')\n", 764 | "plt.show()" 765 | ], 766 | "execution_count": null, 767 | "outputs": [] 768 | }, 769 | { 770 | "cell_type": "markdown", 771 | "metadata": { 772 | "id": "8X88gpnaTPxY" 773 | }, 774 | "source": [ 775 | "## Make Predictions from Test Data" 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "metadata": { 781 | "id": "biEzYQS9Ff9b" 782 | }, 783 | "source": [ 784 | "Predict on **test_batches**:" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "metadata": { 790 | "id": "8S49mKRpTP3x" 791 | }, 792 | "source": [ 793 | "predictions = mobile_model.predict(test_batches)" 794 | ], 795 | "execution_count": null, 796 | "outputs": [] 797 | }, 798 | { 799 | "cell_type": "markdown", 800 | "metadata": { 801 | "id": "eNctj1GNZsKc" 802 | }, 803 | "source": [ 804 | "Test data is pure because we haven't seen it yet!" 805 | ] 806 | }, 807 | { 808 | "cell_type": "markdown", 809 | "metadata": { 810 | "id": "asEZU59ufx--" 811 | }, 812 | "source": [ 813 | "Display class labels:" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "metadata": { 819 | "id": "gM7d4ZX2fyEa" 820 | }, 821 | "source": [ 822 | "class_labels" 823 | ], 824 | "execution_count": null, 825 | "outputs": [] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": { 830 | "id": "a7CkY_hHZYp1" 831 | }, 832 | "source": [ 833 | "### Inspect the First Prediction" 834 | ] 835 | }, 836 | { 837 | "cell_type": "markdown", 838 | "metadata": { 839 | "id": "MnEOyfn5eIVp" 840 | }, 841 | "source": [ 842 | "Get the first prediction array:" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "metadata": { 848 | "id": "_OzOaQmYSTrN" 849 | }, 850 | "source": [ 851 | "predictions[0]" 852 | ], 853 | "execution_count": null, 854 | "outputs": [] 855 | }, 856 | { 857 | "cell_type": "markdown", 858 | "metadata": { 859 | "id": "AtVIdMx2Z6RV" 860 | }, 861 | "source": [ 862 | "The returned array is the raw prediction." 863 | ] 864 | }, 865 | { 866 | "cell_type": "markdown", 867 | "metadata": { 868 | "id": "5WZprbu5eMmR" 869 | }, 870 | "source": [ 871 | "Use the np.argmax() function to get the prediction for the first image:" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "metadata": { 877 | "id": "8JnnrBIwZ6XY" 878 | }, 879 | "source": [ 880 | "predicted_id = np.argmax(predictions[0])\n", 881 | "predicted_id" 882 | ], 883 | "execution_count": null, 884 | "outputs": [] 885 | }, 886 | { 887 | "cell_type": "markdown", 888 | "metadata": { 889 | "id": "nEeMW4HqaPXl" 890 | }, 891 | "source": [ 892 | "Convert the label to its class name:" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "metadata": { 898 | "id": "L3N1p1FcaPeQ" 899 | }, 900 | "source": [ 901 | "class_labels[predicted_id]" 902 | ], 903 | "execution_count": null, 904 | "outputs": [] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": { 909 | "id": "ycTecXuZGeIn" 910 | }, 911 | "source": [ 912 | "Get the actual labels from the first batch:" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "metadata": { 918 | "id": "5oSGEiX_GeOr" 919 | }, 920 | "source": [ 921 | "for img, lbl in test_batches.take(1):\n", 922 | " print (lbl)" 923 | ], 924 | "execution_count": null, 925 | "outputs": [] 926 | }, 927 | { 928 | "cell_type": "markdown", 929 | "metadata": { 930 | "id": "XaKahEURG3UR" 931 | }, 932 | "source": [ 933 | "Get the first label:" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "metadata": { 939 | "id": "bBxf8d_tG3gl" 940 | }, 941 | "source": [ 942 | "class_labels[lbl[0].numpy()]" 943 | ], 944 | "execution_count": null, 945 | "outputs": [] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "metadata": { 950 | "id": "QBXIZ0vxHJEv" 951 | }, 952 | "source": [ 953 | "The prediction is correct if the actual label matches the prediction." 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": { 959 | "id": "3CcLQztMeEbj" 960 | }, 961 | "source": [ 962 | "### Inspect the First Batch of Predictions" 963 | ] 964 | }, 965 | { 966 | "cell_type": "markdown", 967 | "metadata": { 968 | "id": "SIV2BJbyaesu" 969 | }, 970 | "source": [ 971 | "Alternatively, we can convert *test_batches* to an iterator:" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "metadata": { 977 | "id": "nNbfATXTbP1h" 978 | }, 979 | "source": [ 980 | "image_batch, label_batch = next(iter(test_batches))\n", 981 | "\n", 982 | "images = image_batch.numpy()\n", 983 | "labels = label_batch.numpy()\n", 984 | "\n", 985 | "class_labels[labels[0]]" 986 | ], 987 | "execution_count": null, 988 | "outputs": [] 989 | }, 990 | { 991 | "cell_type": "markdown", 992 | "metadata": { 993 | "id": "Dyyopo3fHqHg" 994 | }, 995 | "source": [ 996 | "Get the first batch from test_batches iterator with **next**, convert images and labels to NumPy, and display the first label." 997 | ] 998 | }, 999 | { 1000 | "cell_type": "markdown", 1001 | "metadata": { 1002 | "id": "KynoSZMocbdk" 1003 | }, 1004 | "source": [ 1005 | "Display labels from the first batch:" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "code", 1010 | "metadata": { 1011 | "id": "7qWRueYicymH" 1012 | }, 1013 | "source": [ 1014 | "labels" 1015 | ], 1016 | "execution_count": null, 1017 | "outputs": [] 1018 | }, 1019 | { 1020 | "cell_type": "markdown", 1021 | "metadata": { 1022 | "id": "Hcwlb1bFe30Y" 1023 | }, 1024 | "source": [ 1025 | "Convert the batch of labels to named labels:" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "metadata": { 1031 | "id": "qdiK6Q4ye4I2" 1032 | }, 1033 | "source": [ 1034 | "named_labels = [class_labels[labels[i]]\n", 1035 | " for i, lbl in enumerate(range(BATCH_SIZE))]\n", 1036 | "named_labels" 1037 | ], 1038 | "execution_count": null, 1039 | "outputs": [] 1040 | }, 1041 | { 1042 | "cell_type": "markdown", 1043 | "metadata": { 1044 | "id": "q8OC80vgebDp" 1045 | }, 1046 | "source": [ 1047 | "Get predictions from the first batch:" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "metadata": { 1053 | "id": "ryhNm6SQc0tY" 1054 | }, 1055 | "source": [ 1056 | "predicted_batch = [np.argmax(predictions[i])\n", 1057 | " for i, _ in enumerate(range(BATCH_SIZE))]\n", 1058 | "predicted_batch" 1059 | ], 1060 | "execution_count": null, 1061 | "outputs": [] 1062 | }, 1063 | { 1064 | "cell_type": "markdown", 1065 | "metadata": { 1066 | "id": "4dhN7z4KgKSf" 1067 | }, 1068 | "source": [ 1069 | "Convert predictions to named predictions:" 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "code", 1074 | "metadata": { 1075 | "id": "3DjJihdqe3Eb" 1076 | }, 1077 | "source": [ 1078 | "named_pred = [class_labels[predicted_batch[i]]\n", 1079 | " for i, lbl in enumerate(range(BATCH_SIZE))]\n", 1080 | "named_pred" 1081 | ], 1082 | "execution_count": null, 1083 | "outputs": [] 1084 | }, 1085 | { 1086 | "cell_type": "markdown", 1087 | "metadata": { 1088 | "id": "rCpGxM23ejaM" 1089 | }, 1090 | "source": [ 1091 | "## Plot Model Predictions" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "markdown", 1096 | "metadata": { 1097 | "id": "a56GxOAzII9x" 1098 | }, 1099 | "source": [ 1100 | "The visualization shows actual images from the first test batch. If the prediction is correct, the title is blue. If not, the title is red. If the prediction is incorrect, the prediction is displayed along with the actual label in parentheses." 1101 | ] 1102 | }, 1103 | { 1104 | "cell_type": "code", 1105 | "metadata": { 1106 | "id": "WzeSIJm8ejoC" 1107 | }, 1108 | "source": [ 1109 | "plt.figure(figsize=(20,20))\n", 1110 | "for n in range(30):\n", 1111 | " plt.subplot(6,5,n+1)\n", 1112 | " plt.subplots_adjust(hspace = 0.3)\n", 1113 | " plt.imshow(images[n])\n", 1114 | " color = 'blue' if labels[n] == predicted_batch[n] else 'red'\n", 1115 | " if labels[n] != predicted_batch[n]:\n", 1116 | " t = named_pred[n].title() +\\\n", 1117 | " ' (' +named_labels[n].title() + ')'\n", 1118 | " else:\n", 1119 | " t = named_pred[n].title()\n", 1120 | " plt.title(t, color=color)\n", 1121 | " plt.axis('off')\n", 1122 | " st = 'Model predictions (blue: correct, red: incorrect)'\n", 1123 | "_ = plt.suptitle(st)" 1124 | ], 1125 | "execution_count": null, 1126 | "outputs": [] 1127 | }, 1128 | { 1129 | "cell_type": "markdown", 1130 | "metadata": { 1131 | "id": "CNJtM2ISiUrm" 1132 | }, 1133 | "source": [ 1134 | "# Perform Transfer Learning with the Inception Model\n", 1135 | "\n", 1136 | "Use the Inception model to compare against MobileNet. To get the model, peruse https://tfhub.dev/s?module-type=image-feature-vector&q=tf2 and click on 'tf2-preview/inception_v3/feature_vector'. This feature vector corresponds to the Inception v3 model. Use transfer learning to create a CNN that uses Inception v3 as the pretrained model to classify the images from the Flowers dataset. Note that Inception takes as input images that are 299 x 299 pixels." 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "markdown", 1141 | "metadata": { 1142 | "id": "76u7UnAhlh0B" 1143 | }, 1144 | "source": [ 1145 | "## Reformat Images and Create Batches" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "markdown", 1150 | "metadata": { 1151 | "id": "koMZP7TeISAq" 1152 | }, 1153 | "source": [ 1154 | "Recreate the **format_image** function to reformat images to the resolution expected by Inception v3 (299, 299), and scale them:" 1155 | ] 1156 | }, 1157 | { 1158 | "cell_type": "code", 1159 | "metadata": { 1160 | "id": "wuoyOMQWlh8o" 1161 | }, 1162 | "source": [ 1163 | "def format_image(image, label):\n", 1164 | " image = tf.image.resize(image, (299, 299)) / 255.0\n", 1165 | " return image, label" 1166 | ], 1167 | "execution_count": null, 1168 | "outputs": [] 1169 | }, 1170 | { 1171 | "cell_type": "markdown", 1172 | "metadata": { 1173 | "id": "M7H7nCITliCg" 1174 | }, 1175 | "source": [ 1176 | "## Build an Input Pipeline for Inception" 1177 | ] 1178 | }, 1179 | { 1180 | "cell_type": "markdown", 1181 | "metadata": { 1182 | "id": "ywb5gqNqIpTI" 1183 | }, 1184 | "source": [ 1185 | "Shuffle train data, reformat, batch, and prefetch train, validation, and test data:" 1186 | ] 1187 | }, 1188 | { 1189 | "cell_type": "code", 1190 | "metadata": { 1191 | "id": "gGukQ5oQliNY" 1192 | }, 1193 | "source": [ 1194 | "BATCH_SIZE = 367\n", 1195 | "\n", 1196 | "train_im = train.shuffle(num_train_img//4).\\\n", 1197 | " map(format_image).batch(BATCH_SIZE).prefetch(1)\n", 1198 | "\n", 1199 | "validation_im = valid.map(format_image).\\\n", 1200 | " batch(BATCH_SIZE).prefetch(1)\n", 1201 | "\n", 1202 | "test_im = test.map(format_image).\\\n", 1203 | " batch(BATCH_SIZE).prefetch(1)" 1204 | ], 1205 | "execution_count": null, 1206 | "outputs": [] 1207 | }, 1208 | { 1209 | "cell_type": "markdown", 1210 | "metadata": { 1211 | "id": "cc8HwUN5kUBd" 1212 | }, 1213 | "source": [ 1214 | "## Create a Feature Extractor\n", 1215 | "\n", 1216 | "Create a feature_extractor using the Inception v3 feature vector. A **feature extractor** is the partial model from TensorFlow Hub (without the final classification layer).\n", 1217 | "\n", 1218 | "To see a list of available feature vectors, visit:\n", 1219 | "\n", 1220 | "https://tfhub.dev/s?module-type=image-feature-vector&q=tf2\n", 1221 | "\n", 1222 | "Click on one of them, read the documentation, and get the corresponding URL to get the feature vector." 1223 | ] 1224 | }, 1225 | { 1226 | "cell_type": "markdown", 1227 | "metadata": { 1228 | "id": "H2EUn646JBpS" 1229 | }, 1230 | "source": [ 1231 | "Create the feature extractor:" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "metadata": { 1237 | "id": "sCJG_7A_iUyc" 1238 | }, 1239 | "source": [ 1240 | "piece1 = 'https://tfhub.dev/google/tf2-preview/'\n", 1241 | "piece2 = 'inception_v3/feature_vector/4'\n", 1242 | "URL = piece1 + piece2\n", 1243 | "feature_extractor_im = hub.KerasLayer(URL,\n", 1244 | " input_shape=(299, 299, 3),\n", 1245 | " trainable=False)" 1246 | ], 1247 | "execution_count": null, 1248 | "outputs": [] 1249 | }, 1250 | { 1251 | "cell_type": "markdown", 1252 | "metadata": { 1253 | "id": "Qo3LJq0uS4ak" 1254 | }, 1255 | "source": [ 1256 | "Freeze the pre-trained layers:" 1257 | ] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "metadata": { 1262 | "id": "t2-3l7FrS3pF" 1263 | }, 1264 | "source": [ 1265 | "feature_extractor_im.trainable = False" 1266 | ], 1267 | "execution_count": null, 1268 | "outputs": [] 1269 | }, 1270 | { 1271 | "cell_type": "markdown", 1272 | "metadata": { 1273 | "id": "ydFK3WF7m84s" 1274 | }, 1275 | "source": [ 1276 | "Clear and seed:" 1277 | ] 1278 | }, 1279 | { 1280 | "cell_type": "code", 1281 | "metadata": { 1282 | "id": "OHDuFtVZm89k" 1283 | }, 1284 | "source": [ 1285 | "tf.keras.backend.clear_session()\n", 1286 | "np.random.seed(0)\n", 1287 | "tf.random.set_seed(0)" 1288 | ], 1289 | "execution_count": null, 1290 | "outputs": [] 1291 | }, 1292 | { 1293 | "cell_type": "markdown", 1294 | "metadata": { 1295 | "id": "IkLg09WoiU1e" 1296 | }, 1297 | "source": [ 1298 | "## Create the Inception Model" 1299 | ] 1300 | }, 1301 | { 1302 | "cell_type": "markdown", 1303 | "metadata": { 1304 | "id": "mGyFrbA-JLU4" 1305 | }, 1306 | "source": [ 1307 | "\n", 1308 | "\n", 1309 | "We've already set the stage with MobileNet. So we just need to subtitute the Inception feature extractor:" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "code", 1314 | "metadata": { 1315 | "id": "xazyJWpxiU65" 1316 | }, 1317 | "source": [ 1318 | "inception_model = tf.keras.Sequential([\n", 1319 | " feature_extractor_im,\n", 1320 | " Dropout(0.5),\n", 1321 | " Dense(num_classes)])" 1322 | ], 1323 | "execution_count": null, 1324 | "outputs": [] 1325 | }, 1326 | { 1327 | "cell_type": "markdown", 1328 | "metadata": { 1329 | "id": "4OHJnI2RiU-l" 1330 | }, 1331 | "source": [ 1332 | "## Compile" 1333 | ] 1334 | }, 1335 | { 1336 | "cell_type": "markdown", 1337 | "metadata": { 1338 | "id": "k5rcwGMnJfAp" 1339 | }, 1340 | "source": [ 1341 | "Compile with **SparseCategoricalCrossentropy(from_logits=True)**:" 1342 | ] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "metadata": { 1347 | "id": "6jX0nOYXiVDk" 1348 | }, 1349 | "source": [ 1350 | "inception_model.compile(\n", 1351 | " optimizer='adam',\n", 1352 | " loss=SparseCategoricalCrossentropy(from_logits=True),\n", 1353 | " metrics=['accuracy'])" 1354 | ], 1355 | "execution_count": null, 1356 | "outputs": [] 1357 | }, 1358 | { 1359 | "cell_type": "markdown", 1360 | "metadata": { 1361 | "id": "R7pgOL3NnK9N" 1362 | }, 1363 | "source": [ 1364 | "## Train" 1365 | ] 1366 | }, 1367 | { 1368 | "cell_type": "markdown", 1369 | "metadata": { 1370 | "id": "1VQQgDTZJkch" 1371 | }, 1372 | "source": [ 1373 | "Train model for six epochs:" 1374 | ] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "metadata": { 1379 | "id": "6ka4k-dknLB3" 1380 | }, 1381 | "source": [ 1382 | "EPOCHS = 6\n", 1383 | "\n", 1384 | "history = inception_model.fit(\n", 1385 | " train_im, epochs=EPOCHS,\n", 1386 | " validation_data=validation_im)" 1387 | ], 1388 | "execution_count": null, 1389 | "outputs": [] 1390 | }, 1391 | { 1392 | "cell_type": "markdown", 1393 | "metadata": { 1394 | "id": "JjCgXfVOAijJ" 1395 | }, 1396 | "source": [ 1397 | "## Visualize" 1398 | ] 1399 | }, 1400 | { 1401 | "cell_type": "markdown", 1402 | "metadata": { 1403 | "id": "r4EAyBkEJ87N" 1404 | }, 1405 | "source": [ 1406 | "Visualize model performance:" 1407 | ] 1408 | }, 1409 | { 1410 | "cell_type": "code", 1411 | "metadata": { 1412 | "id": "jL4Sz_dNAipc" 1413 | }, 1414 | "source": [ 1415 | "import matplotlib.pyplot as plt\n", 1416 | "\n", 1417 | "acc = history.history['accuracy']\n", 1418 | "val_acc = history.history['val_accuracy']\n", 1419 | "\n", 1420 | "loss = history.history['loss']\n", 1421 | "val_loss = history.history['val_loss']\n", 1422 | "\n", 1423 | "epochs_range = range(EPOCHS)\n", 1424 | "\n", 1425 | "plt.figure(figsize=(8, 8))\n", 1426 | "plt.subplot(1, 2, 1)\n", 1427 | "plt.plot(epochs_range, acc, label='Training Accuracy')\n", 1428 | "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", 1429 | "plt.legend(loc='lower right')\n", 1430 | "plt.title('Training and Validation Accuracy')\n", 1431 | "\n", 1432 | "plt.subplot(1, 2, 2)\n", 1433 | "plt.plot(epochs_range, loss, label='Training Loss')\n", 1434 | "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", 1435 | "plt.legend(loc='upper right')\n", 1436 | "plt.title('Training and Validation Loss')\n", 1437 | "plt.show()" 1438 | ], 1439 | "execution_count": null, 1440 | "outputs": [] 1441 | }, 1442 | { 1443 | "cell_type": "markdown", 1444 | "metadata": { 1445 | "id": "6A9NyBJyUpVQ" 1446 | }, 1447 | "source": [ 1448 | "## Predictions" 1449 | ] 1450 | }, 1451 | { 1452 | "cell_type": "markdown", 1453 | "metadata": { 1454 | "id": "Oj0AgshqUuZJ" 1455 | }, 1456 | "source": [ 1457 | "Make predictions:" 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "metadata": { 1463 | "id": "-FnYdw27UpcD" 1464 | }, 1465 | "source": [ 1466 | "im_predictions = inception_model.predict(test_im)" 1467 | ], 1468 | "execution_count": null, 1469 | "outputs": [] 1470 | }, 1471 | { 1472 | "cell_type": "markdown", 1473 | "metadata": { 1474 | "id": "J1STd2R6UpyY" 1475 | }, 1476 | "source": [ 1477 | "Get a batch of predictions and convert them to named predictions:" 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "code", 1482 | "metadata": { 1483 | "id": "Y8KcxLR-Up3R" 1484 | }, 1485 | "source": [ 1486 | "im_pred_batch = [np.argmax(im_predictions[i])\n", 1487 | " for i, _ in enumerate(range(BATCH_SIZE))]\n", 1488 | "im_named_pred = [class_labels[im_pred_batch[i]]\n", 1489 | " for i, lbl in enumerate(range(BATCH_SIZE))]" 1490 | ], 1491 | "execution_count": null, 1492 | "outputs": [] 1493 | }, 1494 | { 1495 | "cell_type": "markdown", 1496 | "metadata": { 1497 | "id": "jWi7gCeJWfEZ" 1498 | }, 1499 | "source": [ 1500 | "Grab the first batch of images and labels from the test set:" 1501 | ] 1502 | }, 1503 | { 1504 | "cell_type": "code", 1505 | "metadata": { 1506 | "id": "FqjjvhjgWfLY" 1507 | }, 1508 | "source": [ 1509 | "im_image_batch, im_label_batch = next(iter(test_im))\n", 1510 | "\n", 1511 | "im_images = im_image_batch.numpy()\n", 1512 | "im_labels = im_label_batch.numpy()" 1513 | ], 1514 | "execution_count": null, 1515 | "outputs": [] 1516 | }, 1517 | { 1518 | "cell_type": "markdown", 1519 | "metadata": { 1520 | "id": "TbwYqS-3VxDn" 1521 | }, 1522 | "source": [ 1523 | "Convert the labels to named labels:" 1524 | ] 1525 | }, 1526 | { 1527 | "cell_type": "code", 1528 | "metadata": { 1529 | "id": "c1u5CoMdVxLL" 1530 | }, 1531 | "source": [ 1532 | "im_named_labels = [class_labels[im_labels[i]]\n", 1533 | " for i, lbl in enumerate(range(BATCH_SIZE))]" 1534 | ], 1535 | "execution_count": null, 1536 | "outputs": [] 1537 | }, 1538 | { 1539 | "cell_type": "markdown", 1540 | "metadata": { 1541 | "id": "Wb6xTPeFVaDB" 1542 | }, 1543 | "source": [ 1544 | "## Plot Model Predictions" 1545 | ] 1546 | }, 1547 | { 1548 | "cell_type": "markdown", 1549 | "metadata": { 1550 | "id": "BcEjndkhVNzS" 1551 | }, 1552 | "source": [ 1553 | "Create a function to display predictions:" 1554 | ] 1555 | }, 1556 | { 1557 | "cell_type": "code", 1558 | "metadata": { 1559 | "id": "hQYpQAUYccJX" 1560 | }, 1561 | "source": [ 1562 | "def plot_pred(images, labels, named_labels, named_pred):\n", 1563 | " plt.figure(figsize=(20,20))\n", 1564 | " for n in range(30):\n", 1565 | " plt.subplot(6,5,n+1)\n", 1566 | " plt.subplots_adjust(hspace = 0.3)\n", 1567 | " plt.imshow(images[n])\n", 1568 | " color = 'blue' if named_labels[n] == named_pred[n] else 'red'\n", 1569 | " if named_labels[n] != named_pred[n]:\n", 1570 | " t = named_pred[n].title() +\\\n", 1571 | " ' (' +named_labels[n].title() + ')'\n", 1572 | " else:\n", 1573 | " t = named_pred[n].title()\n", 1574 | " plt.title(t, color=color)\n", 1575 | " plt.axis('off')\n", 1576 | " st = 'Model predictions (blue: correct, red: incorrect)'\n", 1577 | " _ = plt.suptitle(st)" 1578 | ], 1579 | "execution_count": null, 1580 | "outputs": [] 1581 | }, 1582 | { 1583 | "cell_type": "markdown", 1584 | "metadata": { 1585 | "id": "ZekwPu8bbo19" 1586 | }, 1587 | "source": [ 1588 | "Invoke the function:" 1589 | ] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "metadata": { 1594 | "id": "ERKUroLfbXOm" 1595 | }, 1596 | "source": [ 1597 | "plot_pred(im_images, im_labels, im_named_labels, im_named_pred)" 1598 | ], 1599 | "execution_count": null, 1600 | "outputs": [] 1601 | } 1602 | ] 1603 | } -------------------------------------------------------------------------------- /chapter08/ch08.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ch08.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "_EXoS-8gwxlx" 22 | }, 23 | "source": [ 24 | "# Stacked Autoencoders\n", 25 | "\n", 26 | "**Autoencoders** are artificial neural networks that learn dense representations of the input data without any supervision. The dense representations are called *latent representations* or *codings*. The codings typically have much lower dimensionality than the input data, which makes autoencoders useful for dimensionality reduction. They can also act as feature detectors (feature extraction), unsupervised pretraining of deep neural networks, and generative models. As generative models, they can randomly generate new data that looks very similar to the training data.\n", 27 | "\n", 28 | "Simply, autoencoders are trained in an unsupervised manner to learn the low-level features of an input (latent representations or codings), which are then used to reconstruct the original input. So, an autoencoder consists of 3 components: encoder, latent representations (or codings), and decoder. The encoder compresses the input and produces the codings. The decoder then reconstructs the input from the codings.\n", 29 | "\n", 30 | "Resources:\n", 31 | "\n", 32 | "https://towardsdatascience.com/applied-deep-learning-part-3-autoencoders-1c083af4d798\n", 33 | "\n", 34 | "https://www.tensorflow.org/tutorials/generative/autoencoder\n", 35 | "\n", 36 | "https://blog.keras.io/building-autoencoders-in-keras.html\n", 37 | "\n", 38 | "https://www.datacamp.com/community/tutorials/autoencoder-keras-tutorial" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "Ew6ILT-7Rbkc" 45 | }, 46 | "source": [ 47 | "# Import **tensorflow** library" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "id": "7_V_r2slQPxI" 54 | }, 55 | "source": [ 56 | "Import library and alias it:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "metadata": { 62 | "id": "P1B43zq5Rb2b" 63 | }, 64 | "source": [ 65 | "import tensorflow as tf" 66 | ], 67 | "execution_count": null, 68 | "outputs": [] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "y7UC2JNPw3pY" 74 | }, 75 | "source": [ 76 | "# GPU Hardware Accelerator\n", 77 | "\n", 78 | "To vastly speed up processing, we can use the GPU available from the Google Colab cloud service. Colab provides a free Tesla K80 GPU of about 12 GB. It’s very easy to enable the GPU in a Colab notebook:\n", 79 | "\n", 80 | "1.\tclick **Runtime** in the top left menu\n", 81 | "2.\tclick **Change runtime** type from the drop-down menu\n", 82 | "3.\tchoose **GPU** from the Hardware accelerator drop-down menu\n", 83 | "4.\tclick **SAVE**" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "KDNFpWCTw7L4" 90 | }, 91 | "source": [ 92 | "Verify that GPU is active:" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "Ck2w8yEvwqpp" 99 | }, 100 | "source": [ 101 | "tf.__version__, tf.test.gpu_device_name()" 102 | ], 103 | "execution_count": null, 104 | "outputs": [] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "id": "bu30Ldwuknh0" 110 | }, 111 | "source": [ 112 | "# Stacked Autoencoders\n", 113 | "\n", 114 | "**Stacked encoders** have multiple hidden layers. The architecture is typically symmetrical with regard to the central hidden layer, which is called the *coding layer*. " 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "7YmINgYRlvZw" 121 | }, 122 | "source": [ 123 | "## Load Data\n", 124 | "\n", 125 | "Load Fashion-Mnist as Numpy arrays:" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "metadata": { 131 | "id": "ApZDmWIDknms" 132 | }, 133 | "source": [ 134 | "import tensorflow_datasets as tfds\n", 135 | "\n", 136 | "(x_train_img, _), (x_test_img, _) = tfds.as_numpy(\n", 137 | " tfds.load('fashion_mnist', split=['train','test'],\n", 138 | " batch_size=-1, as_supervised=True,\n", 139 | " try_gcs=True))" 140 | ], 141 | "execution_count": null, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "id": "4GCte2cfmWzA" 148 | }, 149 | "source": [ 150 | "Notice that we don't load the labels because autoencoders are unsupervised models." 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": { 156 | "id": "4M6qeXremM6o" 157 | }, 158 | "source": [ 159 | "## Scale" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": { 165 | "id": "VtnDyBlaQTsT" 166 | }, 167 | "source": [ 168 | "Scale by dividing datasets by the number of pixels that represent an image:" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "metadata": { 174 | "id": "E_2AvNWMmM_g" 175 | }, 176 | "source": [ 177 | "import numpy as np\n", 178 | "\n", 179 | "x_train, x_test = x_train_img.astype(np.float32) / 255.,\\\n", 180 | " x_test_img.astype(np.float32) / 255." 181 | ], 182 | "execution_count": null, 183 | "outputs": [] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "id": "Q1Vqfe3vkn0T" 189 | }, 190 | "source": [ 191 | "## Clear Previous Models and Generate Seed" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": { 197 | "id": "QKEPn9d0QdBf" 198 | }, 199 | "source": [ 200 | "Clear previous model sessions and generate a seed for reproducibility:" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "oJxvYwiNkn5k" 207 | }, 208 | "source": [ 209 | "tf.keras.backend.clear_session()\n", 210 | "np.random.seed(0)\n", 211 | "tf.random.set_seed(0)" 212 | ], 213 | "execution_count": null, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": { 219 | "id": "iEgBWvh0mfqA" 220 | }, 221 | "source": [ 222 | "## Get Input Shape" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": { 228 | "id": "q5MYhJRXQuNI" 229 | }, 230 | "source": [ 231 | "Get input shape for use in the model:" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "metadata": { 237 | "id": "ag2ff2YCmfuw" 238 | }, 239 | "source": [ 240 | "in_shape = x_train.shape[1:]\n", 241 | "in_shape" 242 | ], 243 | "execution_count": null, 244 | "outputs": [] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": { 249 | "id": "mEp-duponITa" 250 | }, 251 | "source": [ 252 | "## Build Stacked Autoencoder\n", 253 | "\n", 254 | "Stacked encoders have multiple hidden layers. The architecture is typically symmetrical with regard to the central hidden layer, which is the coding layer. We split the autoencoder model into the encoder and decoder." 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "id": "lcOoNB8gnLa4" 261 | }, 262 | "source": [ 263 | "Import libraries:" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "TEc2w0SknIYE" 270 | }, 271 | "source": [ 272 | "from tensorflow.keras.models import Sequential\n", 273 | "from tensorflow.keras.layers import Dense, Flatten,\\\n", 274 | " Reshape" 275 | ], 276 | "execution_count": null, 277 | "outputs": [] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": { 282 | "id": "DPUvM3ILNooj" 283 | }, 284 | "source": [ 285 | "In our example, the encoder accepts 28 x 28 pixel grayscale images, flattens them so that each image is represented as a vector of size 784, and processes the vectors through three Dense layers of diminishing sizes (128 units to 64 units to 32 units). The 32 unit layer is the coding layer (central hidden layer). For each input image, the encoder outputs a vector of size 32. " 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "metadata": { 291 | "id": "tWBXI4mrnNbA" 292 | }, 293 | "source": [ 294 | "stacked_encoder = Sequential([\n", 295 | " Flatten(input_shape=in_shape),\n", 296 | " Dense(128, activation='relu'),\n", 297 | " Dense(64, activation='relu'),\n", 298 | " Dense(32, activation='relu')\n", 299 | "])" 300 | ], 301 | "execution_count": null, 302 | "outputs": [] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": { 307 | "id": "B-PIKnrjCz1g" 308 | }, 309 | "source": [ 310 | "" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "id": "V8iLbmAMP5GF" 317 | }, 318 | "source": [ 319 | "The decoder accepts codings of size 32 (output by the encoder) and processes them through three Dense layers of increasing sizes (64 units to 128 units to 784 units). It then reshapes the final vectors into 28 x 28 arrays so the decoder's outputs have the same shape as the encoder's inputs. " 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "metadata": { 325 | "id": "hclEUvaGn3sy" 326 | }, 327 | "source": [ 328 | "stacked_decoder = Sequential([\n", 329 | " Dense(64, activation='relu'),\n", 330 | " Dense(128, activation='relu'),\n", 331 | " Dense(28 * 28, activation='sigmoid'),\n", 332 | " Reshape(in_shape)\n", 333 | "])" 334 | ], 335 | "execution_count": null, 336 | "outputs": [] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "id": "liyEWDUvRJhp" 342 | }, 343 | "source": [ 344 | "## Create Stacked Autoencoder" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": { 350 | "id": "ZdIXWN52QxGo" 351 | }, 352 | "source": [ 353 | "Create stacked autoencoder based on stacked encoder and decoder:" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "metadata": { 359 | "id": "ifu3mFoQRJn5" 360 | }, 361 | "source": [ 362 | "stacked_ae = Sequential([stacked_encoder, stacked_decoder])" 363 | ], 364 | "execution_count": null, 365 | "outputs": [] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": { 370 | "id": "eeRzOzGPShGZ" 371 | }, 372 | "source": [ 373 | "## Create Appropriate Metric" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": { 379 | "id": "cpvOgfwKQ8h4" 380 | }, 381 | "source": [ 382 | "Create metric to track model performance:" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "metadata": { 388 | "id": "OJv88Ux4ShK-" 389 | }, 390 | "source": [ 391 | "def rounded_accuracy(y_true, y_pred):\n", 392 | " return tf.keras.metrics.binary_accuracy(tf.round(y_true),\n", 393 | " tf.round(y_pred))" 394 | ], 395 | "execution_count": null, 396 | "outputs": [] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": { 401 | "id": "O_fljpRtS9cf" 402 | }, 403 | "source": [ 404 | "The *accuracy* metric won't work properly since it expects labels to be either 0 or 1 for each pixel." 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "id": "2CTCiWJGRPRd" 411 | }, 412 | "source": [ 413 | "## Compile" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": { 419 | "id": "Y4LNvTCKRHGI" 420 | }, 421 | "source": [ 422 | "Use **binary crossentropy** as the loss function because the reconstruction task is a multilabel binary classification problem since each pixel intensity represents the probability that the pixel should be black." 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "metadata": { 428 | "id": "TUBTFvjvRUYN" 429 | }, 430 | "source": [ 431 | "opt = tf.keras.optimizers.SGD(lr=1.5)\n", 432 | "\n", 433 | "stacked_ae.compile(\n", 434 | " loss='binary_crossentropy',\n", 435 | " optimizer=opt, metrics=[rounded_accuracy])" 436 | ], 437 | "execution_count": null, 438 | "outputs": [] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "id": "ue8jUqULRPX8" 444 | }, 445 | "source": [ 446 | "## Train" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": { 452 | "id": "sEpZXTyvRNVQ" 453 | }, 454 | "source": [ 455 | "Train the model using x_train as both the input and the target. The encoder will learn to compress the dataset from 784 dimensions to the latent space, and the decoder will learn to reconstruct the original images." 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "metadata": { 461 | "id": "ewb2_9qfnZ7g" 462 | }, 463 | "source": [ 464 | "sae_history = stacked_ae.fit(\n", 465 | " x_train, x_train, epochs=10,\n", 466 | " validation_data=(x_test, x_test))" 467 | ], 468 | "execution_count": null, 469 | "outputs": [] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": { 474 | "id": "c_OVtA3CfBtP" 475 | }, 476 | "source": [ 477 | "## Visualize Performance" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": { 483 | "id": "hDiMBqXjgGMT" 484 | }, 485 | "source": [ 486 | "Import a plotting library:" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "metadata": { 492 | "id": "EaNUbjaAgGTF" 493 | }, 494 | "source": [ 495 | "import matplotlib.pyplot as plt" 496 | ], 497 | "execution_count": null, 498 | "outputs": [] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "id": "k6L_GZqWfPsA" 504 | }, 505 | "source": [ 506 | "Create a visualization function:" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "metadata": { 512 | "id": "dw-J5VyifB9v" 513 | }, 514 | "source": [ 515 | "def viz_history(training_history):\n", 516 | " loss = training_history.history['loss']\n", 517 | " val_loss = training_history.history['val_loss']\n", 518 | " accuracy = training_history.history['rounded_accuracy']\n", 519 | " val_accuracy = training_history.history['val_rounded_accuracy']\n", 520 | " plt.figure(figsize=(14, 4))\n", 521 | " plt.subplot(1, 2, 1)\n", 522 | " plt.title('Loss')\n", 523 | " plt.xlabel('Epoch')\n", 524 | " plt.ylabel('Loss')\n", 525 | " plt.plot(loss, label='Training set')\n", 526 | " plt.plot(val_loss, label='Test set', linestyle='--')\n", 527 | " plt.legend()\n", 528 | " plt.grid(linestyle='--', linewidth=1, alpha=0.5)\n", 529 | " plt.subplot(1, 2, 2)\n", 530 | " plt.title('Accuracy')\n", 531 | " plt.xlabel('Epoch')\n", 532 | " plt.ylabel('Accuracy')\n", 533 | " plt.plot(accuracy, label='Training set')\n", 534 | " plt.plot(val_accuracy, label='Test set', linestyle='--')\n", 535 | " plt.legend()\n", 536 | " plt.grid(linestyle='--', linewidth=1, alpha=0.5)\n", 537 | " plt.show()" 538 | ], 539 | "execution_count": null, 540 | "outputs": [] 541 | }, 542 | { 543 | "cell_type": "markdown", 544 | "metadata": { 545 | "id": "2Eupv-GufWHQ" 546 | }, 547 | "source": [ 548 | "Visualize:" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "metadata": { 554 | "id": "kbmjTPrWfWTr" 555 | }, 556 | "source": [ 557 | "viz_history(sae_history)" 558 | ], 559 | "execution_count": null, 560 | "outputs": [] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": { 565 | "id": "XJ6-Si7hpKkj" 566 | }, 567 | "source": [ 568 | "## Visualize the Reconstructions" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": { 574 | "id": "SHxfAA6YYdkO" 575 | }, 576 | "source": [ 577 | "Create a function to plot a grayscale 28x28 image:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "metadata": { 583 | "id": "LUbVPLwMnNiw" 584 | }, 585 | "source": [ 586 | "import matplotlib.pyplot as plt\n", 587 | "\n", 588 | "def plot_image(image):\n", 589 | " plt.imshow(image, cmap='binary')\n", 590 | " plt.axis('off')" 591 | ], 592 | "execution_count": null, 593 | "outputs": [] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": { 598 | "id": "EQNfjeUoYl_V" 599 | }, 600 | "source": [ 601 | "Create a function to visualize original images and reconstructions:" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "metadata": { 607 | "id": "qw0VMXZWYmD3" 608 | }, 609 | "source": [ 610 | "def show_reconstructions(model, images, n_images):\n", 611 | " reconstructions = model.predict(images[:n_images])\n", 612 | " reconstructions = tf.squeeze(reconstructions) # drop '1' dimension\n", 613 | " fig = plt.figure(figsize=(n_images * 1.5, 3))\n", 614 | " for image_index in range(n_images):\n", 615 | " plt.subplot(2, n_images, 1 + image_index)\n", 616 | " plot_image(images[image_index])\n", 617 | " plt.subplot(2, n_images, 1 + n_images + image_index)\n", 618 | " plot_image(reconstructions[image_index])" 619 | ], 620 | "execution_count": null, 621 | "outputs": [] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": { 626 | "id": "ivpkpPKorIK-" 627 | }, 628 | "source": [ 629 | "The predict() function adds the *1* dimension back." 630 | ] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "metadata": { 635 | "id": "41eHnXvPXjQt" 636 | }, 637 | "source": [ 638 | "Check dimensionality of test data:" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "metadata": { 644 | "id": "Ld6X1U0mXj2r" 645 | }, 646 | "source": [ 647 | "x_test.shape" 648 | ], 649 | "execution_count": null, 650 | "outputs": [] 651 | }, 652 | { 653 | "cell_type": "markdown", 654 | "metadata": { 655 | "id": "iM0YrgNUY1k8" 656 | }, 657 | "source": [ 658 | "To visualize with imshow(), we must remove dimensions of size 1 from the shape of a tensor:" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "metadata": { 664 | "id": "iP7PY-F3W4Zn" 665 | }, 666 | "source": [ 667 | "x_test_imgs = tf.squeeze(x_test)\n", 668 | "x_test_imgs.shape" 669 | ], 670 | "execution_count": null, 671 | "outputs": [] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "metadata": { 676 | "id": "66BnQMsgY9c8" 677 | }, 678 | "source": [ 679 | "Visualize:" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "metadata": { 685 | "id": "eHU-BTaHY9hN" 686 | }, 687 | "source": [ 688 | "show_reconstructions(stacked_ae, x_test_imgs, 6)" 689 | ], 690 | "execution_count": null, 691 | "outputs": [] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "metadata": { 696 | "id": "toLFhTAeVMVT" 697 | }, 698 | "source": [ 699 | "Reconstructed images are generated from **test images** based on predictions from the trained model." 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": { 705 | "id": "QzvvHELFQddh" 706 | }, 707 | "source": [ 708 | "## Breakdown" 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "metadata": { 714 | "id": "TQCqMJgJWaPh" 715 | }, 716 | "source": [ 717 | "Grab an image from the test set:" 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "metadata": { 723 | "id": "b7u4MV4wWaXf" 724 | }, 725 | "source": [ 726 | "img = x_test[:1]" 727 | ], 728 | "execution_count": null, 729 | "outputs": [] 730 | }, 731 | { 732 | "cell_type": "markdown", 733 | "metadata": { 734 | "id": "iZ5ReIMgWhDn" 735 | }, 736 | "source": [ 737 | "Since the prediction method computations are done in batches, we grab the first image as a batch of one." 738 | ] 739 | }, 740 | { 741 | "cell_type": "markdown", 742 | "metadata": { 743 | "id": "RV1zv9G7W3C4" 744 | }, 745 | "source": [ 746 | "Make a prediction based on the image batch:" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "metadata": { 752 | "id": "Vtv5Om7mW3Jo" 753 | }, 754 | "source": [ 755 | "reconstruction = stacked_ae.predict(img)" 756 | ], 757 | "execution_count": null, 758 | "outputs": [] 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "metadata": { 763 | "id": "UGWFUAHWXA8x" 764 | }, 765 | "source": [ 766 | "Drop the '1' dimension:" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "metadata": { 772 | "id": "biRGfidGXBDo" 773 | }, 774 | "source": [ 775 | "reconstruction = tf.squeeze(reconstruction)" 776 | ], 777 | "execution_count": null, 778 | "outputs": [] 779 | }, 780 | { 781 | "cell_type": "markdown", 782 | "metadata": { 783 | "id": "SnAMprQpXiJ6" 784 | }, 785 | "source": [ 786 | "Plot reconstruction:" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "metadata": { 792 | "id": "jUuXMexZXiPi" 793 | }, 794 | "source": [ 795 | "plot_image(reconstruction)" 796 | ], 797 | "execution_count": null, 798 | "outputs": [] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "metadata": { 803 | "id": "GG_ijIv_X30i" 804 | }, 805 | "source": [ 806 | "Plot actual image:" 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "metadata": { 812 | "id": "UcHDU2BbX39O" 813 | }, 814 | "source": [ 815 | "plot_image(tf.squeeze(x_test[0]))" 816 | ], 817 | "execution_count": null, 818 | "outputs": [] 819 | }, 820 | { 821 | "cell_type": "markdown", 822 | "metadata": { 823 | "id": "pmXnnHOrYHh7" 824 | }, 825 | "source": [ 826 | "We squeeze out the '1' dimension from the image to plot." 827 | ] 828 | }, 829 | { 830 | "cell_type": "markdown", 831 | "metadata": { 832 | "id": "8cA-2bN3ZQfG" 833 | }, 834 | "source": [ 835 | "## Visualize with Dimensionality Reduction" 836 | ] 837 | }, 838 | { 839 | "cell_type": "markdown", 840 | "metadata": { 841 | "id": "u9nwWHBUbvtw" 842 | }, 843 | "source": [ 844 | "To perform dimensionality reduction, we need labels. So load labels from the **test data** set:" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "metadata": { 850 | "id": "nmi0J281a1b_" 851 | }, 852 | "source": [ 853 | "test = tfds.as_numpy(\n", 854 | " tfds.load('fashion_mnist', split=['test'],\n", 855 | " batch_size=-1, as_supervised=True,\n", 856 | " try_gcs=True))" 857 | ], 858 | "execution_count": null, 859 | "outputs": [] 860 | }, 861 | { 862 | "cell_type": "markdown", 863 | "metadata": { 864 | "id": "B6TjU308b6j5" 865 | }, 866 | "source": [ 867 | "Slice test labels from the test data set:" 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "metadata": { 873 | "id": "N-13pGxGbBzu" 874 | }, 875 | "source": [ 876 | "y_test = test[0][1]" 877 | ], 878 | "execution_count": null, 879 | "outputs": [] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": { 884 | "id": "QPPjWPuucBro" 885 | }, 886 | "source": [ 887 | "Use the encoder to reduce dimensionality to 32:" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "metadata": { 893 | "id": "ANm8WaY8ZWAu" 894 | }, 895 | "source": [ 896 | "from sklearn.manifold import TSNE\n", 897 | "\n", 898 | "np.random.seed(0)\n", 899 | "x_test_compressed = stacked_encoder.predict(x_test_imgs)\n", 900 | "tsne = TSNE()\n", 901 | "x_test_2D = tsne.fit_transform(x_test_compressed)\n", 902 | "x_test_2D = (x_test_2D - x_test_2D.min()) /\\\n", 903 | " (x_test_2D.max() - x_test_2D.min())" 904 | ], 905 | "execution_count": null, 906 | "outputs": [] 907 | }, 908 | { 909 | "cell_type": "markdown", 910 | "metadata": { 911 | "id": "TwbwnD_fcnFc" 912 | }, 913 | "source": [ 914 | "We used Scikit-Learn's implementation of the t-SNE algorithm to reduce dimensionality to 2D for visualization." 915 | ] 916 | }, 917 | { 918 | "cell_type": "markdown", 919 | "metadata": { 920 | "id": "eDFda0egZWGn" 921 | }, 922 | "source": [ 923 | "Visualize:" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "metadata": { 929 | "id": "gDty2umlZWL3" 930 | }, 931 | "source": [ 932 | "plt.scatter(x_test_2D[:, 0], x_test_2D[:, 1],\n", 933 | " c=y_test, s=10, cmap='tab10')\n", 934 | "plt.axis('off')\n", 935 | "plt.show()" 936 | ], 937 | "execution_count": null, 938 | "outputs": [] 939 | }, 940 | { 941 | "cell_type": "markdown", 942 | "metadata": { 943 | "id": "ColxRICW3mcp" 944 | }, 945 | "source": [ 946 | "Each class is represented by a different color." 947 | ] 948 | }, 949 | { 950 | "cell_type": "markdown", 951 | "metadata": { 952 | "id": "skDn5YtdZWQE" 953 | }, 954 | "source": [ 955 | "Display a prettier visualization:" 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "metadata": { 961 | "id": "Waf-0o9sZWUU" 962 | }, 963 | "source": [ 964 | "import matplotlib as mpl\n", 965 | "\n", 966 | "plt.figure(figsize=(10, 8))\n", 967 | "cmap = plt.cm.tab10\n", 968 | "plt.scatter(x_test_2D[:, 0], x_test_2D[:, 1],\n", 969 | " c=y_test, s=10, cmap=cmap)\n", 970 | "image_positions = np.array([[1., 1.]])\n", 971 | "for index, position in enumerate(x_test_2D):\n", 972 | " dist = np.sum((position - image_positions) ** 2, axis=1)\n", 973 | " if np.min(dist) > 0.02: # if far enough from other images\n", 974 | " image_positions = np.r_[image_positions, [position]]\n", 975 | " imagebox = mpl.offsetbox.AnnotationBbox(\n", 976 | " mpl.offsetbox.OffsetImage(x_test_imgs[index],\n", 977 | " cmap='binary'),\n", 978 | " position, bboxprops={\n", 979 | " 'edgecolor': cmap(y_test[index]), 'lw': 2})\n", 980 | " plt.gca().add_artist(imagebox)\n", 981 | "plt.axis('off')\n", 982 | "plt.show()" 983 | ], 984 | "execution_count": null, 985 | "outputs": [] 986 | }, 987 | { 988 | "cell_type": "markdown", 989 | "metadata": { 990 | "id": "n1KvNqRDdxQT" 991 | }, 992 | "source": [ 993 | "Adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html" 994 | ] 995 | }, 996 | { 997 | "cell_type": "markdown", 998 | "metadata": { 999 | "id": "81OIsQqwgW_E" 1000 | }, 1001 | "source": [ 1002 | "# Tying Weights\n", 1003 | "\n", 1004 | "When an autoencoder is neatly symmetrical, we can tie the weights of the decoder layers to the weights of the encoder layers. As a result, we halve the number of weights in the model, which speeds training and reduces overfitting. " 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "markdown", 1009 | "metadata": { 1010 | "id": "Ia_hCvPvg0rW" 1011 | }, 1012 | "source": [ 1013 | "## Define a Custom Layer\n", 1014 | "\n", 1015 | "To tie the weights of the encoder and the decoder, we use the transpose of the encoder's weights as the decoder weights:" 1016 | ] 1017 | }, 1018 | { 1019 | "cell_type": "code", 1020 | "metadata": { 1021 | "id": "5ezaYhKWg0ws" 1022 | }, 1023 | "source": [ 1024 | "class DenseTranspose(tf.keras.layers.Layer):\n", 1025 | " def __init__(self, dense, activation=None, **kwargs):\n", 1026 | " self.dense = dense\n", 1027 | " self.activation = tf.keras.activations.get(activation)\n", 1028 | " super().__init__(**kwargs)\n", 1029 | " def build(self, batch_input_shape):\n", 1030 | " self.biases = self.add_weight(\n", 1031 | " name='bias', shape=[self.dense.input_shape[-1]],\n", 1032 | " initializer='zeros')\n", 1033 | " super().build(batch_input_shape)\n", 1034 | " def call(self, inputs):\n", 1035 | " z = tf.matmul(\n", 1036 | " inputs, self.dense.weights[0], transpose_b=True)\n", 1037 | " return self.activation(z + self.biases)" 1038 | ], 1039 | "execution_count": null, 1040 | "outputs": [] 1041 | }, 1042 | { 1043 | "cell_type": "markdown", 1044 | "metadata": { 1045 | "id": "9kJ76gtmh31c" 1046 | }, 1047 | "source": [ 1048 | "The class accepts a layer from a model, an activation function (if included in a layer), and transposes the data. A lot of times we have to preprocess data fed into machine learning algorithms. The reason is that data may be stored as rows, but the machine learning algorithm expects input as columns or vice versa. So transposition is a very useful operation in machine learning.\n", 1049 | "\n", 1050 | "Resource:\n", 1051 | "\n", 1052 | "https://www.youtube.com/watch?v=QDpeRUIrb6U" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "markdown", 1057 | "metadata": { 1058 | "id": "kvvJoUA5hmYF" 1059 | }, 1060 | "source": [ 1061 | "## Clear Models and Generate Seed" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": { 1067 | "id": "vDe1cgGrTWjL" 1068 | }, 1069 | "source": [ 1070 | "Clear previous model sessions and generate a seed for reproducibility:" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "code", 1075 | "metadata": { 1076 | "id": "eEC-qA_nhtJZ" 1077 | }, 1078 | "source": [ 1079 | "tf.keras.backend.clear_session()\n", 1080 | "np.random.seed(0)\n", 1081 | "tf.random.set_seed(0)" 1082 | ], 1083 | "execution_count": null, 1084 | "outputs": [] 1085 | }, 1086 | { 1087 | "cell_type": "markdown", 1088 | "metadata": { 1089 | "id": "3EgEtc9EjR8e" 1090 | }, 1091 | "source": [ 1092 | "## Create Dense Layers" 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "markdown", 1097 | "metadata": { 1098 | "id": "0G8IdolcTYH-" 1099 | }, 1100 | "source": [ 1101 | "Create three dense layers for the model:" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "code", 1106 | "metadata": { 1107 | "id": "Ea_TPIb0hmek" 1108 | }, 1109 | "source": [ 1110 | "dense_1 = Dense(128, activation='relu')\n", 1111 | "dense_2 = Dense(64, activation='relu')\n", 1112 | "dense_3 = Dense(32, activation='relu')" 1113 | ], 1114 | "execution_count": null, 1115 | "outputs": [] 1116 | }, 1117 | { 1118 | "cell_type": "markdown", 1119 | "metadata": { 1120 | "id": "F2nWgj_XiqIX" 1121 | }, 1122 | "source": [ 1123 | "## Build the Encoder" 1124 | ] 1125 | }, 1126 | { 1127 | "cell_type": "markdown", 1128 | "metadata": { 1129 | "id": "VTgCa35dTdOL" 1130 | }, 1131 | "source": [ 1132 | "Build the encoder with three dense layers:" 1133 | ] 1134 | }, 1135 | { 1136 | "cell_type": "code", 1137 | "metadata": { 1138 | "id": "GkFh3niCiqM-" 1139 | }, 1140 | "source": [ 1141 | "tied_encoder = Sequential([\n", 1142 | " Flatten(input_shape=in_shape),\n", 1143 | " dense_1,\n", 1144 | " dense_2,\n", 1145 | " dense_3\n", 1146 | "])" 1147 | ], 1148 | "execution_count": null, 1149 | "outputs": [] 1150 | }, 1151 | { 1152 | "cell_type": "markdown", 1153 | "metadata": { 1154 | "id": "Mp2_HywNi-zf" 1155 | }, 1156 | "source": [ 1157 | "## Build the Decoder" 1158 | ] 1159 | }, 1160 | { 1161 | "cell_type": "markdown", 1162 | "metadata": { 1163 | "id": "k2afrMWGTk6J" 1164 | }, 1165 | "source": [ 1166 | "Build the decoder and tie weights with the encoder:" 1167 | ] 1168 | }, 1169 | { 1170 | "cell_type": "code", 1171 | "metadata": { 1172 | "id": "VfF5Aog8i-3w" 1173 | }, 1174 | "source": [ 1175 | "tied_decoder = Sequential([\n", 1176 | " DenseTranspose(dense_3, activation='relu'),\n", 1177 | " DenseTranspose(dense_2, activation='relu'),\n", 1178 | " DenseTranspose(dense_1, activation='sigmoid'),\n", 1179 | " Reshape([28, 28])\n", 1180 | "])" 1181 | ], 1182 | "execution_count": null, 1183 | "outputs": [] 1184 | }, 1185 | { 1186 | "cell_type": "markdown", 1187 | "metadata": { 1188 | "id": "zmf6CEqthEQk" 1189 | }, 1190 | "source": [ 1191 | "## Build Tied Model" 1192 | ] 1193 | }, 1194 | { 1195 | "cell_type": "markdown", 1196 | "metadata": { 1197 | "id": "r55Rl20lT9cz" 1198 | }, 1199 | "source": [ 1200 | "Build the model with tied weights between the encoder and decoder:" 1201 | ] 1202 | }, 1203 | { 1204 | "cell_type": "code", 1205 | "metadata": { 1206 | "id": "hzpPA3YngXDU" 1207 | }, 1208 | "source": [ 1209 | "tied_ae = Sequential([tied_encoder, tied_decoder])" 1210 | ], 1211 | "execution_count": null, 1212 | "outputs": [] 1213 | }, 1214 | { 1215 | "cell_type": "markdown", 1216 | "metadata": { 1217 | "id": "YX57YLgHgXIE" 1218 | }, 1219 | "source": [ 1220 | "## Compile" 1221 | ] 1222 | }, 1223 | { 1224 | "cell_type": "markdown", 1225 | "metadata": { 1226 | "id": "bPOx6UuKUJBa" 1227 | }, 1228 | "source": [ 1229 | "Compile with **binary crossentropy**:" 1230 | ] 1231 | }, 1232 | { 1233 | "cell_type": "code", 1234 | "metadata": { 1235 | "id": "hsWbEuuDgXMY" 1236 | }, 1237 | "source": [ 1238 | "tied_ae.compile(loss='binary_crossentropy',\n", 1239 | " optimizer=opt, metrics=[rounded_accuracy])" 1240 | ], 1241 | "execution_count": null, 1242 | "outputs": [] 1243 | }, 1244 | { 1245 | "cell_type": "markdown", 1246 | "metadata": { 1247 | "id": "4McSoQXwgXQM" 1248 | }, 1249 | "source": [ 1250 | "## Train" 1251 | ] 1252 | }, 1253 | { 1254 | "cell_type": "markdown", 1255 | "metadata": { 1256 | "id": "AG5slaILUNQ6" 1257 | }, 1258 | "source": [ 1259 | "Train for ten epochs:" 1260 | ] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "metadata": { 1265 | "id": "mqwtXHEdj1qh" 1266 | }, 1267 | "source": [ 1268 | "tied_history = tied_ae.fit(\n", 1269 | " x_train, x_train, epochs=10,\n", 1270 | " validation_data=(x_test, x_test))" 1271 | ], 1272 | "execution_count": null, 1273 | "outputs": [] 1274 | }, 1275 | { 1276 | "cell_type": "markdown", 1277 | "metadata": { 1278 | "id": "m1NMH9sOnUCI" 1279 | }, 1280 | "source": [ 1281 | "Visualize training performance:" 1282 | ] 1283 | }, 1284 | { 1285 | "cell_type": "code", 1286 | "metadata": { 1287 | "id": "GyzfeJeqnUKn" 1288 | }, 1289 | "source": [ 1290 | "viz_history(tied_history)" 1291 | ], 1292 | "execution_count": null, 1293 | "outputs": [] 1294 | }, 1295 | { 1296 | "cell_type": "markdown", 1297 | "metadata": { 1298 | "id": "sVdclL6Ms7ma" 1299 | }, 1300 | "source": [ 1301 | "## Visualize Reconstructions" 1302 | ] 1303 | }, 1304 | { 1305 | "cell_type": "markdown", 1306 | "metadata": { 1307 | "id": "JnAgL14hUToB" 1308 | }, 1309 | "source": [ 1310 | "Show test image reconstructions based on predictions from the trained model:" 1311 | ] 1312 | }, 1313 | { 1314 | "cell_type": "code", 1315 | "metadata": { 1316 | "id": "lrq46303s7x5" 1317 | }, 1318 | "source": [ 1319 | "show_reconstructions(tied_ae, x_test_imgs, 6)\n", 1320 | "plt.show()" 1321 | ], 1322 | "execution_count": null, 1323 | "outputs": [] 1324 | }, 1325 | { 1326 | "cell_type": "markdown", 1327 | "metadata": { 1328 | "id": "8sgF6zHqtQRN" 1329 | }, 1330 | "source": [ 1331 | "# Denoising Autoencoders\n", 1332 | "\n", 1333 | "An autoencoder can also be trained to remove noise from images. We can add noise to inputs and train to recover the original noise-free inputs." 1334 | ] 1335 | }, 1336 | { 1337 | "cell_type": "markdown", 1338 | "metadata": { 1339 | "id": "FWqaa8yc2ire" 1340 | }, 1341 | "source": [ 1342 | "## Clear Model and Generate Seed" 1343 | ] 1344 | }, 1345 | { 1346 | "cell_type": "markdown", 1347 | "metadata": { 1348 | "id": "GAzFQUZPUiMF" 1349 | }, 1350 | "source": [ 1351 | "Clear previous model sessions and generate a seed for reproducibility:" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "code", 1356 | "metadata": { 1357 | "id": "0uzM_ovf2ix4" 1358 | }, 1359 | "source": [ 1360 | "tf.keras.backend.clear_session()\n", 1361 | "np.random.seed(0)\n", 1362 | "tf.random.set_seed(0)" 1363 | ], 1364 | "execution_count": null, 1365 | "outputs": [] 1366 | }, 1367 | { 1368 | "cell_type": "markdown", 1369 | "metadata": { 1370 | "id": "AOGGlaVKuTCV" 1371 | }, 1372 | "source": [ 1373 | "## Build the Encoder with Gaussian Noise" 1374 | ] 1375 | }, 1376 | { 1377 | "cell_type": "markdown", 1378 | "metadata": { 1379 | "id": "JB3f0LzfUkWa" 1380 | }, 1381 | "source": [ 1382 | "Add pure Gaussian noise directly in the encoder:" 1383 | ] 1384 | }, 1385 | { 1386 | "cell_type": "code", 1387 | "metadata": { 1388 | "id": "5QUF0DKhtQfr" 1389 | }, 1390 | "source": [ 1391 | "from tensorflow.keras.layers import GaussianNoise\n", 1392 | "\n", 1393 | "gaussian_encoder = Sequential([\n", 1394 | " Flatten(input_shape=in_shape),\n", 1395 | " GaussianNoise(0.2),\n", 1396 | " dense_1,\n", 1397 | " dense_2,\n", 1398 | " dense_3\n", 1399 | "])" 1400 | ], 1401 | "execution_count": null, 1402 | "outputs": [] 1403 | }, 1404 | { 1405 | "cell_type": "markdown", 1406 | "metadata": { 1407 | "id": "tdjwbm_-vPYK" 1408 | }, 1409 | "source": [ 1410 | "## Build the Decoder" 1411 | ] 1412 | }, 1413 | { 1414 | "cell_type": "markdown", 1415 | "metadata": { 1416 | "id": "Jpwe3K5BUrLl" 1417 | }, 1418 | "source": [ 1419 | "Tie the weights of the decoder layers to the weights of the encoder layers:" 1420 | ] 1421 | }, 1422 | { 1423 | "cell_type": "code", 1424 | "metadata": { 1425 | "id": "WuCsUfsOvS__" 1426 | }, 1427 | "source": [ 1428 | "gaussian_decoder = Sequential([\n", 1429 | " DenseTranspose(dense_3, activation='relu'),\n", 1430 | " DenseTranspose(dense_2, activation='relu'),\n", 1431 | " DenseTranspose(dense_1, activation='sigmoid'),\n", 1432 | " Reshape([28, 28])\n", 1433 | "])" 1434 | ], 1435 | "execution_count": null, 1436 | "outputs": [] 1437 | }, 1438 | { 1439 | "cell_type": "markdown", 1440 | "metadata": { 1441 | "id": "YIve4G3rvjla" 1442 | }, 1443 | "source": [ 1444 | "## Build the Denoising Autoencoder" 1445 | ] 1446 | }, 1447 | { 1448 | "cell_type": "markdown", 1449 | "metadata": { 1450 | "id": "iIzpTj_sUvNy" 1451 | }, 1452 | "source": [ 1453 | "Build the denoising autoencoder from the gaussian encoder and decoder:" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "code", 1458 | "metadata": { 1459 | "id": "t2DmmOHvvjsl" 1460 | }, 1461 | "source": [ 1462 | "gaussian_ae = Sequential([gaussian_encoder, gaussian_decoder])" 1463 | ], 1464 | "execution_count": null, 1465 | "outputs": [] 1466 | }, 1467 | { 1468 | "cell_type": "markdown", 1469 | "metadata": { 1470 | "id": "LcqdzgAyvjw5" 1471 | }, 1472 | "source": [ 1473 | "## Compile" 1474 | ] 1475 | }, 1476 | { 1477 | "cell_type": "markdown", 1478 | "metadata": { 1479 | "id": "nrb75ynCU1-7" 1480 | }, 1481 | "source": [ 1482 | "Compile with **binary crossentropy**:" 1483 | ] 1484 | }, 1485 | { 1486 | "cell_type": "code", 1487 | "metadata": { 1488 | "id": "Y9UkDRUQwJlm" 1489 | }, 1490 | "source": [ 1491 | "gaussian_ae.compile(\n", 1492 | " loss='binary_crossentropy',\n", 1493 | " optimizer=opt, metrics=[rounded_accuracy])" 1494 | ], 1495 | "execution_count": null, 1496 | "outputs": [] 1497 | }, 1498 | { 1499 | "cell_type": "markdown", 1500 | "metadata": { 1501 | "id": "Dr1VBHJkwNM8" 1502 | }, 1503 | "source": [ 1504 | "## Train" 1505 | ] 1506 | }, 1507 | { 1508 | "cell_type": "markdown", 1509 | "metadata": { 1510 | "id": "N46vycU7U7eK" 1511 | }, 1512 | "source": [ 1513 | "Train model for ten epochs:" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "metadata": { 1519 | "id": "LGRxQqPzvj1g" 1520 | }, 1521 | "source": [ 1522 | "gae_history = gaussian_ae.fit(\n", 1523 | " x_train, x_train, epochs=10,\n", 1524 | " validation_data=(x_test, x_test))" 1525 | ], 1526 | "execution_count": null, 1527 | "outputs": [] 1528 | }, 1529 | { 1530 | "cell_type": "markdown", 1531 | "metadata": { 1532 | "id": "HWyIM-xLuoYE" 1533 | }, 1534 | "source": [ 1535 | "Visualize training performance:" 1536 | ] 1537 | }, 1538 | { 1539 | "cell_type": "code", 1540 | "metadata": { 1541 | "id": "n9vrqyWouodU" 1542 | }, 1543 | "source": [ 1544 | "viz_history(tied_history)" 1545 | ], 1546 | "execution_count": null, 1547 | "outputs": [] 1548 | }, 1549 | { 1550 | "cell_type": "markdown", 1551 | "metadata": { 1552 | "id": "q4-zY0kevj5Z" 1553 | }, 1554 | "source": [ 1555 | "## Visualize Reconstructions" 1556 | ] 1557 | }, 1558 | { 1559 | "cell_type": "markdown", 1560 | "metadata": { 1561 | "id": "EWcSldeyVA9C" 1562 | }, 1563 | "source": [ 1564 | "Add the same amount of Gaussian noise to **test** images:" 1565 | ] 1566 | }, 1567 | { 1568 | "cell_type": "code", 1569 | "metadata": { 1570 | "id": "_rPYMlMovj9r" 1571 | }, 1572 | "source": [ 1573 | "noise = GaussianNoise(0.2)\n", 1574 | "show_reconstructions(gaussian_ae, noise(x_test_imgs), 6)\n", 1575 | "plt.show()" 1576 | ], 1577 | "execution_count": null, 1578 | "outputs": [] 1579 | }, 1580 | { 1581 | "cell_type": "markdown", 1582 | "metadata": { 1583 | "id": "QWuNdF9CxYHt" 1584 | }, 1585 | "source": [ 1586 | "# Build the Encoder with Dropout" 1587 | ] 1588 | }, 1589 | { 1590 | "cell_type": "markdown", 1591 | "metadata": { 1592 | "id": "IothDH0fXps2" 1593 | }, 1594 | "source": [ 1595 | "Add dropout directly into the encoder. Dropout adds random noise to the images." 1596 | ] 1597 | }, 1598 | { 1599 | "cell_type": "code", 1600 | "metadata": { 1601 | "id": "8OkuBpiaxYOY" 1602 | }, 1603 | "source": [ 1604 | "from tensorflow.keras.layers import Dropout\n", 1605 | "\n", 1606 | "tf.keras.backend.clear_session()\n", 1607 | "np.random.seed(0)\n", 1608 | "tf.random.set_seed(0)\n", 1609 | "\n", 1610 | "dropout_encoder = Sequential([\n", 1611 | " Flatten(input_shape=in_shape),\n", 1612 | " Dropout(0.5),\n", 1613 | " dense_1,\n", 1614 | " dense_2,\n", 1615 | " dense_3\n", 1616 | "])" 1617 | ], 1618 | "execution_count": null, 1619 | "outputs": [] 1620 | }, 1621 | { 1622 | "cell_type": "markdown", 1623 | "metadata": { 1624 | "id": "k7gBrIomtQj0" 1625 | }, 1626 | "source": [ 1627 | "## Build the Decoder" 1628 | ] 1629 | }, 1630 | { 1631 | "cell_type": "markdown", 1632 | "metadata": { 1633 | "id": "pBmsk4BYXzO2" 1634 | }, 1635 | "source": [ 1636 | "Tie the weights of the decoder layers to the weights of the encoder layers:" 1637 | ] 1638 | }, 1639 | { 1640 | "cell_type": "code", 1641 | "metadata": { 1642 | "id": "7dtoaqDdtQoi" 1643 | }, 1644 | "source": [ 1645 | "dropout_decoder = Sequential([\n", 1646 | " DenseTranspose(dense_3, activation='relu'),\n", 1647 | " DenseTranspose(dense_2, activation='relu'),\n", 1648 | " DenseTranspose(dense_1, activation='sigmoid'),\n", 1649 | " Reshape([28, 28])\n", 1650 | "])" 1651 | ], 1652 | "execution_count": null, 1653 | "outputs": [] 1654 | }, 1655 | { 1656 | "cell_type": "markdown", 1657 | "metadata": { 1658 | "id": "A6daiOrUX13t" 1659 | }, 1660 | "source": [ 1661 | "We tie the weights together because the performance is better." 1662 | ] 1663 | }, 1664 | { 1665 | "cell_type": "markdown", 1666 | "metadata": { 1667 | "id": "5NmAnKSr1mPp" 1668 | }, 1669 | "source": [ 1670 | "## Build the Dropout Autoencoder" 1671 | ] 1672 | }, 1673 | { 1674 | "cell_type": "markdown", 1675 | "metadata": { 1676 | "id": "4wMXDBzcYH_l" 1677 | }, 1678 | "source": [ 1679 | "Build the autoencoder from the dropout encoder and decoder:" 1680 | ] 1681 | }, 1682 | { 1683 | "cell_type": "code", 1684 | "metadata": { 1685 | "id": "Gn-JY_mT1mU9" 1686 | }, 1687 | "source": [ 1688 | "dropout_ae = Sequential([dropout_encoder, dropout_decoder])" 1689 | ], 1690 | "execution_count": null, 1691 | "outputs": [] 1692 | }, 1693 | { 1694 | "cell_type": "markdown", 1695 | "metadata": { 1696 | "id": "V92OPjsD10Y0" 1697 | }, 1698 | "source": [ 1699 | "## Compile" 1700 | ] 1701 | }, 1702 | { 1703 | "cell_type": "markdown", 1704 | "metadata": { 1705 | "id": "0NvOgcYbYNud" 1706 | }, 1707 | "source": [ 1708 | "Compile with **binary crossentropy**:" 1709 | ] 1710 | }, 1711 | { 1712 | "cell_type": "code", 1713 | "metadata": { 1714 | "id": "nJwYCpjb10qZ" 1715 | }, 1716 | "source": [ 1717 | "dropout_ae.compile(\n", 1718 | " loss='binary_crossentropy',\n", 1719 | " optimizer=opt, metrics=[rounded_accuracy])" 1720 | ], 1721 | "execution_count": null, 1722 | "outputs": [] 1723 | }, 1724 | { 1725 | "cell_type": "markdown", 1726 | "metadata": { 1727 | "id": "2CjRl2sj104z" 1728 | }, 1729 | "source": [ 1730 | "## Train" 1731 | ] 1732 | }, 1733 | { 1734 | "cell_type": "markdown", 1735 | "metadata": { 1736 | "id": "B4Q007HoYRU-" 1737 | }, 1738 | "source": [ 1739 | "Train the model for ten epochs:" 1740 | ] 1741 | }, 1742 | { 1743 | "cell_type": "code", 1744 | "metadata": { 1745 | "id": "5mugmk_m109v" 1746 | }, 1747 | "source": [ 1748 | "drop_history = dropout_ae.fit(\n", 1749 | " x_train, x_train, epochs=10,\n", 1750 | " validation_data=(x_test, x_test))" 1751 | ], 1752 | "execution_count": null, 1753 | "outputs": [] 1754 | }, 1755 | { 1756 | "cell_type": "markdown", 1757 | "metadata": { 1758 | "id": "jxgz9a5x2Xv7" 1759 | }, 1760 | "source": [ 1761 | "Visualize performance:" 1762 | ] 1763 | }, 1764 | { 1765 | "cell_type": "code", 1766 | "metadata": { 1767 | "id": "0ewAbOQY2X3U" 1768 | }, 1769 | "source": [ 1770 | "viz_history(drop_history)" 1771 | ], 1772 | "execution_count": null, 1773 | "outputs": [] 1774 | }, 1775 | { 1776 | "cell_type": "markdown", 1777 | "metadata": { 1778 | "id": "IHDqIei44GzR" 1779 | }, 1780 | "source": [ 1781 | "## Visualize Reconstructions" 1782 | ] 1783 | }, 1784 | { 1785 | "cell_type": "markdown", 1786 | "metadata": { 1787 | "id": "Wje_Oz-MYWwj" 1788 | }, 1789 | "source": [ 1790 | "Add the same amount of dropout noise to test images:" 1791 | ] 1792 | }, 1793 | { 1794 | "cell_type": "code", 1795 | "metadata": { 1796 | "id": "afm97V2C4G38" 1797 | }, 1798 | "source": [ 1799 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", 1800 | "\n", 1801 | "dropout = Dropout(0.5)\n", 1802 | "show_reconstructions(dropout_ae, dropout(x_test_imgs), 6)\n", 1803 | "plt.show()" 1804 | ], 1805 | "execution_count": null, 1806 | "outputs": [] 1807 | } 1808 | ] 1809 | } -------------------------------------------------------------------------------- /chapter14/ch14.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "ch14.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "E0DIZQALOE23" 21 | }, 22 | "source": [ 23 | "# Reinforcement Learning\n", 24 | "\n", 25 | "Reinforcement learning is training machine learning models to make a sequence of decisions. A software **agent** makes **observations** and takes actions within an **environment** to receive rewards. Its objective is to learn to act in a way that maximizes its expected rewards over time. \n", 26 | "\n", 27 | "The model employs trial and error to come up with a solution to the problem. It gets either rewards or penalties for the actions it performs with the goal to maximize the total reward. Although the designer sets the reward policy (the rules of the game), he/she gives the model no hints or suggestions for how to solve the game. The model must figure out how to perform the task to maximize the reward." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "l7TtAR5kxYoy" 34 | }, 35 | "source": [ 36 | "# Policy Search\n", 37 | "\n", 38 | "A **policy** is the algorithm a software agent uses to determine its actions. The **policy space** is a mapping from perceived states of the environment to actions to be taken when in those states." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "KyHEcXv30Bhf" 45 | }, 46 | "source": [ 47 | "# OpenAI Gym\n", 48 | "\n", 49 | "We need a working environment to train an agent. *OpenAI Gym* is a toolkit that provides a wide variety of simulated environments including Atari games, board games, 2D and 3d physical simulations, and so on.\n", 50 | "\n", 51 | "Resources:\n", 52 | "\n", 53 | "https://colab.research.google.com/drive/18LdlDDT87eb8cCTHZsXyS9ksQPzL3i6H\n", 54 | "\n", 55 | "https://colab.research.google.com/drive/1flu31ulJlgiRL1dnN2ir8wGh9p7Zij2t\n", 56 | "\n", 57 | "https://stackoverflow.com/questions/50107530/how-to-render-openai-gym-in-google-colab" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "kKvaJr-1ycnw" 64 | }, 65 | "source": [ 66 | "# Install and Configure OpenAI Gym on Colab" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "7idcsVDny7vw" 73 | }, 74 | "source": [ 75 | "Most of the requirements of python packages are already fulfilled on CoLab. To run Gym, install prerequisites:" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "POuTao38ycwe" 82 | }, 83 | "source": [ 84 | "!pip install gym\n", 85 | "!apt-get install python-opengl -y\n", 86 | "!apt install xvfb -y" 87 | ], 88 | "execution_count": null, 89 | "outputs": [] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": { 94 | "id": "qeq9tJI6zWPQ" 95 | }, 96 | "source": [ 97 | "For the rendering environment, use pyvirtualdisplay:" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "id": "Lm_X9tCMzWxG" 104 | }, 105 | "source": [ 106 | "!pip install pyvirtualdisplay\n", 107 | "!pip install piglet" 108 | ], 109 | "execution_count": null, 110 | "outputs": [] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "ioHCOdEcTX1X" 116 | }, 117 | "source": [ 118 | "# Import **tensorflow** library" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "id": "VpygsU5NTY4O" 125 | }, 126 | "source": [ 127 | "Import library and alias it:" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "metadata": { 133 | "id": "prmzt6Q0TdMM" 134 | }, 135 | "source": [ 136 | "import tensorflow as tf" 137 | ], 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": { 144 | "id": "z1PaVreZTcPm" 145 | }, 146 | "source": [ 147 | "# GPU Hardware Accelerator\n", 148 | "\n", 149 | "To vastly speed up processing, we can use the GPU available from the Google Colab cloud service. Colab provides a free Tesla K80 GPU of about 12 GB. It’s very easy to enable the GPU in a Colab notebook:\n", 150 | "\n", 151 | "1.\tclick **Runtime** in the top left menu\n", 152 | "2.\tclick **Change runtime** type from the drop-down menu\n", 153 | "3.\tchoose **GPU** from the Hardware accelerator drop-down menu\n", 154 | "4.\tclick **SAVE**" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "id": "nu9mIFhNTgeW" 161 | }, 162 | "source": [ 163 | "Verify that GPU is available:" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "P-LVWYafThG_" 170 | }, 171 | "source": [ 172 | "tf.__version__, tf.test.gpu_device_name()" 173 | ], 174 | "execution_count": null, 175 | "outputs": [] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "GcllIuuxVKfX" 181 | }, 182 | "source": [ 183 | "# Import Requisite Libraries" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": { 189 | "id": "n5eRwZX4VN8e" 190 | }, 191 | "source": [ 192 | "To activate the virtual display:" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "64LS09veVKkf" 199 | }, 200 | "source": [ 201 | "import pyvirtualdisplay\n", 202 | "\n", 203 | "display = pyvirtualdisplay.Display(\n", 204 | " visible=0, size=(1400, 900)).start()" 205 | ], 206 | "execution_count": null, 207 | "outputs": [] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": { 212 | "id": "7ZPIX52b01YV" 213 | }, 214 | "source": [ 215 | "Import the **gym** library:" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "ITRgIvjo01ok" 222 | }, 223 | "source": [ 224 | "import gym" 225 | ], 226 | "execution_count": null, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "jXmVNSq125tR" 233 | }, 234 | "source": [ 235 | "# Create an Environment" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": { 241 | "id": "fPvqGtus1ftj" 242 | }, 243 | "source": [ 244 | "Create a **CartPole** environment:" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "metadata": { 250 | "id": "JBw199AS1mse" 251 | }, 252 | "source": [ 253 | "env = gym.make('CartPole-v1')" 254 | ], 255 | "execution_count": null, 256 | "outputs": [] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": { 261 | "id": "W51giejf2Hd6" 262 | }, 263 | "source": [ 264 | "The *CartPole* environment is a 2D simulation that accelerates a cart left or right to balance a pole placed on top of it. A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright and the goal is to prevent it from falling over." 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": { 270 | "id": "7USxVMBS1qBU" 271 | }, 272 | "source": [ 273 | "Initialize the environment by calling is reset() method, which returns an observation:" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "NmamatDR1qNV" 280 | }, 281 | "source": [ 282 | "env.seed(0)\n", 283 | "obs = env.reset()\n", 284 | "obs" 285 | ], 286 | "execution_count": null, 287 | "outputs": [] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": { 292 | "id": "f__W2QpG18E6" 293 | }, 294 | "source": [ 295 | "Observations vary depending on the environment. In this case, it is a 1D numpy array composed of 4 floats that represent the cart's horizontal position, velocity, angle of the pole (0 = vertical), and angular velocity. Any positive number indicates movement to the **right** for angle of the pole and angular velocity. Any negative number indicates movement to the **left**. For horizontal position, a negative number means that it is tilting left and a positive number to the right. For velocity, a positive number means the cart is speeding up and a negative number slowing down.\n", 296 | "\n", 297 | "So the pole is not completely horizontal (obs[0] is slightly negative), its velocity is slowly increasing (obs[1] is slightly positive), the pole is angled slightly to the right (obs[2] is slightly positive), and the angular velocity is going toward the left (obs[3] is slightly negative)." 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": { 303 | "id": "rlZ8s-jA2xVR" 304 | }, 305 | "source": [ 306 | "An environment can be visualized by calling its render() method, and you can pick the rendering mode (the rendering options depend on the environment)." 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "metadata": { 312 | "id": "gfSSdjOs2xn5" 313 | }, 314 | "source": [ 315 | "env.render()" 316 | ], 317 | "execution_count": null, 318 | "outputs": [] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "id": "1_hK5kSb37LJ" 324 | }, 325 | "source": [ 326 | "# Display the Environment" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": { 332 | "id": "ShgoLJx03I57" 333 | }, 334 | "source": [ 335 | "Set mode='rgb_array' to get an image of the environment as a NumPy array:" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "metadata": { 341 | "id": "BBklNNna2xqH" 342 | }, 343 | "source": [ 344 | "img = env.render(mode='rgb_array')\n", 345 | "img.shape" 346 | ], 347 | "execution_count": null, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": { 353 | "id": "cBPWSJGt3Tv-" 354 | }, 355 | "source": [ 356 | "Create a function to display the environment as configured:" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "metadata": { 362 | "id": "04F-SfvO2weY" 363 | }, 364 | "source": [ 365 | "def plot_environment(env, figsize=(5,4)):\n", 366 | " plt.figure(figsize=figsize)\n", 367 | " img = env.render(mode='rgb_array')\n", 368 | " plt.imshow(img)\n", 369 | " plt.axis('off')\n", 370 | " return img" 371 | ], 372 | "execution_count": null, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "id": "H4Tiiw9p4KZ9" 379 | }, 380 | "source": [ 381 | "Display:" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "metadata": { 387 | "id": "BaW5Ec0k3rsI" 388 | }, 389 | "source": [ 390 | "import matplotlib.pyplot as plt\n", 391 | "\n", 392 | "plot_environment(env)\n", 393 | "plt.show()" 394 | ], 395 | "execution_count": null, 396 | "outputs": [] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": { 401 | "id": "S4FHO1NvlCKj" 402 | }, 403 | "source": [ 404 | "# Display Actions\n", 405 | "\n", 406 | "Let's see how to interact with the environment we created. The agent needs to select an action from an action space. An **action space** is the set of possible actions that an agent can take." 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": { 412 | "id": "QPWJrS7A3hhw" 413 | }, 414 | "source": [ 415 | "Ask the environment about possible actions:" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "metadata": { 421 | "id": "4rHLLedC3hnY" 422 | }, 423 | "source": [ 424 | "env.action_space" 425 | ], 426 | "execution_count": null, 427 | "outputs": [] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": { 432 | "id": "iAM4OypW4mM3" 433 | }, 434 | "source": [ 435 | "**Discrete(2)** means that the possible actions are integers 0 and 1, which represents accelerating left (0) or right (1). So the environment's action space has two possible actions. The agent can accelerate towards the left or towards the right. Of course, other environments may have additional discrete actions or other kinds of actions like continuous ones." 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "cfjOQx6b5k_h" 442 | }, 443 | "source": [ 444 | "Reset the enviroment and see how the pole is leaning by looking at its angle:" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "metadata": { 450 | "id": "3fPLhxvb5lFn" 451 | }, 452 | "source": [ 453 | "env.seed(0)\n", 454 | "obs = env.reset()\n", 455 | "indx = 2\n", 456 | "obs[indx]" 457 | ], 458 | "execution_count": null, 459 | "outputs": [] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": { 464 | "id": "RpHkDKBI6PsM" 465 | }, 466 | "source": [ 467 | "The third position (index of 2) in the *obs* array is the angle of the pole. If the value is below 0, the pole angles to the left. If above 0, it angles to the right. The pole is moving slightly toward the right because **obs[2] is > 0**." 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "id": "wZr3IKt05OE7" 474 | }, 475 | "source": [ 476 | "The CartPole environment only has two actions, left (0) or right (1). Let's accelerate the cart toward the right by setting **action=1**:" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "metadata": { 482 | "id": "ZWJzRa1K3h8L" 483 | }, 484 | "source": [ 485 | "action = 1\n", 486 | "obs, reward, done, info = env.step(action)\n", 487 | "print ('obs array:', obs)\n", 488 | "print ('reward:', reward)\n", 489 | "print ('done:', done)\n", 490 | "print ('info:', info)" 491 | ], 492 | "execution_count": null, 493 | "outputs": [] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": { 498 | "id": "JqMUcvC--WZ2" 499 | }, 500 | "source": [ 501 | "The **step()** method executes the given action and returns four values. **obs** is the new observation. The cart is now moving toward the right because **obs[1] > 0**. The pole is still tilted toward the right because **obs[2] > 0**, but its angular velocity is now negative because **obs[3] < 0**. So it will likely be tilted toward the left after the next step. In this simple environment, *reward* is always *1.0* at every step. So the goal is to keep the episode running as long as possible. The *done* value is *True* when the episode is over. The episode is over if the pole tilts too much, goes off the screen or we win the game. The *info* value provides extra information. Once we finish using an environment, call the **close()** method to free resources." 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": { 507 | "id": "Lygt0lKu8em7" 508 | }, 509 | "source": [ 510 | "The environment tells the agent each new observation, the reward, when the game is over, and information it got during the last step." 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": { 516 | "id": "GLuQrUEI4W7j" 517 | }, 518 | "source": [ 519 | "Display the pole position:" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "metadata": { 525 | "id": "LKYEjDCV4XDh" 526 | }, 527 | "source": [ 528 | "plot_environment(env)\n", 529 | "plt.show()" 530 | ], 531 | "execution_count": null, 532 | "outputs": [] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": { 537 | "id": "zTI_Rppf8ahK" 538 | }, 539 | "source": [ 540 | "Here is the reward the agent got during the last step:" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "metadata": { 546 | "id": "BjXiCaJ98am7" 547 | }, 548 | "source": [ 549 | "reward" 550 | ], 551 | "execution_count": null, 552 | "outputs": [] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": { 557 | "id": "laPUdSGn8_4_" 558 | }, 559 | "source": [ 560 | "The game is not over yet:" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "metadata": { 566 | "id": "rvEaZPSv8_-Z" 567 | }, 568 | "source": [ 569 | "done" 570 | ], 571 | "execution_count": null, 572 | "outputs": [] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": { 577 | "id": "G2E5BsJ79Fhl" 578 | }, 579 | "source": [ 580 | "The sequence of steps between the moment the environment is reset until it is done is called an **episode**. At the end of an episode (i.e., when step() returns done=True), reset the environment before continuing to use it." 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "metadata": { 586 | "id": "I7Tl8ZCk9Fq_" 587 | }, 588 | "source": [ 589 | "if done:\n", 590 | " obs = env.reset()\n", 591 | "else:\n", 592 | " print ('game is not over!')" 593 | ], 594 | "execution_count": null, 595 | "outputs": [] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": { 600 | "id": "KAJ6T8rwlG_r" 601 | }, 602 | "source": [ 603 | "# Simple Neural Network Policy\n", 604 | "\n", 605 | "How can we make the poll remain upright? We need to define a policy, which is the strategy that the agent uses to select an action at each step. It can use all past actions and observations to decide what to do.\n", 606 | "\n", 607 | "Let's create a neural network that takes observations as inputs and output the action to take for each observation. To choose an action, the network estimates a probability for each action and selects an action randomly according to the estimated probabilities. In the case of the Cart-Pole environment, there are just two possible actions (left or right). So we only need one output neuron that outputs the probability *p* of the action 0 (left), and of course the probability of action 1 (right) will be *1 - p*." 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "metadata": { 613 | "id": "UsDHqwFGume8" 614 | }, 615 | "source": [ 616 | "Clear previous models and generate a seed:" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "metadata": { 622 | "id": "eRzfuloFch98" 623 | }, 624 | "source": [ 625 | "import numpy as np\n", 626 | "\n", 627 | "tf.keras.backend.clear_session()\n", 628 | "tf.random.set_seed(0)\n", 629 | "np.random.seed(0)" 630 | ], 631 | "execution_count": null, 632 | "outputs": [] 633 | }, 634 | { 635 | "cell_type": "markdown", 636 | "metadata": { 637 | "id": "reGJMdgl-9lB" 638 | }, 639 | "source": [ 640 | "Determine the observation space:" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "metadata": { 646 | "id": "R_SDjROO-9qx" 647 | }, 648 | "source": [ 649 | "obs_space = env.observation_space.shape\n", 650 | "obs_space" 651 | ], 652 | "execution_count": null, 653 | "outputs": [] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": { 658 | "id": "mIUGfNjUAhOQ" 659 | }, 660 | "source": [ 661 | "As shown earlier in this notebook, the observation (or policy) space is a 1D numpy array composed of 4 floats that represent the cart's horizontal position, velocity, angle of the pole (0 = vertical), and angular velocity. So the policy space is 4." 662 | ] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "metadata": { 667 | "id": "scad5XmbAG1o" 668 | }, 669 | "source": [ 670 | "Set the number of inputs:" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "metadata": { 676 | "id": "Rp1JFv9uAHEH" 677 | }, 678 | "source": [ 679 | "n_inputs = env.observation_space.shape[0]\n", 680 | "n_inputs" 681 | ], 682 | "execution_count": null, 683 | "outputs": [] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "metadata": { 688 | "id": "rzF54s8ZBLoi" 689 | }, 690 | "source": [ 691 | "Create a model:" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "metadata": { 697 | "id": "QtQY2wezBLx5" 698 | }, 699 | "source": [ 700 | "from tensorflow.keras.models import Sequential\n", 701 | "from tensorflow.keras.layers import Dense\n", 702 | "\n", 703 | "model = Sequential([\n", 704 | " Dense(5, activation='elu', input_shape=[n_inputs]),\n", 705 | " Dense(1, activation='sigmoid')\n", 706 | "])" 707 | ], 708 | "execution_count": null, 709 | "outputs": [] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "metadata": { 714 | "id": "fGh5pcfvxZDv" 715 | }, 716 | "source": [ 717 | "The model is a simple *Sequential* model that defines the policy network. The number of inputs is the size of the observation space, which is 4 in our case. We include only 5 neurons in the first layer because this is such a simple problem. We output a single probability (the probability of going left) so we have a single output neuron using sigmoid activation. If we had more than two possible actions, we would use one output neuron per action and substitute softmax activation." 718 | ] 719 | }, 720 | { 721 | "cell_type": "markdown", 722 | "metadata": { 723 | "id": "YzNjnJ3lBTZF" 724 | }, 725 | "source": [ 726 | "In this particular environment, past actions and observations can safely be ignored because each observation contains the environment's full state. If there were some hidden state, we may need to consider past actions and observations to try to infer the hidden state of the environment. For example, if the environment only revealed the position of the cart but not its velocity, we have to consider not only the current observation but also the previous observation in order to estimate the current velocity. Another example is if the observations are noisy, we may want to use the past few observations to estimate the most likely current state. Our problem is as simple as can be because the current observation is noise-free and contains the environment's full state.\n", 727 | "\n", 728 | "Why do we pick a random action based on the probability given by the policy network rather than just picking the action with the highest probability? This approach lets the agent find the right balance between exploring new actions and exploiting the actions that are known to work well." 729 | ] 730 | }, 731 | { 732 | "cell_type": "markdown", 733 | "metadata": { 734 | "id": "HfttE66yDx3n" 735 | }, 736 | "source": [ 737 | "# Model Predictions" 738 | ] 739 | }, 740 | { 741 | "cell_type": "markdown", 742 | "metadata": { 743 | "id": "zT2VWBh9C_pj" 744 | }, 745 | "source": [ 746 | "Create a function that runs the model to play one episode and return the frames so we can display an animation:" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "metadata": { 752 | "id": "gOGOvhbwDAP6" 753 | }, 754 | "source": [ 755 | "def render_policy_net(model, n_max_steps=200, seed=0):\n", 756 | " frames = []\n", 757 | " env = gym.make('CartPole-v1')\n", 758 | " env.seed(seed)\n", 759 | " np.random.seed(seed)\n", 760 | " obs = env.reset()\n", 761 | " for step in range(n_max_steps):\n", 762 | " frames.append(env.render(mode='rgb_array'))\n", 763 | " left_proba = model.predict(obs.reshape(1, -1))\n", 764 | " action = int(np.random.rand() > left_proba)\n", 765 | " obs, reward, done, info = env.step(action)\n", 766 | " if done:\n", 767 | " break\n", 768 | " env.close()\n", 769 | " return frames" 770 | ], 771 | "execution_count": null, 772 | "outputs": [] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "metadata": { 777 | "id": "-7bGd7QFDAZ5" 778 | }, 779 | "source": [ 780 | "Establish the environment and reset it. Create a loop to run a number of steps until the episode is over. Begin each step by appending the visualization of the environment to the *frames* list. Continue making action predictions with the model. Next, establish an action based on prediction. Execute the *step()* method based on the action. Continue looping until the episode is over. End by returning the list of frames." 781 | ] 782 | }, 783 | { 784 | "cell_type": "markdown", 785 | "metadata": { 786 | "id": "XgDPjvJQGOA5" 787 | }, 788 | "source": [ 789 | "Create functions to show animation of the frames:" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "metadata": { 795 | "id": "gBQIf2HOGOHQ" 796 | }, 797 | "source": [ 798 | "import matplotlib.animation as animation\n", 799 | "import matplotlib as mpl\n", 800 | "\n", 801 | "def update_scene(num, frames, patch):\n", 802 | " patch.set_data(frames[num])\n", 803 | " return patch,\n", 804 | "\n", 805 | "def plot_animation(frames, repeat=False, interval=40):\n", 806 | " fig = plt.figure()\n", 807 | " patch = plt.imshow(frames[0])\n", 808 | " plt.axis('off')\n", 809 | " anim = animation.FuncAnimation(\n", 810 | " fig, update_scene, fargs=(frames, patch), blit=True,\n", 811 | " frames=len(frames), repeat=repeat, interval=interval)\n", 812 | " plt.close()\n", 813 | " return anim" 814 | ], 815 | "execution_count": null, 816 | "outputs": [] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "metadata": { 821 | "id": "uFHd7hL3NQNX" 822 | }, 823 | "source": [ 824 | "Model predictions:" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "metadata": { 830 | "id": "8LKp7JecDAf5" 831 | }, 832 | "source": [ 833 | "frames = render_policy_net(model)" 834 | ], 835 | "execution_count": null, 836 | "outputs": [] 837 | }, 838 | { 839 | "cell_type": "markdown", 840 | "metadata": { 841 | "id": "pp_0qgYDNpNc" 842 | }, 843 | "source": [ 844 | "# Animate\n", 845 | "\n", 846 | "Additional resources:\n", 847 | "\n", 848 | "https://colab.research.google.com/github/jckantor/CBE30338/blob/master/docs/A.03-Animation-in-Jupyter-Notebooks.ipynb\n", 849 | "\n", 850 | "https://colab.research.google.com/github/phoebe-project/phoebe2-docs/blob/2.1/tutorials/animations.ipynb" 851 | ] 852 | }, 853 | { 854 | "cell_type": "markdown", 855 | "metadata": { 856 | "id": "eJduad0gNhjk" 857 | }, 858 | "source": [ 859 | "Create the animation:" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "metadata": { 865 | "id": "LMySFUOQNkN9" 866 | }, 867 | "source": [ 868 | "anim = plot_animation(frames, interval=100)" 869 | ], 870 | "execution_count": null, 871 | "outputs": [] 872 | }, 873 | { 874 | "cell_type": "markdown", 875 | "metadata": { 876 | "id": "Xf6KzHroRZL3" 877 | }, 878 | "source": [ 879 | "Experiment with the **interval** parameter size." 880 | ] 881 | }, 882 | { 883 | "cell_type": "markdown", 884 | "metadata": { 885 | "id": "qio2krN7NuGk" 886 | }, 887 | "source": [ 888 | "Render and display the animation. We show two ways to accomplish this. The first method uses the HTML library to display HTML elements. The animation is rendered to html5 video with the to_html5_video() function and then displayed with HTML():" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "metadata": { 894 | "id": "5qQBctubHgDr" 895 | }, 896 | "source": [ 897 | "from IPython.display import HTML\n", 898 | "\n", 899 | "method1 = HTML(anim.to_html5_video())\n", 900 | "method1" 901 | ], 902 | "execution_count": null, 903 | "outputs": [] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "metadata": { 908 | "id": "xWFZ2GnYPMel" 909 | }, 910 | "source": [ 911 | "The second method uses the runtime configuration library:" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "metadata": { 917 | "id": "ICp6akcSIBRN" 918 | }, 919 | "source": [ 920 | "from matplotlib import rc\n", 921 | "\n", 922 | "method2 = rc('animation', html='html5')" 923 | ], 924 | "execution_count": null, 925 | "outputs": [] 926 | }, 927 | { 928 | "cell_type": "markdown", 929 | "metadata": { 930 | "id": "0fgGmae8Q8Vu" 931 | }, 932 | "source": [ 933 | "Run the animation:" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "metadata": { 939 | "id": "vxH4f32GIBTb" 940 | }, 941 | "source": [ 942 | "anim" 943 | ], 944 | "execution_count": null, 945 | "outputs": [] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "metadata": { 950 | "id": "UQsn8wjOze5b" 951 | }, 952 | "source": [ 953 | "Ugh! The pole is falling to the left! But, we didn't implement our basic policy to go left if the pole is tilting left and go right if it is tilting right.\n" 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": { 959 | "id": "oXmO464g3oSP" 960 | }, 961 | "source": [ 962 | "# Implement a Basic Policy\n", 963 | "\n", 964 | "Make the network play in 50 different environments in parallel to give us a diverse training batch at each step. Train for 5,000 iterations. Use the RMSProp optimizer. And, use binary cross-entropy for the loss function because we only have two discrete possible actions. Finally, reset the environments when they are done to free resources. We train with a custom training loop so we can easily use the predictions at each training step to advance the environments." 965 | ] 966 | }, 967 | { 968 | "cell_type": "markdown", 969 | "metadata": { 970 | "id": "ycCLwZNG4INQ" 971 | }, 972 | "source": [ 973 | "Train with a basic policy:" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "metadata": { 979 | "id": "tqNtwd6c4ITJ" 980 | }, 981 | "source": [ 982 | "n_environments = 50\n", 983 | "n_iterations = 5000\n", 984 | "\n", 985 | "envs = [gym.make(\n", 986 | " 'CartPole-v1') for _ in range(n_environments)]\n", 987 | "for index, env in enumerate(envs):\n", 988 | " env.seed(index)\n", 989 | "np.random.seed(0)\n", 990 | "observations = [env.reset() for env in envs]\n", 991 | "optimizer = tf.keras.optimizers.RMSprop()\n", 992 | "loss_fn = tf.keras.losses.binary_crossentropy\n", 993 | "\n", 994 | "for iteration in range(n_iterations):\n", 995 | " # if angle < 0, we want proba(left) = 1., or else proba(left) = 0.\n", 996 | " target_probas = np.array(\n", 997 | " [([1.] if obs[2] < 0 else [0.]) \n", 998 | " for obs in observations])\n", 999 | " with tf.GradientTape() as tape:\n", 1000 | " left_probas = model(np.array(observations))\n", 1001 | " loss = tf.reduce_mean(\n", 1002 | " loss_fn(target_probas, left_probas))\n", 1003 | " print('\\rIteration: {}, Loss: {:.3f}'.\\\n", 1004 | " format(iteration, loss.numpy()), end='')\n", 1005 | " grads = tape.gradient(loss, model.trainable_variables)\n", 1006 | " optimizer.apply_gradients(\n", 1007 | " zip(grads, model.trainable_variables))\n", 1008 | " actions = (np.random.rand(n_environments, 1) >\\\n", 1009 | " left_probas.numpy()).astype(np.int32)\n", 1010 | " for env_index, env in enumerate(envs):\n", 1011 | " obs, reward, done, info = env.step(\n", 1012 | " actions[env_index][0])\n", 1013 | " observations[env_index] = obs if not done else env.reset()\n", 1014 | "\n", 1015 | "for env in envs:\n", 1016 | " env.close()" 1017 | ], 1018 | "execution_count": null, 1019 | "outputs": [] 1020 | }, 1021 | { 1022 | "cell_type": "markdown", 1023 | "metadata": { 1024 | "id": "M4lPmolg4IYI" 1025 | }, 1026 | "source": [ 1027 | "Create the frames for an animation:" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "code", 1032 | "metadata": { 1033 | "id": "NDbhNVGV4IdM" 1034 | }, 1035 | "source": [ 1036 | "frames = render_policy_net(model)" 1037 | ], 1038 | "execution_count": null, 1039 | "outputs": [] 1040 | }, 1041 | { 1042 | "cell_type": "markdown", 1043 | "metadata": { 1044 | "id": "8yhP-8aO4Ii7" 1045 | }, 1046 | "source": [ 1047 | "Animate:" 1048 | ] 1049 | }, 1050 | { 1051 | "cell_type": "code", 1052 | "metadata": { 1053 | "id": "T58qIy4m4Iny" 1054 | }, 1055 | "source": [ 1056 | "anim = plot_animation(frames, repeat=True, interval=100)\n", 1057 | "anim" 1058 | ], 1059 | "execution_count": null, 1060 | "outputs": [] 1061 | }, 1062 | { 1063 | "cell_type": "markdown", 1064 | "metadata": { 1065 | "id": "df9WOAce8KbJ" 1066 | }, 1067 | "source": [ 1068 | "Much better!" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "markdown", 1073 | "metadata": { 1074 | "id": "R3_cmcNj8RSO" 1075 | }, 1076 | "source": [ 1077 | "# Reinforce Algorithm\n", 1078 | "\n", 1079 | "Let's see if the agent can learn a better policy on its own. Policy gradients (PG) optimize the parameters of a policy by following the gradients toward higher rewards. We use a reinforce PG algorithm to automate agent learning." 1080 | ] 1081 | }, 1082 | { 1083 | "cell_type": "markdown", 1084 | "metadata": { 1085 | "id": "R_xfFqrP9Zsi" 1086 | }, 1087 | "source": [ 1088 | "Let the neural network policy play the game several times. At each step, compute the gradients that make the chosen action even more likely. But, don't apply the gradients yet. After running several episodes, compute each action's advantage with a discount factor at each step. A *discount factor* is computed by evaluating an action based on the sum of all rewards that come after the action. If an action's advantage is positive, the action was probably good. So apply the gradients to make the action more likely to be chosen in the future. If it is negative, apply the opposite gradients to make the action less likely to be chosen. Finally, compute the mean of all resultant gradient vectors (gradient (or opposite gradient) x action advantage) and use it to perform a *Gradient Descent* step." 1089 | ] 1090 | }, 1091 | { 1092 | "cell_type": "markdown", 1093 | "metadata": { 1094 | "id": "YrL_XCgV_12G" 1095 | }, 1096 | "source": [ 1097 | "## Train the Model\n", 1098 | "\n", 1099 | "Train the model to learn to balance the pole on the cart." 1100 | ] 1101 | }, 1102 | { 1103 | "cell_type": "markdown", 1104 | "metadata": { 1105 | "id": "p7eDd1XMJSKI" 1106 | }, 1107 | "source": [ 1108 | " ## Create Functions to Play the Game" 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "markdown", 1113 | "metadata": { 1114 | "id": "j9KbJ0RbJQt4" 1115 | }, 1116 | "source": [ 1117 | "Create a function that plays one step: " 1118 | ] 1119 | }, 1120 | { 1121 | "cell_type": "code", 1122 | "metadata": { 1123 | "id": "wptcEIqRze_a" 1124 | }, 1125 | "source": [ 1126 | "def play_one_step(env, obs, model, loss_fn):\n", 1127 | " with tf.GradientTape() as tape:\n", 1128 | " left_proba = model(obs[np.newaxis])\n", 1129 | " action = (tf.random.uniform([1, 1]) > left_proba)\n", 1130 | " y_target = tf.constant(\n", 1131 | " [[1.]]) - tf.cast(action, tf.float32)\n", 1132 | " loss = tf.reduce_mean(loss_fn(\n", 1133 | " y_target, left_proba))\n", 1134 | " grads = tape.gradient(loss, model.trainable_variables)\n", 1135 | " obs, reward, done, info = env.step(\n", 1136 | " int(action[0, 0].numpy()))\n", 1137 | " return obs, reward, done, grads" 1138 | ], 1139 | "execution_count": null, 1140 | "outputs": [] 1141 | }, 1142 | { 1143 | "cell_type": "markdown", 1144 | "metadata": { 1145 | "id": "cov705LpATYH" 1146 | }, 1147 | "source": [ 1148 | "With the *GradientTape* block, call the model with a single observation. We reshape the observation so that it becomes a batch containing a single instance (the model expects a batch). We get a probability of going left. Sample a random float between 0 and 1, and check if it greater than the probability. The *action* is **False** with probability *left_proba* and **True** with probability *1 - left_proba*. Cast this Boolean to a number of 0 (left) or 1 (right) with the appropriate probabilities. We then define the target probabilities of going left (1 - the action). If the action is 0 (left), the target probability of going left is 1. If the action is 1 (right), the target probability is 0. Whew!\n", 1149 | "\n", 1150 | "We continue by computing the loss and use the tape to compute the gradient of the loss with regard to the model's trainable variables. We tweak the gradients later depending on how good or bad that action turned out to be. Finally, we play the selected action and return the new observation, reward, whether the episode is over or not, and the gradients." 1151 | ] 1152 | }, 1153 | { 1154 | "cell_type": "markdown", 1155 | "metadata": { 1156 | "id": "sAKlb6_nCiSY" 1157 | }, 1158 | "source": [ 1159 | "Create a function to play multiple episodes and return the rewards and gradients for each episode and each step:" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "metadata": { 1165 | "id": "cmwyhWLNzrwy" 1166 | }, 1167 | "source": [ 1168 | "def play_multiple_episodes(\n", 1169 | " env, n_episodes, n_max_steps, model, loss_fn):\n", 1170 | " all_rewards = []\n", 1171 | " all_grads = []\n", 1172 | " for episode in range(n_episodes):\n", 1173 | " current_rewards = []\n", 1174 | " current_grads = []\n", 1175 | " obs = env.reset()\n", 1176 | " for step in range(n_max_steps):\n", 1177 | " obs, reward, done, grads = play_one_step(\n", 1178 | " env, obs, model, loss_fn)\n", 1179 | " current_rewards.append(reward)\n", 1180 | " current_grads.append(grads)\n", 1181 | " if done:\n", 1182 | " break\n", 1183 | " all_rewards.append(current_rewards)\n", 1184 | " all_grads.append(current_grads)\n", 1185 | " return all_rewards, all_grads" 1186 | ], 1187 | "execution_count": null, 1188 | "outputs": [] 1189 | }, 1190 | { 1191 | "cell_type": "markdown", 1192 | "metadata": { 1193 | "id": "k1fCOhBVDD9E" 1194 | }, 1195 | "source": [ 1196 | "The function returns a list of reward lists. The list contains one reward list per episode. Each reward list contains one reward per step. It also returns a list of gradient lists. The list contains one gradient list per episode. Each gradient list contains one tuple of gradients per step. Each tuple contains one gradient tensor per trainable variable." 1197 | ] 1198 | }, 1199 | { 1200 | "cell_type": "markdown", 1201 | "metadata": { 1202 | "id": "rqqnQhgnEKhk" 1203 | }, 1204 | "source": [ 1205 | "The policy gradient algorithm uses the *play_multiple_episodes()* function to play the game several times. It then goes back and looks at all the rewards to discount and normalize them." 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "markdown", 1210 | "metadata": { 1211 | "id": "Fy_sFGOgJg0x" 1212 | }, 1213 | "source": [ 1214 | "## Discount and Normalize the Rewards" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "markdown", 1219 | "metadata": { 1220 | "id": "mF-qeOVdJg4D" 1221 | }, 1222 | "source": [ 1223 | "To discount and normalize rewards, we introduce two functions: *discount_rewards* and *discount_and_normalize_rewards*." 1224 | ] 1225 | }, 1226 | { 1227 | "cell_type": "code", 1228 | "metadata": { 1229 | "id": "oh9uVqUEzrzR" 1230 | }, 1231 | "source": [ 1232 | "def discount_rewards(rewards, discount_rate):\n", 1233 | " discounted = np.array(rewards)\n", 1234 | " for step in range(len(rewards) - 2, -1, -1):\n", 1235 | " discounted[step] += discounted[step + 1] * discount_rate\n", 1236 | " return discounted\n", 1237 | "\n", 1238 | "def discount_and_normalize_rewards(\n", 1239 | " all_rewards, discount_rate):\n", 1240 | " all_discounted_rewards =\\\n", 1241 | " [discount_rewards(rewards, discount_rate)\n", 1242 | " for rewards in all_rewards]\n", 1243 | " flat_rewards = np.concatenate(all_discounted_rewards)\n", 1244 | " reward_mean = flat_rewards.mean()\n", 1245 | " reward_std = flat_rewards.std()\n", 1246 | " return [(discounted_rewards - reward_mean) / reward_std\n", 1247 | " for discounted_rewards in all_discounted_rewards]" 1248 | ], 1249 | "execution_count": null, 1250 | "outputs": [] 1251 | }, 1252 | { 1253 | "cell_type": "markdown", 1254 | "metadata": { 1255 | "id": "eu-GdgSYFLk-" 1256 | }, 1257 | "source": [ 1258 | "Verify that the function work:" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "metadata": { 1264 | "id": "ujCGqTonzr2G" 1265 | }, 1266 | "source": [ 1267 | "discount_rewards([10, 0, -50], discount_rate=0.8)" 1268 | ], 1269 | "execution_count": null, 1270 | "outputs": [] 1271 | }, 1272 | { 1273 | "cell_type": "markdown", 1274 | "metadata": { 1275 | "id": "s4DSEjJxM8z8" 1276 | }, 1277 | "source": [ 1278 | "We give the function 3 actions. After each action, there is a reward: first 10, then 0, then -50. We use a discount factor of 80%. So, the 3rd action gets -50 (full credit for the last reward), but the 2nd action only gets -40 (80% credit for the last reward), and the 1st action will get 80% of -40 (-32) plus full credit for the first reward (+10) leading to a discounted reward of -22." 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "markdown", 1283 | "metadata": { 1284 | "id": "SGgfrJomNk1f" 1285 | }, 1286 | "source": [ 1287 | "To normalize all discounted rewards across all episodes, we compute the mean and standard deviation of all the discounted rewards. We then subtract the mean from each discounted reward and divide by the standard deviation:" 1288 | ] 1289 | }, 1290 | { 1291 | "cell_type": "code", 1292 | "metadata": { 1293 | "id": "YhvZ0-b4zxv6" 1294 | }, 1295 | "source": [ 1296 | "discount_and_normalize_rewards(\n", 1297 | " [[10, 0, -50], [10, 20]], discount_rate=0.8)" 1298 | ], 1299 | "execution_count": null, 1300 | "outputs": [] 1301 | }, 1302 | { 1303 | "cell_type": "markdown", 1304 | "metadata": { 1305 | "id": "xi4Sg8q6FPLd" 1306 | }, 1307 | "source": [ 1308 | "All actions from the first episode are considered **bad** because normalized advantages are all negative. This makes sense because the sum of the rewards is -40. Conversely, second episode actions are **good** because normalized advantages are positive. The sum of the rewards is 30." 1309 | ] 1310 | }, 1311 | { 1312 | "cell_type": "markdown", 1313 | "metadata": { 1314 | "id": "ybgAgornJ6XW" 1315 | }, 1316 | "source": [ 1317 | "## Build the Model" 1318 | ] 1319 | }, 1320 | { 1321 | "cell_type": "markdown", 1322 | "metadata": { 1323 | "id": "qyuuFFXGIWeR" 1324 | }, 1325 | "source": [ 1326 | "Define the hyperparameters:" 1327 | ] 1328 | }, 1329 | { 1330 | "cell_type": "code", 1331 | "metadata": { 1332 | "id": "FPSKsnahzxyi" 1333 | }, 1334 | "source": [ 1335 | "n_iterations = 150\n", 1336 | "n_episodes_per_update = 10\n", 1337 | "n_max_steps = 200\n", 1338 | "discount_rate = 0.95" 1339 | ], 1340 | "execution_count": null, 1341 | "outputs": [] 1342 | }, 1343 | { 1344 | "cell_type": "markdown", 1345 | "metadata": { 1346 | "id": "BAcQD0SiIlcH" 1347 | }, 1348 | "source": [ 1349 | "Run 150 training iterations, play 10 episodes of the game per iteration, make each episode last at most 200 steps, and use a discount rate of 0.95." 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "markdown", 1354 | "metadata": { 1355 | "id": "O5RuRjF3I4UA" 1356 | }, 1357 | "source": [ 1358 | "Define an optimizer and loss function:" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "metadata": { 1364 | "id": "Fe_7h5Sszx1Q" 1365 | }, 1366 | "source": [ 1367 | "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", 1368 | "loss_fn = tf.keras.losses.binary_crossentropy" 1369 | ], 1370 | "execution_count": null, 1371 | "outputs": [] 1372 | }, 1373 | { 1374 | "cell_type": "markdown", 1375 | "metadata": { 1376 | "id": "kgN3EJqbI-ZI" 1377 | }, 1378 | "source": [ 1379 | "Use binary cross_entropy because we are training a binary classifier (two possible actions: left or right)." 1380 | ] 1381 | }, 1382 | { 1383 | "cell_type": "markdown", 1384 | "metadata": { 1385 | "id": "g2nbszluKA-p" 1386 | }, 1387 | "source": [ 1388 | "Create the model:" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "code", 1393 | "metadata": { 1394 | "id": "WB6nb9F0z_D9" 1395 | }, 1396 | "source": [ 1397 | "tf.keras.backend.clear_session()\n", 1398 | "np.random.seed(0)\n", 1399 | "tf.random.set_seed(0)\n", 1400 | "\n", 1401 | "model = Sequential([\n", 1402 | " Dense(5, activation='elu', input_shape=[4]),\n", 1403 | " Dense(1, activation='sigmoid'),\n", 1404 | "])" 1405 | ], 1406 | "execution_count": null, 1407 | "outputs": [] 1408 | }, 1409 | { 1410 | "cell_type": "markdown", 1411 | "metadata": { 1412 | "id": "bAkxWb-xKFQc" 1413 | }, 1414 | "source": [ 1415 | "Train:" 1416 | ] 1417 | }, 1418 | { 1419 | "cell_type": "code", 1420 | "metadata": { 1421 | "id": "gwjrIpCgz_GU" 1422 | }, 1423 | "source": [ 1424 | "env = gym.make('CartPole-v1')\n", 1425 | "env.seed(42);\n", 1426 | "\n", 1427 | "for iteration in range(n_iterations):\n", 1428 | " all_rewards, all_grads = play_multiple_episodes(\n", 1429 | " env, n_episodes_per_update, n_max_steps,\n", 1430 | " model, loss_fn)\n", 1431 | " total_rewards = sum(map(sum, all_rewards))\n", 1432 | " print('\\rIteration: {}, mean rewards: {:.1f}'.format(\n", 1433 | " iteration, total_rewards / n_episodes_per_update),\n", 1434 | " end='')\n", 1435 | " all_final_rewards = discount_and_normalize_rewards(\n", 1436 | " all_rewards, discount_rate)\n", 1437 | " all_mean_grads = []\n", 1438 | " for var_index in range(len(model.trainable_variables)):\n", 1439 | " mean_grads = tf.reduce_mean(\n", 1440 | " [final_reward * all_grads[episode_index][step][var_index]\n", 1441 | " for episode_index, final_rewards in enumerate(\n", 1442 | " all_final_rewards)\n", 1443 | " for step, final_reward in enumerate(\n", 1444 | " final_rewards)], axis=0)\n", 1445 | " all_mean_grads.append(mean_grads)\n", 1446 | " optimizer.apply_gradients(\n", 1447 | " zip(all_mean_grads, model.trainable_variables))\n", 1448 | "\n", 1449 | "env.close()" 1450 | ], 1451 | "execution_count": null, 1452 | "outputs": [] 1453 | }, 1454 | { 1455 | "cell_type": "markdown", 1456 | "metadata": { 1457 | "id": "zE0QOoVyKnC0" 1458 | }, 1459 | "source": [ 1460 | "At each training iteration, call *play_multiple_episodes()*, which plays the game 10 times and returns all the rewards and gradients for every episode and step. Call *discount_and_normalize_rewards()* to compute each action's normalized advantage, which gives us a measure of how good or bad each action actually was in hindsight. For each trainable variable, we compute the weighted mean of the gradients over all episodes and all steps weighted by the *final_reward*. The **final_reward** is each action's normalized advantage. We end by applying the mean gradients using the optimizer, which tweaks the model's trainable variables to hopefully make the policy a bit better." 1461 | ] 1462 | }, 1463 | { 1464 | "cell_type": "markdown", 1465 | "metadata": { 1466 | "id": "fm_GfCOIOF2t" 1467 | }, 1468 | "source": [ 1469 | "## Render Policy" 1470 | ] 1471 | }, 1472 | { 1473 | "cell_type": "markdown", 1474 | "metadata": { 1475 | "id": "MFKz61JMOLEV" 1476 | }, 1477 | "source": [ 1478 | "Render the reinforce algorithm policy:" 1479 | ] 1480 | }, 1481 | { 1482 | "cell_type": "code", 1483 | "metadata": { 1484 | "id": "YVZP4ilLOGBF" 1485 | }, 1486 | "source": [ 1487 | "frames_ra = render_policy_net(model)" 1488 | ], 1489 | "execution_count": null, 1490 | "outputs": [] 1491 | }, 1492 | { 1493 | "cell_type": "markdown", 1494 | "metadata": { 1495 | "id": "E2IYh3smN5sg" 1496 | }, 1497 | "source": [ 1498 | "## Animate" 1499 | ] 1500 | }, 1501 | { 1502 | "cell_type": "markdown", 1503 | "metadata": { 1504 | "id": "q58B0BXfOXCH" 1505 | }, 1506 | "source": [ 1507 | "Animate the policy:" 1508 | ] 1509 | }, 1510 | { 1511 | "cell_type": "code", 1512 | "metadata": { 1513 | "id": "D3yrbs1Iz_I2" 1514 | }, 1515 | "source": [ 1516 | "anim = plot_animation(frames_ra, repeat=True, interval=100)\n", 1517 | "anim" 1518 | ], 1519 | "execution_count": null, 1520 | "outputs": [] 1521 | }, 1522 | { 1523 | "cell_type": "markdown", 1524 | "metadata": { 1525 | "id": "czFkl3al9MJa" 1526 | }, 1527 | "source": [ 1528 | "A bit less wobbly." 1529 | ] 1530 | } 1531 | ] 1532 | } --------------------------------------------------------------------------------