├── LICENSE ├── README.md ├── environment.yml ├── notebooks ├── DeepSEA_Demo.ipynb ├── README.md ├── kipoi_screenshot.png ├── network_train_demo.ipynb └── predictor.names └── slides ├── pt1_deep_learning_intro.pdf ├── pt2_convnets.pdf └── pt3_RNN_dim_reduction.pdf /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Greene Laboratory 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DLBio Workshop 2 | This repository stores materials used in the DLBio workshop. 3 | The demo materials are in the `notebooks` directory, and the slides are in `slides`. 4 | 5 | Questions from the talk can be found here: https://docs.google.com/document/d/1cXJZb2Ygja1ACvTVOFYjlIoz-ImiP1toC_xd-d6WJis/edit?usp=sharing 6 | 7 | 8 | ## Resources about deep learning: 9 | * https://greenelab.github.io/deep-review/ 10 | * https://github.com/Benjamin-Lee/deep-rules 11 | * https://medium.com/intro-to-artificial-intelligence/deep-learning-series-1-intro-to-deep-learning-abb1780ee20 12 | * https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi 13 | 14 | ## Tips for learning biology: 15 | * Don't try to learning *everything* - try to start with one specific question and use that as a foundation to expand from (i.e. "what is the role of accessory genes in *P. aeruginosa* infection?" >> "What is the definition of these accessory genes?" >> "how do people relate accessory genes to infection?"). Think about starting with one question and as you try to answer this one question you'll find other ones that will expand your knowledge. 16 | * https://www.khanacademy.org/ 17 | * Get help from others -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: workshop 2 | channels: 3 | - bioconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - mamba 8 | - pybedtools 9 | - pyfaidx 10 | - kipoi 11 | - kipoiseq 12 | - pyyaml 13 | - ipykernel 14 | - pytorch 15 | prefix: /home/heil/anaconda3/envs/workshop 16 | 17 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | This directory contains the notebooks for the two demonstrations in the workshop. 3 | 4 | The first demo, [network_train_demo.ipynb](https://colab.research.google.com/drive/1rl7DwrjGbqFF9vFd3TeusLLnnNH87eVD?usp=sharing), explains how to implement and train a basic neural network in pytorch. 5 | The second demo, [DeepSEA_Demo.ipynb](https://colab.research.google.com/drive/1q1nK4xSZu0ma871D1XqeWF92WfB0UjH0?usp=sharing), shows how pretrained models from places like [Kipoi](http://kipoi.org/) can be applied to biological problems. 6 | 7 | ## Usage 8 | Ideally the notebooks will be executed on Google CoLab via the links above. 9 | If you want to run the notebooks locally, you can download them from this repo, though you may not want to run the conda installation portions of `DeepSEA_Demo.ipynb` if you already have conda. 10 | -------------------------------------------------------------------------------- /notebooks/kipoi_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ben-heil/dl_workshop/8264068771abc70fcd20a6b2b4e8153f6b1b6a53/notebooks/kipoi_screenshot.png -------------------------------------------------------------------------------- /notebooks/network_train_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DL_intro", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "2mo42M1Oi8ef" 20 | }, 21 | "source": [ 22 | "# Intro Deep Learning Notebook\n", 23 | "This notebook demonstrates how to actually implement the ideas discussed in the presentation." 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "tSn5b8oXjN4J" 30 | }, 31 | "source": [ 32 | "## Step 1: Imports\n", 33 | "There are two main frameworks used for deep learning in a research setting: [Pytorch](https://pytorch.org/) and [Tensorflow](https://www.tensorflow.org/). \n", 34 | "Because the code for these frameworks can be verbose, there are also libraries that abstract away many implementation details such as [Keras](https://keras.io), [fastai](https://fast.ai), and\n", 35 | "[HuggingFace](https://huggingface.co/).\n", 36 | "\n", 37 | "Picking your framework is usually easy: you just select a model from the literature that worked well\n", 38 | "on your problem, then modify it to do what you want.\n", 39 | "If you have to start from scratch, use the highest level library that will do what you want.\n", 40 | "That is to say, pick Keras or fastai before Tensorflow or Pytorch whenever possible.\n", 41 | "\n", 42 | "I've selected Pytorch because I'm familiar with it, and because the model I'm using in the next presentation is a Pytorch model." 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "metadata": { 48 | "id": "jsrNFjTbim39" 49 | }, 50 | "source": [ 51 | "from typing import Tuple\n", 52 | "\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "import numpy as np\n", 55 | "import pandas as pd\n", 56 | "import torch\n", 57 | "import torch.nn as nn\n", 58 | "import torch.nn.functional as F\n", 59 | "from torch.utils.data import Dataset, DataLoader, random_split" 60 | ], 61 | "execution_count": 1, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "id": "iAIYPUoJokCh" 68 | }, 69 | "source": [ 70 | "# Make code deterministic\n", 71 | "np.random.seed(42)\n", 72 | "torch.manual_seed(42)\n", 73 | "torch.backends.cudnn.deterministic = True\n", 74 | "torch.backends.cudnn.benchmark = False" 75 | ], 76 | "execution_count": 2, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "ghkthBtknjbC", 83 | "outputId": "426e9054-714a-44fd-c3d9-35ec2245d102", 84 | "colab": { 85 | "base_uri": "https://localhost:8080/", 86 | "height": 439 87 | } 88 | }, 89 | "source": [ 90 | "# First column is the label, other 784 columns are pixel values\n", 91 | "numbers_df = pd.read_csv('/content/sample_data/mnist_train_small.csv', header=None)\n", 92 | "numbers_df" 93 | ], 94 | "execution_count": 3, 95 | "outputs": [ 96 | { 97 | "output_type": "execute_result", 98 | "data": { 99 | "text/html": [ 100 | "
\n", 101 | "\n", 114 | "\n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | "
0123456789101112131415161718192021222324252627282930313233343536373839...745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
06000000000000000000000000000000000000000...0000000000000000000000000000000000000000
15000000000000000000000000000000000000000...0000000000000000000000000000000000000000
27000000000000000000000000000000000000000...0000000000000000000000000000000000000000
39000000000000000000000000000000000000000...15000000000000000000000000000000000000000
45000000000000000000000000000000000000000...0000000000000000000000000000000000000000
......................................................................................................................................................................................................................................................
199950000000000000000000000000000000000000000...0000000000000000000000000000000000000000
199961000000000000000000000000000000000000000...0000000000000000000000000000000000000000
199972000000000000000000000000000000000000000...0000000000000000000000000000000000000000
199989000000000000000000000000000000000000000...0000000000000000000000000000000000000000
199995000000000000000000000000000000000000000...0000000000000000000000000000000000000000
\n", 1128 | "

20000 rows × 785 columns

\n", 1129 | "
" 1130 | ], 1131 | "text/plain": [ 1132 | " 0 1 2 3 4 5 6 ... 778 779 780 781 782 783 784\n", 1133 | "0 6 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1134 | "1 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1135 | "2 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1136 | "3 9 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1137 | "4 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1138 | "... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...\n", 1139 | "19995 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1140 | "19996 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1141 | "19997 2 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1142 | "19998 9 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1143 | "19999 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0\n", 1144 | "\n", 1145 | "[20000 rows x 785 columns]" 1146 | ] 1147 | }, 1148 | "metadata": { 1149 | "tags": [] 1150 | }, 1151 | "execution_count": 3 1152 | } 1153 | ] 1154 | }, 1155 | { 1156 | "cell_type": "markdown", 1157 | "metadata": { 1158 | "id": "aNmXQt0EwcBH" 1159 | }, 1160 | "source": [ 1161 | "## Step 2: Data\n", 1162 | "\n", 1163 | "Pytorch uses Dataset objects to store their data, and DataLoader objects to feed the\n", 1164 | "data into models." 1165 | ] 1166 | }, 1167 | { 1168 | "cell_type": "code", 1169 | "metadata": { 1170 | "id": "G9in-SqpwbK2" 1171 | }, 1172 | "source": [ 1173 | "class MnistDataset(Dataset):\n", 1174 | " \"\"\" This Dataset object stores the MNIST handwritten digit dataset \"\"\" \n", 1175 | " def __init__(self, csv_path: str) -> None:\n", 1176 | " \"\"\" An intializer function that reads the csv stored at csv_path \"\"\"\n", 1177 | " numbers_df = pd.read_csv(csv_path, header=None)\n", 1178 | "\n", 1179 | " # Pull the labels and data out of the pandas dataframe and into numpy arrays\n", 1180 | " self.labels = numbers_df.iloc[:,0].values\n", 1181 | " self.pixels = numbers_df.iloc[:,1:].values\n", 1182 | " \n", 1183 | " def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:\n", 1184 | " \"\"\" Retrieve the item in the dataset at index idx \"\"\"\n", 1185 | " return self.pixels[idx,:], self.labels[idx]\n", 1186 | "\n", 1187 | " def __len__(self) -> int:\n", 1188 | " \"\"\" Return the number of items in the dataset \"\"\"\n", 1189 | " return len(self.labels)" 1190 | ], 1191 | "execution_count": 4, 1192 | "outputs": [] 1193 | }, 1194 | { 1195 | "cell_type": "code", 1196 | "metadata": { 1197 | "id": "93nVZmrg0DX5" 1198 | }, 1199 | "source": [ 1200 | "# These MNIST file are built into colab, but you can also download them and change the paths\n", 1201 | "# to run this notebook locally with Jupyter\n", 1202 | "train_dataset = MnistDataset('/content/sample_data/mnist_train_small.csv')\n", 1203 | "train_dataset, val_dataset = random_split(train_dataset, [18000, 2000], torch.Generator().manual_seed(42))\n", 1204 | "test_dataset = MnistDataset('/content/sample_data/mnist_test.csv')" 1205 | ], 1206 | "execution_count": 5, 1207 | "outputs": [] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "metadata": { 1212 | "id": "x1aJdxdN0FRo", 1213 | "outputId": "a8e55de0-28c4-4462-8bb0-515adf1fcf7f", 1214 | "colab": { 1215 | "base_uri": "https://localhost:8080/", 1216 | "height": 812 1217 | } 1218 | }, 1219 | "source": [ 1220 | "images = []\n", 1221 | "for dataset in [train_dataset, val_dataset, test_dataset]:\n", 1222 | " pixels, label = dataset[0]\n", 1223 | "\n", 1224 | " # Reshape the long flat vector of pixel values into a square image\n", 1225 | " img_example = np.reshape(pixels, [28, 28])\n", 1226 | " print(label)\n", 1227 | " plt.figure()\n", 1228 | " plt.imshow(img_example, cmap=\"gray\")" 1229 | ], 1230 | "execution_count": 6, 1231 | "outputs": [ 1232 | { 1233 | "output_type": "stream", 1234 | "text": [ 1235 | "9\n", 1236 | "5\n", 1237 | "7\n" 1238 | ], 1239 | "name": "stdout" 1240 | }, 1241 | { 1242 | "output_type": "display_data", 1243 | "data": { 1244 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAMSElEQVR4nO3dQawV5RnG8eep2o26wJoSglCtYWOaFBtCIDWNjdFYN2iiIIuGJqbXhTaCLGrsQpemQUlXJtdoxMYqELWyMK2UmNAiGK6GImhUavDCDUKNC3Vl0beLO5gr3jNzmJlz5nDf/y+5OefMd87M21MeZ858M9/niBCAue97XRcAYDgIO5AEYQeSIOxAEoQdSOLCYW7MNqf+gQGLCM+2vNGe3fbNtt+zfcT2A03WBWCwXLef3fYFkt6XdKOk45L2S1obEe+UfIY9OzBgg9izL5d0JCI+jIgvJT0vaVWD9QEYoCZhXyjp2IzXx4tl32J7zPaE7YkG2wLQ0MBP0EXEuKRxicN4oEtN9uxTkhbNeH1FsQzACGoS9v2Slti+yvb3Jd0paUc7ZQFoW+3D+Ig4bfteSX+XdIGkpyLicGuVAWhV7a63WhvjNzswcAO5qAbA+YOwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kETt+dklyfZRSZ9L+krS6YhY1kZRANrXKOyFX0bEJy2sB8AAcRgPJNE07CHpVdtv2h6b7Q22x2xP2J5ouC0ADTgi6n/YXhgRU7Z/KGmnpN9FxO6S99ffGIC+RIRnW95ozx4RU8XjKUkvSVreZH0ABqd22G1fbPvSM88l3STpUFuFAWhXk7Px8yW9ZPvMev4SEX9rpSqcNxYtWlTafvvtt/ds27BhQ6N1b9++vbR99erVpe3Z1A57RHwo6act1gJggOh6A5Ig7EAShB1IgrADSRB2IIlGV9Cd88a4gu68U9V9tXXr1trrruo6q3LHHXeUtt9///092zZv3txo26NsIFfQATh/EHYgCcIOJEHYgSQIO5AEYQeSIOxAEm0MOInzWNVtpo899lhp+7Fjx0rbFy9efM41nVF1i2tVP/vevXtrb3suYs8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0nQzz7Hvf7666XtK1euLG2v6mffuHHjOdfUr7JhqPuxb9++liqZG9izA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjBt/Hqi6r3vPnj21P1s2trrU7fjqTf9tll0jMMjrA7pWe9x420/ZPmX70Ixll9neafuD4nFem8UCaF8/h/FPS7r5rGUPSNoVEUsk7SpeAxhhlWGPiN2SPj1r8SpJW4rnWyTd2nJdAFpW99r4+RFxonj+saT5vd5oe0zSWM3tAGhJ4xthIiLKTrxFxLikcYkTdECX6na9nbS9QJKKx1PtlQRgEOqGfYekdcXzdZJebqccAINS2c9u+zlJ10u6XNJJSQ9J+qukbZIWS/pI0uqIOPsk3mzr4jC+hm3btpW2l42fvmbNmkbrHqSmY9Y3Yc/aFT0n9Opnr/zNHhFrezTd0KgiAEPF5bJAEoQdSIKwA0kQdiAJwg4kwVDSI2DFihWl7VVTE2/fvr1nW5dda1L5/7ZBdq1J1bfvZsOeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYCjpEdB0WuUub9esGqp6cnKy9rqPHTtW2r569erS9qxTNtceShrA3EDYgSQIO5AEYQeSIOxAEoQdSIKwA0nQzz4Cmv5/UDVcdBPr168vba+6BqBM1f3sc3la5UGinx1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkqCffQRU3Ze9devWIVXSvrK+dPrRB6N2P7vtp2yfsn1oxrKHbU/ZPlD83dJmsQDa189h/NOSbp5l+eaIWFr8vdJuWQDaVhn2iNgt6dMh1AJggJqcoLvX9sHiMH9erzfZHrM9YXuiwbYANFQ37I9LulrSUkknJD3a640RMR4RyyJiWc1tAWhBrbBHxMmI+Coivpb0hKTl7ZYFoG21wm57wYyXt0k61Ou9AEZDZT+77eckXS/pckknJT1UvF4qKSQdlXR3RJyo3Bj97LVs2LChtH1qaqr2ujdt2lTaXjUufNnc8FL1NQRoX69+9gv7+ODaWRY/2bgiAEPF5bJAEoQdSIKwA0kQdiAJwg4kwS2uc1xVt13VcM5V0yYvXrz4nGvCYDGUNJAcYQeSIOxAEoQdSIKwA0kQdiAJwg4kQT/7HFB2G2nVMNT0o8899LMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBL0s58HqoZznpycrL3ulStXlrbv27ev9rrRDfrZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJyllc0b09e/bU/uyaNWtK2+lHz6Nyz257ke3XbL9j+7Dt+4rll9neafuD4nHe4MsFUFc/h/GnJW2MiGskrZB0j+1rJD0gaVdELJG0q3gNYERVhj0iTkTEW8XzzyW9K2mhpFWSthRv2yLp1kEVCaC5c/rNbvtKSddKekPS/Ig4UTR9LGl+j8+MSRqrXyKANvR9Nt72JZJekLQ+Ij6b2RbTd9PMepNLRIxHxLKIWNaoUgCN9BV22xdpOujPRsSLxeKTthcU7QsknRpMiQDaUHkYb9uSnpT0bkTMnN93h6R1kh4pHl8eSIUJlA0FLVXf4rp3796ebdu2batVE+aefn6z/1zSryW9bftAsexBTYd8m+27JH0kqfxfLIBOVYY9Iv4ladab4SXd0G45AAaFy2WBJAg7kARhB5Ig7EAShB1IgqGkh2DFihWl7WX95P0om1a5akpmzD0MJQ0kR9iBJAg7kARhB5Ig7EAShB1IgrADSTCUdAua3G/ej6pplelLRz/YswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEvSzt2BycrLR55lWGcPAnh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuhnfvZFkp6RNF9SSBqPiD/ZfljSbyX9t3jrgxHxyqAK7VrV2O9lqvrRmUMdw9DPRTWnJW2MiLdsXyrpTds7i7bNEbFpcOUBaEs/87OfkHSieP657XclLRx0YQDadU6/2W1fKelaSW8Ui+61fdD2U7bn9fjMmO0J2xONKgXQSN9ht32JpBckrY+IzyQ9LulqSUs1ved/dLbPRcR4RCyLiGUt1Augpr7CbvsiTQf92Yh4UZIi4mREfBURX0t6QtLywZUJoKnKsNu2pCclvRsRj81YvmDG226TdKj98gC0pXLKZtvXSfqnpLclfV0sflDSWk0fwoeko5LuLk7mla0r5ZTNwDD1mrKZ+dmBOYb52YHkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kMe8rmTyR9NOP15cWyUTSqtY1qXRK11dVmbT/q1TDU+9m/s3F7YlTHphvV2ka1Lona6hpWbRzGA0kQdiCJrsM+3vH2y4xqbaNal0RtdQ2ltk5/swMYnq737ACGhLADSXQSdts3237P9hHbD3RRQy+2j9p+2/aBruenK+bQO2X70Ixll9neafuD4nHWOfY6qu1h21PFd3fA9i0d1bbI9mu237F92PZ9xfJOv7uSuobyvQ39N7vtCyS9L+lGSccl7Ze0NiLeGWohPdg+KmlZRHR+AYbtX0j6QtIzEfGTYtkfJX0aEY8U/6GcFxG/H5HaHpb0RdfTeBezFS2YOc24pFsl/UYdfnclda3WEL63LvbsyyUdiYgPI+JLSc9LWtVBHSMvInZL+vSsxaskbSmeb9H0P5ah61HbSIiIExHxVvH8c0lnphnv9LsrqWsougj7QknHZrw+rtGa7z0kvWr7TdtjXRczi/kzptn6WNL8LouZReU03sN01jTjI/Pd1Zn+vClO0H3XdRHxM0m/knRPcbg6kmL6N9go9Z32NY33sMwyzfg3uvzu6k5/3lQXYZ+StGjG6yuKZSMhIqaKx1OSXtLoTUV98swMusXjqY7r+cYoTeM92zTjGoHvrsvpz7sI+35JS2xfZfv7ku6UtKODOr7D9sXFiRPZvljSTRq9qah3SFpXPF8n6eUOa/mWUZnGu9c04+r4u+t8+vOIGPqfpFs0fUb+P5L+0EUNPer6saR/F3+Hu65N0nOaPqz7n6bPbdwl6QeSdkn6QNI/JF02QrX9WdNTex/UdLAWdFTbdZo+RD8o6UDxd0vX311JXUP53rhcFkiCE3RAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kMT/AeoJLM428Os2AAAAAElFTkSuQmCC\n", 1245 | "text/plain": [ 1246 | "
" 1247 | ] 1248 | }, 1249 | "metadata": { 1250 | "tags": [], 1251 | "needs_background": "light" 1252 | } 1253 | }, 1254 | { 1255 | "output_type": "display_data", 1256 | "data": { 1257 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM0klEQVR4nO3db4hd9Z3H8c9n3fbJtGDcsUNIQ9ItAxoWassgKyvSpbRm9UHsk9Aga2TFCdJgoousZAkV1gVdNq0FsTDF2LS0KYVRE0qzrRuKcUGLk5DGMWlrKtEmjElEpNYnXZ3vPrgny6hzzp2ce+49N/N9v2C4957vPfd8OeST8++e+3NECMDy9xdtNwBgMAg7kARhB5Ig7EAShB1I4i8HuTDbnPoH+iwivNj0nrbsttfb/q3tk7bv7+WzAPSX615nt32ZpN9J+rKk05JelLQpIo5XzMOWHeizfmzZr5V0MiJejYg/S/qxpA09fB6APuol7Ksk/WHB69PFtA+wPWl7xvZMD8sC0KO+n6CLiClJUxK78UCbetmyn5G0esHrTxfTAAyhXsL+oqRx25+x/XFJX5O0v5m2ADSt9m58RLxne6ukn0u6TNLuiHi5sc4ANKr2pbdaC+OYHei7vnypBsClg7ADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkag/ZjEvDXXfdVVl/7LHHKuu9jvJ7/vz50trhw4d7Wvajjz5aWT9w4EBlPZuewm77lKR3JL0v6b2ImGiiKQDNa2LL/vcR8WYDnwOgjzhmB5LoNewh6Re2D9ueXOwNtidtz9ie6XFZAHrQ62789RFxxvanJD1j+zcRcWjhGyJiStKUJNnu7WwPgNp62rJHxJni8ZykpyRd20RTAJpXO+y2R2x/8sJzSV+RNNtUYwCa1ctu/Jikp2xf+JwfRcR/NdIVLsrIyEhpbfv27ZXzzs5W///cbf5169ZV1kdHR0trExPVV2rn5+cr6ydOnKis44Nqhz0iXpX0uQZ7AdBHXHoDkiDsQBKEHUiCsANJEHYgCW5xXQauu+660tr4+HjlvLfffntl/eDBgz3VMTzYsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAElxnXwauvvrq0lq3n2Ou+qlnLC9s2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCa6zLwNV19mnp6cr52VY4zzYsgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAElxnXwbWrFlTWtu2bdsAO8Ew67plt73b9jnbswumXWH7GduvFI8r+tsmgF4tZTf+e5LWf2ja/ZIORsS4pIPFawBDrGvYI+KQpLc+NHmDpD3F8z2Sbmm4LwANq3vMPhYRc8XzNySNlb3R9qSkyZrLAdCQnk/QRUTYLv1Vw4iYkjQlSVXvA9BfdS+9nbW9UpKKx3PNtQSgH+qGfb+kzcXzzZL2NdMOgH7puhtve6+kL0oatX1a0jckPSTpJ7bvkPSapI39bDK7nTt3VtZPnjxZq4ZcuoY9IjaVlL7UcC8A+oivywJJEHYgCcIOJEHYgSQIO5AEt7heAkZHR9tuAcsAW3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7JeA119/vbK+devW0trevXsr533hhRdq9YRLD1t2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEYMbpIURYepZu3ZtZf2JJ54ora1bt65y3mPHjlXWjx8/Xlnv9h2A6enp0tqpU6cq50U9EeHFprNlB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkuM6+DIyMjJTWbrjhhsp5169fX1nvNv+qVasq61W/eT8zM1M57913311Z5178xdW+zm57t+1ztmcXTHvA9hnbR4u/m5psFkDzlrIb/z1Ji/33/62IuKb4+1mzbQFoWtewR8QhSW8NoBcAfdTLCbqtto8Vu/kryt5ke9L2jO3qAzQAfVU37N+R9FlJ10iak7Sr7I0RMRURExExUXNZABpQK+wRcTYi3o+IeUnflXRts20BaFqtsNteueDlVyXNlr0XwHDoep3d9l5JX5Q0KumspG8Ur6+RFJJOSdoSEXNdF8Z19mXn8ssvr6xfddVVpbV9+/ZVzvvss89W1jdu3FhZz6rsOnvXQSIiYtMikx/vuSMAA8XXZYEkCDuQBGEHkiDsQBKEHUiCW1zRmm63sN54442V9ZtvvrnJdpYNfkoaSI6wA0kQdiAJwg4kQdiBJAg7kARhB5Loetcb0C/dvuMxPz8/oE5yYMsOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwnR2tWbNmTWX9yJEjA+okB7bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE19kvAVdeeWVl/emnny6tPfjgg5XzHjhwoFZPS1XV+5YtWyrn3bFjR9PtpNZ1y257te1f2j5u+2Xb24rpV9h+xvYrxeOK/rcLoK6l7Ma/J+mfI2KdpL+V9HXb6yTdL+lgRIxLOli8BjCkuoY9IuYi4kjx/B1JJyStkrRB0p7ibXsk3dKvJgH07qKO2W2vlfR5Sb+SNBYRc0XpDUljJfNMSpqs3yKAJiz5bLztT0ialrQ9Iv64sBadXw5c9NcDI2IqIiYiYqKnTgH0ZElht/0xdYL+w4h4sph81vbKor5S0rn+tAigCV13421b0uOSTkTENxeU9kvaLOmh4nFfXzqEJiaqd4rGx8dLa88//3zT7VyUe++9t7Q2NzdXWpOk3bt3N91Oaks5Zv87Sf8o6SXbR4tpO9QJ+U9s3yHpNUkb+9MigCZ0DXtE/I+kRQd3l/SlZtsB0C98XRZIgrADSRB2IAnCDiRB2IEkuMX1EnDo0KHKetXQx91uI3344Ydr9XTBbbfdVlm/5557Smu33npr5bzvvvturZ6wOLbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE19kvAd2uN993332ltUceeaRy3m7DJp8/f76yvnPnzsr6rl27SmvT09OV86JZbNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAlX3Qvd+MLswS0MkroP9zw7O1tZ7/bv47nnnqus33nnnaW1t99+u3Je1BMRi/4aNFt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii63V226slfV/SmKSQNBUR37b9gKQ7JV244XlHRPysy2dxnR3os7Lr7EsJ+0pJKyPiiO1PSjos6RZ1xmP/U0T851KbIOxA/5WFfSnjs89Jmiuev2P7hKRVzbYHoN8u6pjd9lpJn5f0q2LSVtvHbO+2vaJknknbM7ZneuoUQE+W/N1425+Q9Kykf4+IJ22PSXpTneP4f1NnV/+funwGu/FAn9U+Zpck2x+T9FNJP4+Iby5SXyvppxHxN10+h7ADfVb7RhjblvS4pBMLg16cuLvgq5Kqb58C0KqlnI2/XtJzkl6SNF9M3iFpk6Rr1NmNPyVpS3Eyr+qz2LIDfdbTbnxTCDvQf9zPDiRH2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLrD0427E1Jry14PVpMG0bD2tuw9iXRW11N9ramrDDQ+9k/snB7JiImWmugwrD2Nqx9SfRW16B6YzceSIKwA0m0HfaplpdfZVh7G9a+JHqrayC9tXrMDmBw2t6yAxgQwg4k0UrYba+3/VvbJ23f30YPZWyfsv2S7aNtj09XjKF3zvbsgmlX2H7G9ivF46Jj7LXU2wO2zxTr7qjtm1rqbbXtX9o+bvtl29uK6a2uu4q+BrLeBn7MbvsySb+T9GVJpyW9KGlTRBwfaCMlbJ+SNBERrX8Bw/YNkv4k6fsXhtay/R+S3oqIh4r/KFdExL8MSW8P6CKH8e5Tb2XDjN+uFtddk8Of19HGlv1aSScj4tWI+LOkH0va0EIfQy8iDkl660OTN0jaUzzfo84/loEr6W0oRMRcRBwpnr8j6cIw462uu4q+BqKNsK+S9IcFr09ruMZ7D0m/sH3Y9mTbzSxibMEwW29IGmuzmUV0HcZ7kD40zPjQrLs6w5/3ihN0H3V9RHxB0j9I+nqxuzqUonMMNkzXTr8j6bPqjAE4J2lXm80Uw4xPS9oeEX9cWGtz3S3S10DWWxthPyNp9YLXny6mDYWIOFM8npP0lDqHHcPk7IURdIvHcy338/8i4mxEvB8R85K+qxbXXTHM+LSkH0bEk8Xk1tfdYn0Nar21EfYXJY3b/oztj0v6mqT9LfTxEbZHihMnsj0i6SsavqGo90vaXDzfLGlfi718wLAM4102zLhaXnetD38eEQP/k3STOmfkfy/pX9vooaSvv5b06+Lv5bZ7k7RXnd26/1Xn3MYdkv5K0kFJr0j6b0lXDFFvP1BnaO9j6gRrZUu9Xa/OLvoxSUeLv5vaXncVfQ1kvfF1WSAJTtABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/B8rWF1EiWOSHAAAAAElFTkSuQmCC\n", 1258 | "text/plain": [ 1259 | "
" 1260 | ] 1261 | }, 1262 | "metadata": { 1263 | "tags": [], 1264 | "needs_background": "light" 1265 | } 1266 | }, 1267 | { 1268 | "output_type": "display_data", 1269 | "data": { 1270 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAM3ElEQVR4nO3dXahc9bnH8d/vpCmI6UXiS9ik0bTBC8tBEo1BSCxbQktOvIjFIM1FyYHi7kWUFkuo2It4WaQv1JvALkrTkmMJpGoQscmJxVDU4o5Es2NIjCGaxLxYIjQRJMY+vdjLso0za8ZZa2ZN8nw/sJmZ9cya9bDMz7VmvczfESEAV77/aroBAINB2IEkCDuQBGEHkiDsQBJfGeTCbHPoH+iziHCr6ZW27LZX2j5o+7Dth6t8FoD+cq/n2W3PkHRI0nckHZf0mqS1EfFWyTxs2YE+68eWfamkwxFxJCIuSPqTpNUVPg9AH1UJ+zxJx6a9Pl5M+xzbY7YnbE9UWBaAivp+gC4ixiWNS+zGA02qsmU/IWn+tNdfL6YBGEJVwv6apJtsf8P2VyV9X9L2etoCULeed+Mj4qLtByT9RdIMSU9GxP7aOgNQq55PvfW0ML6zA33Xl4tqAFw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9Dw+uyTZPirpnKRPJV2MiCV1NAWgfpXCXrgrIv5Rw+cA6CN244EkqoY9JO2wvcf2WKs32B6zPWF7ouKyAFTgiOh9ZnteRJywfb2knZIejIjdJe/vfWEAuhIRbjW90pY9Ik4Uj2ckPS1paZXPA9A/PYfd9tW2v/bZc0nflTRZV2MA6lXlaPxcSU/b/uxz/i8iXqilKwC1q/Sd/UsvjO/sQN/15Ts7gMsHYQeSIOxAEoQdSIKwA0nUcSNMCmvWrGlbu//++0vnff/990vrH3/8cWl9y5YtpfVTp061rR0+fLh0XuTBlh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuCuty4dOXKkbW3BggWDa6SFc+fOta3t379/gJ0Ml+PHj7etPfbYY6XzTkxcvr+ixl1vQHKEHUiCsANJEHYgCcIOJEHYgSQIO5AE97N3qeye9VtuuaV03gMHDpTWb7755tL6rbfeWlofHR1tW7vjjjtK5z127Fhpff78+aX1Ki5evFha/+CDD0rrIyMjPS/7vffeK61fzufZ22HLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcD/7FWD27Nlta4sWLSqdd8+ePaX122+/vaeeutHp9/IPHTpUWu90/cKcOXPa1tavX18676ZNm0rrw6zn+9ltP2n7jO3JadPm2N5p++3isf2/NgBDoZvd+N9LWnnJtIcl7YqImyTtKl4DGGIdwx4RuyWdvWTyakmbi+ebJd1Tc18AatbrtfFzI+Jk8fyUpLnt3mh7TNJYj8sBUJPKN8JERJQdeIuIcUnjEgfogCb1eurttO0RSSoez9TXEoB+6DXs2yWtK56vk/RsPe0A6JeO59ltPyVpVNK1kk5L2ijpGUlbJd0g6V1J90XEpQfxWn0Wu/Ho2r333lta37p1a2l9cnKybe2uu+4qnffs2Y7/nIdWu/PsHb+zR8TaNqUVlToCMFBcLgskQdiBJAg7kARhB5Ig7EAS3OKKxlx//fWl9X379lWaf82aNW1r27ZtK533csaQzUByhB1IgrADSRB2IAnCDiRB2IEkCDuQBEM2ozGdfs75uuuuK61/+OGHpfWDBw9+6Z6uZGzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJ7mdHXy1btqxt7cUXXyydd+bMmaX10dHR0vru3btL61cq7mcHkiPsQBKEHUiCsANJEHYgCcIOJEHYgSS4nx19tWrVqra1TufRd+3aVVp/5ZVXeuopq45bdttP2j5je3LatEdtn7C9t/hr/18UwFDoZjf+95JWtpj+m4hYVPw9X29bAOrWMewRsVvS2QH0AqCPqhyge8D2m8Vu/ux2b7I9ZnvC9kSFZQGoqNewb5K0UNIiSScl/ardGyNiPCKWRMSSHpcFoAY9hT0iTkfEpxHxL0m/k7S03rYA1K2nsNsemfbye5Im270XwHDoeJ7d9lOSRiVda/u4pI2SRm0vkhSSjkr6UR97xBC76qqrSusrV7Y6kTPlwoULpfNu3LixtP7JJ5+U1vF5HcMeEWtbTH6iD70A6CMulwWSIOxAEoQdSIKwA0kQdiAJbnFFJRs2bCitL168uG3thRdeKJ335Zdf7qkntMaWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYMhmlLr77rtL688880xp/aOPPmpbK7v9VZJeffXV0jpaY8hmIDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiC+9mTu+aaa0rrjz/+eGl9xowZpfXnn28/5ifn0QeLLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMH97Fe4TufBO53rvu2220rr77zzTmm97J71TvOiNz3fz257vu2/2n7L9n7bPy6mz7G90/bbxePsupsGUJ9uduMvSvppRHxL0h2S1tv+lqSHJe2KiJsk7SpeAxhSHcMeEScj4vXi+TlJByTNk7Ra0ubibZsl3dOvJgFU96Wujbe9QNJiSX+XNDciThalU5LmtplnTNJY7y0CqEPXR+Ntz5K0TdJPIuKf02sxdZSv5cG3iBiPiCURsaRSpwAq6SrstmdqKuhbIuLPxeTTtkeK+oikM/1pEUAdOu7G27akJyQdiIhfTyttl7RO0i+Kx2f70iEqWbhwYWm906m1Th566KHSOqfXhkc339mXSfqBpH229xbTHtFUyLfa/qGkdyXd158WAdShY9gj4m+SWp6kl7Si3nYA9AuXywJJEHYgCcIOJEHYgSQIO5AEPyV9Bbjxxhvb1nbs2FHpszds2FBaf+655yp9PgaHLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59ivA2Fj7X/264YYbKn32Sy+9VFof5E+Roxq27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZLwPLly8vrT/44IMD6gSXM7bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEN+Ozz5f0B0lzJYWk8Yj4re1HJd0v6YPirY9ExPP9ajSzO++8s7Q+a9asnj+70/jp58+f7/mzMVy6uajmoqSfRsTrtr8maY/tnUXtNxHxy/61B6Au3YzPflLSyeL5OdsHJM3rd2MA6vWlvrPbXiBpsaS/F5MesP2m7Sdtz24zz5jtCdsTlToFUEnXYbc9S9I2ST+JiH9K2iRpoaRFmtry/6rVfBExHhFLImJJDf0C6FFXYbc9U1NB3xIRf5akiDgdEZ9GxL8k/U7S0v61CaCqjmG3bUlPSDoQEb+eNn1k2tu+J2my/vYA1KWbo/HLJP1A0j7be4tpj0haa3uRpk7HHZX0o750iEreeOON0vqKFStK62fPnq2zHTSom6Pxf5PkFiXOqQOXEa6gA5Ig7EAShB1IgrADSRB2IAnCDiThQQ65a5vxfYE+i4hWp8rZsgNZEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoMesvkfkt6d9vraYtowGtbehrUvid56VWdvN7YrDPSimi8s3J4Y1t+mG9behrUvid56Naje2I0HkiDsQBJNh3284eWXGdbehrUvid56NZDeGv3ODmBwmt6yAxgQwg4k0UjYba+0fdD2YdsPN9FDO7aP2t5ne2/T49MVY+idsT05bdoc2zttv108thxjr6HeHrV9olh3e22vaqi3+bb/avst2/tt/7iY3ui6K+lrIOtt4N/Zbc+QdEjSdyQdl/SapLUR8dZAG2nD9lFJSyKi8QswbH9b0nlJf4iI/y6mPSbpbET8ovgf5eyI+NmQ9PaopPNND+NdjFY0Mn2YcUn3SPpfNbjuSvq6TwNYb01s2ZdKOhwRRyLigqQ/SVrdQB9DLyJ2S7p0SJbVkjYXzzdr6h/LwLXpbShExMmIeL14fk7SZ8OMN7ruSvoaiCbCPk/SsWmvj2u4xnsPSTts77E91nQzLcyNiJPF81OS5jbZTAsdh/EepEuGGR+addfL8OdVcYDui5ZHxK2S/kfS+mJ3dSjF1HewYTp32tUw3oPSYpjx/2hy3fU6/HlVTYT9hKT5015/vZg2FCLiRPF4RtLTGr6hqE9/NoJu8Xim4X7+Y5iG8W41zLiGYN01Ofx5E2F/TdJNtr9h+6uSvi9pewN9fIHtq4sDJ7J9taTvaviGot4uaV3xfJ2kZxvs5XOGZRjvdsOMq+F11/jw5xEx8D9JqzR1RP4dST9vooc2fX1T0hvF3/6me5P0lKZ26z7R1LGNH0q6RtIuSW9L+n9Jc4aotz9K2ifpTU0Fa6Sh3pZrahf9TUl7i79VTa+7kr4Gst64XBZIggN0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DEvwEvYRv57rmVLgAAAABJRU5ErkJggg==\n", 1271 | "text/plain": [ 1272 | "
" 1273 | ] 1274 | }, 1275 | "metadata": { 1276 | "tags": [], 1277 | "needs_background": "light" 1278 | } 1279 | } 1280 | ] 1281 | }, 1282 | { 1283 | "cell_type": "code", 1284 | "metadata": { 1285 | "id": "PzRhgAXRO5HL" 1286 | }, 1287 | "source": [ 1288 | "# Pytorch uses objects called DataLoaders to feed data into models.\n", 1289 | "# Dataloaders handle details like data shuffling and how many items to include per batch\n", 1290 | "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)\n", 1291 | "val_loader = DataLoader(val_dataset, batch_size=1)\n", 1292 | "test_loader = DataLoader(test_dataset, batch_size=1)" 1293 | ], 1294 | "execution_count": 7, 1295 | "outputs": [] 1296 | }, 1297 | { 1298 | "cell_type": "markdown", 1299 | "metadata": { 1300 | "id": "xkthXdDK6GIe" 1301 | }, 1302 | "source": [ 1303 | "## Step 3: Model\n", 1304 | "We'll create a simple fully connected network model that takes in images and predicts which digit they represent" 1305 | ] 1306 | }, 1307 | { 1308 | "cell_type": "code", 1309 | "metadata": { 1310 | "id": "iKE0n87-6Po_" 1311 | }, 1312 | "source": [ 1313 | "class FullyConnectedNetwork(nn.Module):\n", 1314 | " \"\"\" A neural network designed to predict which digit is depicted in an image \"\"\"\n", 1315 | " def __init__(self):\n", 1316 | " \"\"\" This function initializes the layers for the network \"\"\"\n", 1317 | " # Call the nn.Module init function\n", 1318 | " super(FullyConnectedNetwork, self).__init__()\n", 1319 | " \n", 1320 | " # Create the neural network layers\n", 1321 | " self.fc1 = nn.Linear(784, 128)\n", 1322 | " self.fc2 = nn.Linear(128, 10)\n", 1323 | "\n", 1324 | " def forward(self, pixels: torch.Tensor):\n", 1325 | " \"\"\" \n", 1326 | " The forward function takes an image or batch of images as the input and returns \n", 1327 | " the predicted probability of each class (digit)\n", 1328 | " \"\"\"\n", 1329 | " # Feed pixels into first fully connected layer\n", 1330 | " x = self.fc1(pixels)\n", 1331 | "\n", 1332 | " # Apply the ReLU nonlinearity\n", 1333 | " x = F.relu(x)\n", 1334 | " \n", 1335 | " # Feed the output of the first layer into the second layer\n", 1336 | " x = self.fc2(x)\n", 1337 | "\n", 1338 | " # Apply a softmax to scale the outputs to be between 0 and 1 and sum to 1 \n", 1339 | " x = F.softmax(x, dim=-1)\n", 1340 | " \n", 1341 | " return x" 1342 | ], 1343 | "execution_count": 8, 1344 | "outputs": [] 1345 | }, 1346 | { 1347 | "cell_type": "code", 1348 | "metadata": { 1349 | "id": "_ASGwgHQH7KI", 1350 | "outputId": "f666e725-4a80-4e75-c1b9-656a52c07735", 1351 | "colab": { 1352 | "base_uri": "https://localhost:8080/", 1353 | "height": 85 1354 | } 1355 | }, 1356 | "source": [ 1357 | "model = FullyConnectedNetwork()\n", 1358 | "\n", 1359 | "example, label = train_dataset[0]\n", 1360 | "print(model(torch.Tensor(example)))\n", 1361 | "pred = torch.argmax(model(torch.Tensor(example)))\n", 1362 | "print('An untrained model predicted {}, but the real number was {}'.format(pred, label))" 1363 | ], 1364 | "execution_count": 9, 1365 | "outputs": [ 1366 | { 1367 | "output_type": "stream", 1368 | "text": [ 1369 | "tensor([2.3221e-11, 6.7203e-34, 2.3421e-26, 1.4276e-25, 1.6089e-37, 9.9330e-23,\n", 1370 | " 1.6785e-19, 5.6315e-28, 1.0000e+00, 4.5462e-28],\n", 1371 | " grad_fn=)\n", 1372 | "An untrained model predicted 8, but the real number was 9\n" 1373 | ], 1374 | "name": "stdout" 1375 | } 1376 | ] 1377 | }, 1378 | { 1379 | "cell_type": "markdown", 1380 | "metadata": { 1381 | "id": "KO0Nq0xr6lwU" 1382 | }, 1383 | "source": [ 1384 | "## Step 4: Training\n", 1385 | "Now we'll train the model to make better predictions" 1386 | ] 1387 | }, 1388 | { 1389 | "cell_type": "code", 1390 | "metadata": { 1391 | "id": "JYpiahY7IDvG", 1392 | "outputId": "c381bc81-ce44-44cf-81a4-7fda4a6ea34a", 1393 | "colab": { 1394 | "base_uri": "https://localhost:8080/", 1395 | "height": 187 1396 | } 1397 | }, 1398 | "source": [ 1399 | "model = FullyConnectedNetwork()\n", 1400 | "\n", 1401 | "# Create the loss and optimizer\n", 1402 | "loss_fn = torch.nn.NLLLoss()\n", 1403 | "\n", 1404 | "optimizer = torch.optim.Adam(model.parameters())\n", 1405 | "\n", 1406 | "# An epoch is the number of iterations it takes for the model to see every training point once\n", 1407 | "epochs = 10\n", 1408 | "for epoch in range(epochs):\n", 1409 | " for i, batch in enumerate(train_loader):\n", 1410 | " images, labels = batch\n", 1411 | " images = images.float()\n", 1412 | "\n", 1413 | " # Zero out the gradient on the optimizer\n", 1414 | " optimizer.zero_grad()\n", 1415 | " output = model(images)\n", 1416 | " loss = loss_fn(output, labels)\n", 1417 | " \n", 1418 | " # Tell the optimizer to calculate the gradient of the loss function\n", 1419 | " loss.backward()\n", 1420 | " \n", 1421 | " # Update the model's weights\n", 1422 | " optimizer.step()\n", 1423 | "\n", 1424 | " num_correct = 0\n", 1425 | " for batch in val_loader:\n", 1426 | " image, label = batch\n", 1427 | " image = image.float()\n", 1428 | "\n", 1429 | " output = model(image)\n", 1430 | " prediction = np.argmax(output.detach().numpy())\n", 1431 | " if prediction == label:\n", 1432 | " num_correct += 1\n", 1433 | " print('Val acc = {}'.format(num_correct / len(val_dataset)))\n", 1434 | "\n", 1435 | " \n" 1436 | ], 1437 | "execution_count": 10, 1438 | "outputs": [ 1439 | { 1440 | "output_type": "stream", 1441 | "text": [ 1442 | "Val acc = 0.4815\n", 1443 | "Val acc = 0.511\n", 1444 | "Val acc = 0.527\n", 1445 | "Val acc = 0.5235\n", 1446 | "Val acc = 0.5415\n", 1447 | "Val acc = 0.525\n", 1448 | "Val acc = 0.5205\n", 1449 | "Val acc = 0.5385\n", 1450 | "Val acc = 0.5415\n", 1451 | "Val acc = 0.52\n" 1452 | ], 1453 | "name": "stdout" 1454 | } 1455 | ] 1456 | }, 1457 | { 1458 | "cell_type": "markdown", 1459 | "metadata": { 1460 | "id": "TpJiAh4g6vTh" 1461 | }, 1462 | "source": [ 1463 | "## Step 5: Evaluation\n", 1464 | "Finally, we'll measure the trained model's performance on a held-out test set. The test set measures the final model's ability to make predictions on data it hasn't seen before." 1465 | ] 1466 | }, 1467 | { 1468 | "cell_type": "code", 1469 | "metadata": { 1470 | "id": "cOYV8_GXby90", 1471 | "outputId": "48af4966-b9ec-445d-d1fa-09681e67caa1", 1472 | "colab": { 1473 | "base_uri": "https://localhost:8080/", 1474 | "height": 34 1475 | } 1476 | }, 1477 | "source": [ 1478 | "num_correct = 0\n", 1479 | "for batch in test_loader:\n", 1480 | " image, label = batch\n", 1481 | " image = image.float()\n", 1482 | "\n", 1483 | " output = model(image)\n", 1484 | " prediction = np.argmax(output.detach().numpy())\n", 1485 | " if prediction == label:\n", 1486 | " num_correct += 1\n", 1487 | "\n", 1488 | "print('Test accuracy: {}'.format(num_correct / len(test_dataset)))" 1489 | ], 1490 | "execution_count": 11, 1491 | "outputs": [ 1492 | { 1493 | "output_type": "stream", 1494 | "text": [ 1495 | "Test accuracy: 0.5275\n" 1496 | ], 1497 | "name": "stdout" 1498 | } 1499 | ] 1500 | } 1501 | ] 1502 | } -------------------------------------------------------------------------------- /notebooks/predictor.names: -------------------------------------------------------------------------------- 1 | 8988T|DNase|None 2 | AoSMC|DNase|None 3 | Chorion|DNase|None 4 | CLL|DNase|None 5 | Fibrobl|DNase|None 6 | FibroP|DNase|None 7 | Gliobla|DNase|None 8 | GM12891|DNase|None 9 | GM12892|DNase|None 10 | GM18507|DNase|None 11 | GM19238|DNase|None 12 | GM19239|DNase|None 13 | GM19240|DNase|None 14 | H9ES|DNase|None 15 | HeLa-S3|DNase|IFNa4h 16 | Hepatocytes|DNase|None 17 | HPDE6-E6E7|DNase|None 18 | HSMM_emb|DNase|None 19 | HTR8svn|DNase|None 20 | Huh-7.5|DNase|None 21 | Huh-7|DNase|None 22 | iPS|DNase|None 23 | Ishikawa|DNase|Estradiol_100nM_1hr 24 | Ishikawa|DNase|4OHTAM_20nM_72hr 25 | LNCaP|DNase|androgen 26 | MCF-7|DNase|Hypoxia_LacAcid 27 | Medullo|DNase|None 28 | Melano|DNase|None 29 | Myometr|DNase|None 30 | Osteobl|DNase|None 31 | PanIsletD|DNase|None 32 | PanIslets|DNase|None 33 | pHTE|DNase|None 34 | ProgFib|DNase|None 35 | RWPE1|DNase|None 36 | Stellate|DNase|None 37 | T-47D|DNase|None 38 | Adult_CD4_Th0|DNase|None 39 | Urothelia|DNase|None 40 | Urothelia|DNase|UT189 41 | AG04449|DNase|None 42 | AG04450|DNase|None 43 | AG09309|DNase|None 44 | AG09319|DNase|None 45 | AG10803|DNase|None 46 | AoAF|DNase|None 47 | BE2_C|DNase|None 48 | BJ|DNase|None 49 | Caco-2|DNase|None 50 | CD20+|DNase|None 51 | CD34+_Mobilized|DNase|None 52 | CMK|DNase|None 53 | A549|DNase|None 54 | GM12878|DNase|None 55 | H1-hESC|DNase|None 56 | HeLa-S3|DNase|None 57 | HepG2|DNase|None 58 | HMEC|DNase|None 59 | HSMMtube|DNase|None 60 | HSMM|DNase|None 61 | HUVEC|DNase|None 62 | K562|DNase|None 63 | LNCaP|DNase|None 64 | MCF-7|DNase|None 65 | NHEK|DNase|None 66 | Th1|DNase|None 67 | GM06990|DNase|None 68 | GM12864|DNase|None 69 | GM12865|DNase|None 70 | H7-hESC|DNase|None 71 | HAc|DNase|None 72 | HAEpiC|DNase|None 73 | HA-h|DNase|None 74 | HA-sp|DNase|None 75 | HBMEC|DNase|None 76 | HCFaa|DNase|None 77 | HCF|DNase|None 78 | HCM|DNase|None 79 | HConF|DNase|None 80 | HCPEpiC|DNase|None 81 | HCT-116|DNase|None 82 | HEEpiC|DNase|None 83 | HFF-Myc|DNase|None 84 | HFF|DNase|None 85 | HGF|DNase|None 86 | HIPEpiC|DNase|None 87 | HL-60|DNase|None 88 | HMF|DNase|None 89 | HMVEC-dAd|DNase|None 90 | HMVEC-dBl-Ad|DNase|None 91 | HMVEC-dBl-Neo|DNase|None 92 | HMVEC-dLy-Ad|DNase|None 93 | HMVEC-dLy-Neo|DNase|None 94 | HMVEC-dNeo|DNase|None 95 | HMVEC-LBl|DNase|None 96 | HMVEC-LLy|DNase|None 97 | HNPCEpiC|DNase|None 98 | HPAEC|DNase|None 99 | HPAF|DNase|None 100 | HPdLF|DNase|None 101 | HPF|DNase|None 102 | HRCEpiC|DNase|None 103 | HRE|DNase|None 104 | HRGEC|DNase|None 105 | HRPEpiC|DNase|None 106 | HVMF|DNase|None 107 | Jurkat|DNase|None 108 | Monocytes-CD14+RO01746|DNase|None 109 | NB4|DNase|None 110 | NH-A|DNase|None 111 | NHDF-Ad|DNase|None 112 | NHDF-neo|DNase|None 113 | NHLF|DNase|None 114 | NT2-D1|DNase|None 115 | PANC-1|DNase|None 116 | PrEC|DNase|None 117 | RPTEC|DNase|None 118 | SAEC|DNase|None 119 | SKMC|DNase|None 120 | SK-N-MC|DNase|None 121 | SK-N-SH_RA|DNase|None 122 | Th2|DNase|None 123 | WERI-Rb-1|DNase|None 124 | WI-38|DNase|4OHTAM_20nM_72hr 125 | WI-38|DNase|None 126 | Dnd41|CTCF|None 127 | Dnd41|EZH2|None 128 | GM12878|CTCF|None 129 | GM12878|EZH2|None 130 | H1-hESC|CHD1|None 131 | H1-hESC|CTCF|None 132 | H1-hESC|EZH2|None 133 | H1-hESC|JARID1A|None 134 | H1-hESC|RBBP5|None 135 | HeLa-S3|CTCF|None 136 | HeLa-S3|EZH2|None 137 | HeLa-S3|Pol2(b)|None 138 | HepG2|CTCF|None 139 | HepG2|EZH2|None 140 | HMEC|CTCF|None 141 | HMEC|EZH2|None 142 | HSMM|CTCF|None 143 | HSMM|EZH2|None 144 | HSMMtube|CTCF|None 145 | HSMMtube|EZH2|None 146 | HUVEC|CTCF|None 147 | HUVEC|EZH2|None 148 | HUVEC|Pol2(b)|None 149 | K562|CHD1|None 150 | K562|CTCF|None 151 | K562|EZH2|None 152 | K562|HDAC1|None 153 | K562|HDAC2|None 154 | K562|HDAC6|None 155 | K562|p300|None 156 | K562|PHF8|None 157 | K562|PLU1|None 158 | K562|Pol2(b)|None 159 | K562|RBBP5|None 160 | K562|SAP30|None 161 | NH-A|CTCF|None 162 | NH-A|EZH2|None 163 | NHDF-Ad|CTCF|None 164 | NHDF-Ad|EZH2|None 165 | NHEK|CTCF|None 166 | NHEK|EZH2|None 167 | NHEK|Pol2(b)|None 168 | NHLF|CTCF|None 169 | NHLF|EZH2|None 170 | Osteobl|CTCF|None 171 | A549|ATF3|EtOH_0.02pct 172 | A549|BCL3|EtOH_0.02pct 173 | A549|CREB1|DEX_100nM 174 | A549|CTCF|DEX_100nM 175 | A549|CTCF|EtOH_0.02pct 176 | A549|ELF1|EtOH_0.02pct 177 | A549|ETS1|EtOH_0.02pct 178 | A549|FOSL2|EtOH_0.02pct 179 | A549|FOXA1|DEX_100nM 180 | A549|GABP|EtOH_0.02pct 181 | A549|GR|DEX_500pM 182 | A549|GR|DEX_50nM 183 | A549|GR|DEX_5nM 184 | A549|GR|DEX_100nM 185 | A549|NRSF|EtOH_0.02pct 186 | A549|p300|EtOH_0.02pct 187 | A549|Pol2|DEX_100nM 188 | A549|Pol2|EtOH_0.02pct 189 | A549|Sin3Ak-20|EtOH_0.02pct 190 | A549|SIX5|EtOH_0.02pct 191 | A549|TAF1|EtOH_0.02pct 192 | A549|TCF12|EtOH_0.02pct 193 | A549|USF-1|DEX_100nM 194 | A549|USF-1|EtOH_0.02pct 195 | A549|USF-1|EtOH_0.02pct 196 | A549|YY1|EtOH_0.02pct 197 | A549|ZBTB33|EtOH_0.02pct 198 | ECC-1|CTCF|DMSO_0.02pct 199 | ECC-1|ERalpha|BPA_100nM 200 | ECC-1|ERalpha|Estradiol_10nM 201 | ECC-1|ERalpha|Genistein_100nM 202 | ECC-1|FOXA1|DMSO_0.02pct 203 | ECC-1|GR|DEX_100nM 204 | ECC-1|Pol2|DMSO_0.02pct 205 | GM12878|ATF2|None 206 | GM12878|ATF3|None 207 | GM12878|BATF|None 208 | GM12878|BCL11A|None 209 | GM12878|BCL3|None 210 | GM12878|BCLAF1|None 211 | GM12878|CEBPB|None 212 | GM12878|EBF1|None 213 | GM12878|Egr-1|None 214 | GM12878|ELF1|None 215 | GM12878|ETS1|None 216 | GM12878|FOXM1|None 217 | GM12878|GABP|None 218 | GM12878|IRF4|None 219 | GM12878|MEF2A|None 220 | GM12878|MEF2C|None 221 | GM12878|MTA3|None 222 | GM12878|NFATC1|None 223 | GM12878|NFIC|None 224 | GM12878|NRSF|None 225 | GM12878|p300|None 226 | GM12878|PAX5-C20|None 227 | GM12878|PAX5-N19|None 228 | GM12878|Pbx3|None 229 | GM12878|PML|None 230 | GM12878|Pol2-4H8|None 231 | GM12878|Pol2|None 232 | GM12878|POU2F2|None 233 | GM12878|PU.1|None 234 | GM12878|Rad21|None 235 | GM12878|RUNX3|None 236 | GM12878|RXRA|None 237 | GM12878|SIX5|None 238 | GM12878|SP1|None 239 | GM12878|SRF|None 240 | GM12878|STAT5A|None 241 | GM12878|TAF1|None 242 | GM12878|TCF12|None 243 | GM12878|TCF3|None 244 | GM12878|USF-1|None 245 | GM12878|YY1|None 246 | GM12878|ZBTB33|None 247 | GM12878|ZEB1|None 248 | GM12891|PAX5-C20|None 249 | GM12891|Pol2-4H8|None 250 | GM12891|Pol2|None 251 | GM12891|POU2F2|None 252 | GM12891|PU.1|None 253 | GM12891|TAF1|None 254 | GM12891|YY1|None 255 | GM12892|PAX5-C20|None 256 | GM12892|Pol2-4H8|None 257 | GM12892|Pol2|None 258 | GM12892|TAF1|None 259 | GM12892|YY1|None 260 | H1-hESC|ATF2|None 261 | H1-hESC|ATF3|None 262 | H1-hESC|BCL11A|None 263 | H1-hESC|CTCF|None 264 | H1-hESC|Egr-1|None 265 | H1-hESC|FOSL1|None 266 | H1-hESC|GABP|None 267 | H1-hESC|HDAC2|None 268 | H1-hESC|JunD|None 269 | H1-hESC|NANOG|None 270 | H1-hESC|NRSF|None 271 | H1-hESC|p300|None 272 | H1-hESC|Pol2-4H8|None 273 | H1-hESC|Pol2|None 274 | H1-hESC|POU5F1|None 275 | H1-hESC|Rad21|None 276 | H1-hESC|RXRA|None 277 | H1-hESC|Sin3Ak-20|None 278 | H1-hESC|SIX5|None 279 | H1-hESC|SP1|None 280 | H1-hESC|SP2|None 281 | H1-hESC|SP4|None 282 | H1-hESC|SRF|None 283 | H1-hESC|TAF1|None 284 | H1-hESC|TAF7|None 285 | H1-hESC|TCF12|None 286 | H1-hESC|TEAD4|None 287 | H1-hESC|USF-1|None 288 | H1-hESC|YY1|None 289 | HCT-116|Pol2-4H8|None 290 | HCT-116|YY1|None 291 | HCT-116|ZBTB33|None 292 | HeLa-S3|GABP|None 293 | HeLa-S3|NRSF|None 294 | HeLa-S3|Pol2|None 295 | HeLa-S3|TAF1|None 296 | HepG2|ATF3|None 297 | HepG2|BHLHE40|None 298 | HepG2|CEBPB|None 299 | HepG2|CEBPD|None 300 | HepG2|CTCF|None 301 | HepG2|ELF1|None 302 | HepG2|FOSL2|None 303 | HepG2|FOXA1|None 304 | HepG2|FOXA1|None 305 | HepG2|FOXA2|None 306 | HepG2|GABP|None 307 | HepG2|HDAC2|None 308 | HepG2|HNF4A|None 309 | HepG2|HNF4G|None 310 | HepG2|JunD|None 311 | HepG2|MBD4|None 312 | HepG2|MYBL2|None 313 | HepG2|NFIC|None 314 | HepG2|NRSF|None 315 | HepG2|NRSF|None 316 | HepG2|p300|None 317 | HepG2|Pol2-4H8|None 318 | HepG2|Pol2|None 319 | HepG2|Rad21|None 320 | HepG2|RXRA|None 321 | HepG2|Sin3Ak-20|None 322 | HepG2|SP1|None 323 | HepG2|SP2|None 324 | HepG2|SRF|None 325 | HepG2|TAF1|None 326 | HepG2|TCF12|None 327 | HepG2|TEAD4|None 328 | HepG2|USF-1|None 329 | HepG2|YY1|None 330 | HepG2|ZBTB33|None 331 | HepG2|ZBTB7A|None 332 | HUVEC|Pol2-4H8|None 333 | HUVEC|Pol2|None 334 | K562|ATF3|None 335 | K562|BCL3|None 336 | K562|BCLAF1|None 337 | K562|CBX3|None 338 | K562|CEBPB|None 339 | K562|CTCF|None 340 | K562|CTCFL|None 341 | K562|E2F6|None 342 | K562|Egr-1|None 343 | K562|ELF1|None 344 | K562|ETS1|None 345 | K562|FOSL1|None 346 | K562|GABP|None 347 | K562|GATA2|None 348 | K562|HDAC2|None 349 | K562|Max|None 350 | K562|MEF2A|None 351 | K562|NR2F2|None 352 | K562|NRSF|None 353 | K562|PML|None 354 | K562|Pol2-4H8|None 355 | K562|Pol2|None 356 | K562|PU.1|None 357 | K562|Rad21|None 358 | K562|Sin3Ak-20|None 359 | K562|SIX5|None 360 | K562|SP1|None 361 | K562|SP2|None 362 | K562|SRF|None 363 | K562|STAT5A|None 364 | K562|TAF1|None 365 | K562|TAF7|None 366 | K562|TEAD4|None 367 | K562|THAP1|None 368 | K562|TRIM28|None 369 | K562|USF-1|None 370 | K562|YY1|None 371 | K562|YY1|None 372 | K562|ZBTB33|None 373 | K562|ZBTB7A|None 374 | PANC-1|NRSF|None 375 | PANC-1|Pol2-4H8|None 376 | PANC-1|Sin3Ak-20|None 377 | PFSK-1|FOXP2|None 378 | PFSK-1|NRSF|None 379 | PFSK-1|Sin3Ak-20|None 380 | PFSK-1|TAF1|None 381 | SK-N-MC|FOXP2|None 382 | SK-N-MC|Pol2-4H8|None 383 | SK-N-SH|NRSF|None 384 | SK-N-SH|NRSF|None 385 | SK-N-SH|Pol2-4H8|None 386 | SK-N-SH_RA|CTCF|None 387 | SK-N-SH_RA|p300|None 388 | SK-N-SH_RA|Rad21|None 389 | SK-N-SH_RA|USF1|None 390 | SK-N-SH_RA|YY1|None 391 | SK-N-SH|Sin3Ak-20|None 392 | SK-N-SH|TAF1|None 393 | T-47D|CTCF|DMSO_0.02pct 394 | T-47D|ERalpha|BPA_100nM 395 | T-47D|ERalpha|Genistein_100nM 396 | T-47D|ERalpha|Estradiol_10nM 397 | T-47D|FOXA1|DMSO_0.02pct 398 | T-47D|GATA3|DMSO_0.02pct 399 | T-47D|p300|DMSO_0.02pct 400 | U87|NRSF|None 401 | U87|Pol2-4H8|None 402 | A549|BHLHE40|None 403 | A549|CEBPB|None 404 | A549|Max|None 405 | A549|Pol2(phosphoS2)|None 406 | A549|Rad21|None 407 | GM08714|ZNF274|None 408 | GM10847|NFKB|TNFa 409 | GM10847|Pol2|None 410 | GM12878|BHLHE40|None 411 | GM12878|BRCA1|None 412 | GM12878|c-Fos|None 413 | GM12878|CHD1|None 414 | GM12878|CHD2|None 415 | GM12878|COREST|None 416 | GM12878|CTCF|None 417 | GM12878|E2F4|None 418 | GM12878|EBF1|None 419 | GM12878|ELK1|None 420 | GM12878|IKZF1|None 421 | GM12878|JunD|None 422 | GM12878|Max|None 423 | GM12878|MAZ|None 424 | GM12878|Mxi1|None 425 | GM12878|NF-E2|None 426 | GM12878|NFKB|TNFa 427 | GM12878|NF-YA|None 428 | GM12878|NF-YB|None 429 | GM12878|Nrf1|None 430 | GM12878|p300|None 431 | GM12878|p300|None 432 | GM12878|Pol2|None 433 | GM12878|Pol2(phosphoS2)|None 434 | GM12878|Pol2|None 435 | GM12878|Pol3|None 436 | GM12878|Rad21|None 437 | GM12878|RFX5|None 438 | GM12878|SIN3A|None 439 | GM12878|SMC3|None 440 | GM12878|STAT1|None 441 | GM12878|STAT3|None 442 | GM12878|TBLR1|None 443 | GM12878|TBP|None 444 | GM12878|TR4|None 445 | GM12878|USF2|None 446 | GM12878|WHIP|None 447 | GM12878|YY1|None 448 | GM12878|Znf143|None 449 | GM12878|ZNF274|None 450 | GM12878|ZZZ3|None 451 | GM12891|NFKB|TNFa 452 | GM12891|Pol2|None 453 | GM12892|NFKB|TNFa 454 | GM12892|Pol2|None 455 | GM15510|NFKB|TNFa 456 | GM15510|Pol2|None 457 | GM18505|NFKB|TNFa 458 | GM18505|Pol2|None 459 | GM18526|NFKB|TNFa 460 | GM18526|Pol2|None 461 | GM18951|NFKB|TNFa 462 | GM18951|Pol2|None 463 | GM19099|NFKB|TNFa 464 | GM19099|Pol2|None 465 | GM19193|NFKB|TNFa 466 | GM19193|Pol2|None 467 | H1-hESC|Bach1|None 468 | H1-hESC|BRCA1|None 469 | H1-hESC|CEBPB|None 470 | H1-hESC|CHD1|None 471 | H1-hESC|CHD2|None 472 | H1-hESC|c-Jun|None 473 | H1-hESC|c-Myc|None 474 | H1-hESC|CtBP2|None 475 | H1-hESC|GTF2F1|None 476 | H1-hESC|JunD|None 477 | H1-hESC|MafK|None 478 | H1-hESC|Max|None 479 | H1-hESC|Mxi1|None 480 | H1-hESC|Nrf1|None 481 | H1-hESC|Rad21|None 482 | H1-hESC|RFX5|None 483 | H1-hESC|SIN3A|None 484 | H1-hESC|SUZ12|None 485 | H1-hESC|TBP|None 486 | H1-hESC|USF2|None 487 | H1-hESC|Znf143|None 488 | HCT-116|Pol2|None 489 | HCT-116|TCF7L2|None 490 | HEK293|ELK4|None 491 | HEK293|KAP1|None 492 | HEK293|Pol2|None 493 | HEK293|TCF7L2|None 494 | HEK293-T-REx|ZNF263|None 495 | HeLa-S3|AP-2alpha|None 496 | HeLa-S3|AP-2gamma|None 497 | HeLa-S3|BAF155|None 498 | HeLa-S3|BAF170|None 499 | HeLa-S3|BDP1|None 500 | HeLa-S3|BRCA1|None 501 | HeLa-S3|BRF1|None 502 | HeLa-S3|BRF2|None 503 | HeLa-S3|Brg1|None 504 | HeLa-S3|CEBPB|None 505 | HeLa-S3|c-Fos|None 506 | HeLa-S3|CHD2|None 507 | HeLa-S3|c-Jun|None 508 | HeLa-S3|c-Myc|None 509 | HeLa-S3|COREST|None 510 | HeLa-S3|E2F1|None 511 | HeLa-S3|E2F4|None 512 | HeLa-S3|E2F6|None 513 | HeLa-S3|ELK1|None 514 | HeLa-S3|ELK4|None 515 | HeLa-S3|GTF2F1|None 516 | HeLa-S3|HA-E2F1|None 517 | HeLa-S3|Ini1|None 518 | HeLa-S3|IRF3|None 519 | HeLa-S3|JunD|None 520 | HeLa-S3|MafK|None 521 | HeLa-S3|Max|None 522 | HeLa-S3|MAZ|None 523 | HeLa-S3|Mxi1|None 524 | HeLa-S3|NF-YA|None 525 | HeLa-S3|NF-YB|None 526 | HeLa-S3|Nrf1|None 527 | HeLa-S3|p300|None 528 | HeLa-S3|Pol2(phosphoS2)|None 529 | HeLa-S3|Pol2|None 530 | HeLa-S3|PRDM1|None 531 | HeLa-S3|Rad21|None 532 | HeLa-S3|RFX5|None 533 | HeLa-S3|RPC155|None 534 | HeLa-S3|SMC3|None 535 | HeLa-S3|SPT20|None 536 | HeLa-S3|STAT1|IFNg30 537 | HeLa-S3|STAT3|None 538 | HeLa-S3|TBP|None 539 | HeLa-S3|TCF7L2|None 540 | HeLa-S3|TCF7L2|None 541 | HeLa-S3|TFIIIC-110|None 542 | HeLa-S3|TR4|None 543 | HeLa-S3|USF2|None 544 | HeLa-S3|ZKSCAN1|None 545 | HeLa-S3|Znf143|None 546 | HeLa-S3|ZNF274|None 547 | HeLa-S3|ZZZ3|None 548 | HepG2|ARID3A|None 549 | HepG2|BHLHE40|None 550 | HepG2|BRCA1|None 551 | HepG2|CEBPB|forskolin 552 | HepG2|CEBPB|None 553 | HepG2|CHD2|None 554 | HepG2|c-Jun|None 555 | HepG2|COREST|None 556 | HepG2|ERRA|forskolin 557 | HepG2|GRp20|forskolin 558 | HepG2|HNF4A|forskolin 559 | HepG2|HSF1|forskolin 560 | HepG2|IRF3|None 561 | HepG2|JunD|None 562 | HepG2|MafF|None 563 | HepG2|MafK|None 564 | HepG2|MafK|None 565 | HepG2|Max|None 566 | HepG2|MAZ|None 567 | HepG2|Mxi1|None 568 | HepG2|Nrf1|None 569 | HepG2|p300|None 570 | HepG2|PGC1A|forskolin 571 | HepG2|Pol2|forskolin 572 | HepG2|Pol2|None 573 | HepG2|Pol2(phosphoS2)|None 574 | HepG2|Rad21|None 575 | HepG2|RFX5|None 576 | HepG2|SMC3|None 577 | NA|NA|NA 578 | HepG2|TBP|None 579 | HepG2|TCF7L2|None 580 | HepG2|TR4|None 581 | HepG2|USF2|None 582 | HepG2|ZNF274|None 583 | HUVEC|c-Fos|None 584 | HUVEC|c-Jun|None 585 | HUVEC|GATA-2|None 586 | HUVEC|Max|None 587 | HUVEC|Pol2|None 588 | IMR90|CEBPB|None 589 | IMR90|CTCF|None 590 | IMR90|MafK|None 591 | IMR90|Pol2|None 592 | IMR90|Rad21|None 593 | K562|ARID3A|None 594 | K562|ATF1|None 595 | K562|ATF3|None 596 | K562|Bach1|None 597 | K562|BDP1|None 598 | K562|BHLHE40|None 599 | K562|BRF1|None 600 | K562|BRF2|None 601 | K562|Brg1|None 602 | K562|CCNT2|None 603 | K562|CEBPB|None 604 | K562|c-Fos|None 605 | K562|CHD2|None 606 | K562|c-Jun|IFNa30 607 | K562|c-Jun|IFNa6h 608 | K562|c-Jun|IFNg30 609 | K562|c-Jun|IFNg6h 610 | K562|c-Jun|None 611 | K562|c-Myc|IFNa30 612 | K562|c-Myc|IFNa6h 613 | K562|c-Myc|IFNg30 614 | K562|c-Myc|IFNg6h 615 | K562|c-Myc|None 616 | K562|c-Myc|None 617 | K562|COREST|None 618 | K562|COREST|None 619 | K562|CTCF|None 620 | K562|E2F4|None 621 | K562|E2F6|None 622 | K562|ELK1|None 623 | K562|GATA-1|None 624 | K562|GATA-2|None 625 | K562|GTF2B|None 626 | K562|GTF2F1|None 627 | K562|HMGN3|None 628 | K562|Ini1|None 629 | K562|IRF1|IFNa30 630 | K562|IRF1|IFNa6h 631 | K562|IRF1|IFNg30 632 | K562|IRF1|IFNg6h 633 | K562|JunD|None 634 | K562|KAP1|None 635 | K562|MafF|None 636 | K562|MafK|None 637 | K562|Max|None 638 | K562|MAZ|None 639 | K562|Mxi1|None 640 | K562|NELFe|None 641 | K562|NF-E2|None 642 | K562|NF-YA|None 643 | K562|NF-YB|None 644 | K562|Nrf1|None 645 | K562|p300|None 646 | K562|Pol2|IFNa30 647 | K562|Pol2|IFNa6h 648 | K562|Pol2|IFNg30 649 | K562|Pol2|IFNg6h 650 | K562|Pol2|None 651 | K562|Pol2(phosphoS2)|None 652 | K562|Pol2(phosphoS2)|None 653 | K562|Pol2|None 654 | K562|Pol3|None 655 | K562|Rad21|None 656 | K562|RFX5|None 657 | K562|RPC155|None 658 | K562|SETDB1|MNaseD 659 | K562|SETDB1|None 660 | K562|SIRT6|None 661 | K562|SMC3|None 662 | K562|STAT1|IFNa30 663 | K562|STAT1|IFNa6h 664 | K562|STAT1|IFNg30 665 | K562|STAT1|IFNg6h 666 | K562|STAT2|IFNa30 667 | K562|STAT2|IFNa6h 668 | K562|TAL1|None 669 | K562|TBLR1|None 670 | K562|TBLR1|None 671 | K562|TBP|None 672 | K562|TFIIIC-110|None 673 | K562|TR4|None 674 | K562|UBF|None 675 | K562|UBTF|None 676 | K562|USF2|None 677 | K562|YY1|None 678 | K562|Znf143|None 679 | K562|ZNF263|None 680 | K562|ZNF274|None 681 | K562|ZNF274|None 682 | MCF10A-Er-Src|c-Fos|EtOH_0.01pct 683 | MCF10A-Er-Src|c-Fos|4OHTAM_1uM_12hr 684 | MCF10A-Er-Src|c-Fos|4OHTAM_1uM_4hr 685 | MCF10A-Er-Src|c-Fos|4OHTAM_1uM_36hr 686 | MCF10A-Er-Src|c-Myc|EtOH_0.01pct 687 | MCF10A-Er-Src|c-Myc|4OHTAM_1uM_4hr 688 | MCF10A-Er-Src|E2F4|4OHTAM_1uM_36hr 689 | MCF10A-Er-Src|Pol2|EtOH_0.01pct 690 | MCF10A-Er-Src|Pol2|4OHTAM_1uM_36hr 691 | MCF10A-Er-Src|STAT3|EtOH_0.01pct_4hr 692 | MCF10A-Er-Src|STAT3|EtOH_0.01pct_12hr 693 | MCF10A-Er-Src|STAT3|EtOH_0.01pct 694 | MCF10A-Er-Src|STAT3|4OHTAM_1uM_12hr 695 | MCF10A-Er-Src|STAT3|4OHTAM_1uM_36hr 696 | MCF-7|GATA3|None 697 | MCF-7|GATA3|None 698 | MCF-7|HA-E2F1|None 699 | MCF-7|TCF7L2|None 700 | MCF-7|ZNF217|None 701 | NB4|c-Myc|None 702 | NB4|Max|None 703 | NB4|Pol2|None 704 | NT2-D1|SUZ12|None 705 | NT2-D1|YY1|None 706 | NT2-D1|ZNF274|None 707 | PANC-1|TCF7L2|None 708 | PBDEFetal|GATA-1|None 709 | PBDE|GATA-1|None 710 | PBDE|Pol2|None 711 | Raji|Pol2|None 712 | SH-SY5Y|GATA-2|None 713 | SH-SY5Y|GATA3|None 714 | U2OS|KAP1|None 715 | U2OS|SETDB1|None 716 | K562|eGFP-FOS|None 717 | K562|eGFP-GATA2|None 718 | K562|eGFP-HDAC8|None 719 | K562|eGFP-JunB|None 720 | K562|eGFP-JunD|None 721 | A549|CTCF|None 722 | A549|Pol2|None 723 | Fibrobl|CTCF|None 724 | Gliobla|CTCF|None 725 | Gliobla|Pol2|None 726 | GM12878|c-Myc|None 727 | GM12878|CTCF|None 728 | GM12878|Pol2|None 729 | GM12891|CTCF|None 730 | GM12892|CTCF|None 731 | GM19238|CTCF|None 732 | GM19239|CTCF|None 733 | GM19240|CTCF|None 734 | H1-hESC|c-Myc|None 735 | H1-hESC|CTCF|None 736 | H1-hESC|Pol2|None 737 | HeLa-S3|c-Myc|None 738 | HeLa-S3|CTCF|None 739 | HeLa-S3|Pol2|None 740 | HepG2|c-Myc|None 741 | HepG2|CTCF|None 742 | HepG2|Pol2|None 743 | HUVEC|c-Myc|None 744 | HUVEC|CTCF|None 745 | HUVEC|Pol2|None 746 | K562|c-Myc|None 747 | K562|CTCF|None 748 | K562|Pol2|None 749 | MCF-7|c-Myc|estrogen 750 | MCF-7|c-Myc|serum_stimulated_media 751 | MCF-7|c-Myc|serum_starved_media 752 | MCF-7|c-Myc|vehicle 753 | MCF-7|CTCF|estrogen 754 | MCF-7|CTCF|serum_stimulated_media 755 | MCF-7|CTCF|serum_starved_media 756 | MCF-7|CTCF|None 757 | MCF-7|CTCF|vehicle 758 | MCF-7|Pol2|serum_stimulated_media 759 | MCF-7|Pol2|serum_starved_media 760 | MCF-7|Pol2|None 761 | NHEK|CTCF|None 762 | ProgFib|CTCF|None 763 | ProgFib|Pol2|None 764 | A549|CTCF|None 765 | AG04449|CTCF|None 766 | AG04450|CTCF|None 767 | AG09309|CTCF|None 768 | AG09319|CTCF|None 769 | AG10803|CTCF|None 770 | AoAF|CTCF|None 771 | BE2_C|CTCF|None 772 | BJ|CTCF|None 773 | Caco-2|CTCF|None 774 | GM06990|CTCF|None 775 | GM12801|CTCF|None 776 | GM12864|CTCF|None 777 | GM12865|CTCF|None 778 | GM12872|CTCF|None 779 | GM12873|CTCF|None 780 | GM12874|CTCF|None 781 | GM12875|CTCF|None 782 | GM12878|CTCF|None 783 | HAc|CTCF|None 784 | HA-sp|CTCF|None 785 | HBMEC|CTCF|None 786 | HCFaa|CTCF|None 787 | HCM|CTCF|None 788 | HCPEpiC|CTCF|None 789 | HCT-116|CTCF|None 790 | HEEpiC|CTCF|None 791 | HEK293|CTCF|None 792 | HeLa-S3|CTCF|None 793 | HepG2|CTCF|None 794 | HFF|CTCF|None 795 | HFF-Myc|CTCF|None 796 | HL-60|CTCF|None 797 | HMEC|CTCF|None 798 | HMF|CTCF|None 799 | HPAF|CTCF|None 800 | HPF|CTCF|None 801 | HRE|CTCF|None 802 | HRPEpiC|CTCF|None 803 | HUVEC|CTCF|None 804 | HVMF|CTCF|None 805 | K562|CTCF|None 806 | MCF-7|CTCF|None 807 | NB4|CTCF|None 808 | NHDF-neo|CTCF|None 809 | NHEK|CTCF|None 810 | NHLF|CTCF|None 811 | RPTEC|CTCF|None 812 | SAEC|CTCF|None 813 | SK-N-SH_RA|CTCF|None 814 | WERI-Rb-1|CTCF|None 815 | WI-38|CTCF|None 816 | H1-hESC|H2AK5ac|None 817 | H1-hESC|H2AZ|None 818 | H1-hESC|H2BK120ac|None 819 | H1-hESC|H2BK12ac|None 820 | H1-hESC|H2BK15ac|None 821 | H1-hESC|H2BK20ac|None 822 | H1-hESC|H2BK5ac|None 823 | H1-hESC|H3K14ac|None 824 | H1-hESC|H3K18ac|None 825 | H1-hESC|H3K23ac|None 826 | H1-hESC|H3K23me2|None 827 | H1-hESC|H3K27ac|None 828 | H1-hESC|H3K27me3|None 829 | H1-hESC|H3K36me3|None 830 | H1-hESC|H3K4ac|None 831 | H1-hESC|H3K4me1|None 832 | H1-hESC|H3K4me2|None 833 | H1-hESC|H3K4me3|None 834 | H1-hESC|H3K56ac|None 835 | H1-hESC|H3K79me1|None 836 | H1-hESC|H3K79me2|None 837 | H1-hESC|H3K9ac|None 838 | H1-hESC|H3K9me3|None 839 | H1-hESC|H4K20me1|None 840 | H1-hESC|H4K5ac|None 841 | H1-hESC|H4K8ac|None 842 | H1-hESC|H4K91ac|None 843 | K562|H2AZ|None 844 | K562|H3K27ac|None 845 | K562|H3K27me3|None 846 | K562|H3K36me3|None 847 | K562|H3K4me1|None 848 | K562|H3K4me2|None 849 | K562|H3K4me3|None 850 | K562|H3K79me2|None 851 | K562|H3K9ac|None 852 | K562|H3K9me1|None 853 | K562|H3K9me3|None 854 | K562|H4K20me1|None 855 | Monocytes-CD14+RO01746|H2AZ|None 856 | Monocytes-CD14+RO01746|H3K27ac|None 857 | Monocytes-CD14+RO01746|H3K27me3|None 858 | Monocytes-CD14+RO01746|H3K36me3|None 859 | Monocytes-CD14+RO01746|H3K4me1|None 860 | Monocytes-CD14+RO01746|H3K4me2|None 861 | Monocytes-CD14+RO01746|H3K4me3|None 862 | Monocytes-CD14+RO01746|H3K79me2|None 863 | Monocytes-CD14+RO01746|H3K9ac|None 864 | Monocytes-CD14+RO01746|H3K9me3|None 865 | Monocytes-CD14+RO01746|H4K20me1|None 866 | NH-A|H2AZ|None 867 | NH-A|H3K27ac|None 868 | NH-A|H3K27me3|None 869 | NH-A|H3K36me3|None 870 | NH-A|H3K4me1|None 871 | NH-A|H3K4me2|None 872 | NH-A|H3K4me3|None 873 | NH-A|H3K79me2|None 874 | NH-A|H3K9ac|None 875 | NH-A|H3K9me3|None 876 | NH-A|H4K20me1|None 877 | NHDF-Ad|H2AZ|None 878 | NHDF-Ad|H3K27ac|None 879 | NHDF-Ad|H3K27me3|None 880 | NHDF-Ad|H3K36me3|None 881 | NHDF-Ad|H3K4me1|None 882 | NHDF-Ad|H3K4me2|None 883 | NHDF-Ad|H3K4me3|None 884 | NHDF-Ad|H3K79me2|None 885 | NHDF-Ad|H3K9ac|None 886 | NHDF-Ad|H3K9me3|None 887 | NHDF-Ad|H4K20me1|None 888 | NHEK|H2AZ|None 889 | NHEK|H3K27ac|None 890 | NHEK|H3K27me3|None 891 | NHEK|H3K36me3|None 892 | NHEK|H3K4me1|None 893 | NHEK|H3K4me2|None 894 | NHEK|H3K4me3|None 895 | NHEK|H3K79me2|None 896 | NHEK|H3K9ac|None 897 | NHEK|H3K9me1|None 898 | NHEK|H3K9me3|None 899 | NHEK|H4K20me1|None 900 | NHLF|H2AZ|None 901 | NHLF|H3K27ac|None 902 | NHLF|H3K27me3|None 903 | NHLF|H3K36me3|None 904 | NHLF|H3K4me1|None 905 | NHLF|H3K4me2|None 906 | NHLF|H3K4me3|None 907 | NHLF|H3K79me2|None 908 | NHLF|H3K9ac|None 909 | NHLF|H3K9me3|None 910 | NHLF|H4K20me1|None 911 | Osteoblasts|H2AZ|None 912 | Osteoblasts|H3K27ac|None 913 | Osteoblasts|H3K27me3|None 914 | Osteoblasts|H3K36me3|None 915 | Osteoblasts|H3K4me1|None 916 | Osteoblasts|H3K4me2|None 917 | Osteoblasts|H3K4me3|None 918 | Osteoblasts|H3K79me2|None 919 | Osteoblasts|H3K9me3|None 920 | -------------------------------------------------------------------------------- /slides/pt1_deep_learning_intro.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ben-heil/dl_workshop/8264068771abc70fcd20a6b2b4e8153f6b1b6a53/slides/pt1_deep_learning_intro.pdf -------------------------------------------------------------------------------- /slides/pt2_convnets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ben-heil/dl_workshop/8264068771abc70fcd20a6b2b4e8153f6b1b6a53/slides/pt2_convnets.pdf -------------------------------------------------------------------------------- /slides/pt3_RNN_dim_reduction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ben-heil/dl_workshop/8264068771abc70fcd20a6b2b4e8153f6b1b6a53/slides/pt3_RNN_dim_reduction.pdf --------------------------------------------------------------------------------