├── Introduction_to_Jupyter_notebooks.ipynb ├── README.md ├── dragonn_tutorial1.ipynb ├── dragonn_tutorial2.ipynb ├── keras_tutorial.ipynb ├── pytorch_mnist_tutorial.ipynb ├── tensorflow_tutorial.ipynb └── tutorial_images ├── CTCF.Tut4.png ├── CTCF_known1.png ├── ChangeRuntime.png ├── GenomeWideModel.png ├── MultiLayerTraining.png ├── RunAllCollab.png ├── RunCellArrow.png ├── RuntimeType.png ├── SIX5_known1.png ├── SPI1.Tut4.png ├── SPIB.Kat.png ├── SimArch1Layer.png ├── TAL1_known4.png ├── ZNF143_known2.png ├── classification_task.jpg ├── comp_graph_eval.png ├── dnn_figure.png ├── dragonn_and_pssm.jpg ├── dragonn_model_figure.jpg ├── heterodimer_simulation.jpg ├── homotypic_motif_density_localization.jpg ├── homotypic_motif_density_localization_task.jpg ├── inspecting_code.png ├── multi-input-multi-output-graph.png ├── numpy_to_tensorflow.png ├── one_hot_encoding.png ├── placeholder_feedforward_dict.png ├── play_all_button.png ├── play_button.png ├── sequence_properties_1.jpg ├── sequence_properties_2.jpg ├── sequence_properties_3.jpg ├── sequence_properties_4.jpg ├── sequence_simulations.png ├── tensor_definition.png ├── tensorflow.png └── tf_binding.jpg /Introduction_to_Jupyter_notebooks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Introduction to Jupyter notebooks #\n", 8 | "## (Formerly known as IPython notebooks) ##" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "The IPython Notebook is now known as the Jupyter Notebook. It is an interactive computational environment, in which you can combine code execution, rich text, mathematics, plots and rich media. For more details on the Jupyter Notebook, please see the [Jupyter](http://jupyter.org/) website." 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "A Jupyter notebook (like this one) is a web-based interface allows you to execute Python and IPython commands in each input cell. You can also save an entire session as a document in a file with the .ipynb extension." 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Cells##" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "* Text cells can be written using [Markdown syntax](http://markdown-guide.readthedocs.io/en/latest/basics.html) (Click on Cell -> Cell Type -> Markdown)\n", 37 | "* Code cells take IPython input (i.e. Python code, %magics, !system calls, etc) like IPython at the terminal" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "In order to execute a cell, you must use **Shift-Enter**, as pressing Enter will add a new line of text to the cell. When you type **Shift-Enter**, the cell content is executed, output displayed and a new cell is created below. Try it now by putting your cursor on the next cell and typing **Shift-Enter**:" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "echo \"Hello World\"" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "You can re-execute the same cell over and over as many times as you want. Simply put your cursor in the cell again, edit at will, and type Shift-Enter to execute." 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "Tip: A cell can also be executed in-place, where IPython executes its content but leaves the cursor in the same cell. This is done by typing Ctrl-Enter instead, and is useful if you want to quickly run a command to check something before tping the real content you want to leave in the cell. For example, in the next cell, try issuing several system commands in-place with Ctrl-Enter, such as pwd (the command that lists the current directory) and then ls (a command that gives you the contents of your current directory):" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "ls" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "In a cell, you can type anything from a single python expression to an arbitrarily long amount of code (although for reasons of readability, you should probably limit this to a few dozen lines):" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "COUNTER=0\n", 93 | "while [ $COUNTER -lt 10 ]; do\n", 94 | " echo The counter is $COUNTER\n", 95 | " let COUNTER=COUNTER+1 \n", 96 | "done\n", 97 | " " 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## User interface ##" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "source": [ 113 | "Click on **Control Panel** in the top-left corner of this screen. Then click on **My Server**. This will bring you to the dashboard. \n", 114 | "\n", 115 | "From the dashboard, you can do one of several things: \n", 116 | "\n", 117 | "* Navigate the file system \n", 118 | "\n", 119 | "* Create a new notebook \n", 120 | "\n", 121 | "* Open a terminal instance\n", 122 | "\n", 123 | "![alt text](https://github.com/kundajelab/training_camp/blob/master/images/dashboard.png?raw=true \"Dashboard\")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "## Kernels ##" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "Under **Notebooks** you are asked to select a kernel to use in creating your notebook. There are 2 options: \n", 138 | "\n", 139 | "* Bash -- we will mostly be using this kernel. It allows us to execute command-line shell scripts from inside the notebook\n", 140 | "* Python3" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "You can always change the kernel that is used to load a notebook by clicking on **Kernel** -> **Change Kernel** in the menu at the top of the notebook. " 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "Sometimes, you may find that a command runs longer than you expected, or you generally need to interrupt a cell. \n", 155 | "You can do this by hitting the black square button (\"interrupt kernel\") at the top of this page. You can also restart a kernel completely, but this will clear outputs from any cells that you have already executed." 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": "Bash", 162 | "language": "bash", 163 | "name": "bash" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": "shell", 167 | "file_extension": ".sh", 168 | "mimetype": "text/x-sh", 169 | "name": "bash" 170 | } 171 | }, 172 | "nbformat": 4, 173 | "nbformat_minor": 1 174 | } 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cs273b 2 | CS273B Deep Learning for Genomics Course Materials 3 | -------------------------------------------------------------------------------- /dragonn_tutorial1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to train your DragoNN tutorial \n", 8 | "\n", 9 | "## How to use this tutorial\n", 10 | "\n", 11 | "This tutorial utilizes a Jupyter Notebook - an interactive computational enviroment that combines live code, visualizations, and explanatory text. The notebook is organized into a series of cells. You can run the next cell by cliking the play button:\n", 12 | "![play button](./tutorial_images/play_button.png)\n", 13 | "You can also run all cells in a series by clicking \"run all\" in the Cell drop-down menu:\n", 14 | "![play all button](./tutorial_images/play_all_button.png)\n", 15 | "Half of the cells in this tutorial contain code, the other half contain visualizations and explanatory text. Code, visualizations, and text in cells can be modified - you are encouraged to modify the code as you advance through the tutorial. You can inspect the implementation of a function used in a cell by following these steps:\n", 16 | "![inspecting code](./tutorial_images/inspecting_code.png)\n", 17 | "\n", 18 | "## Tutorial Overview\n", 19 | "In this tutorial, we will:\n", 20 | "\n", 21 | " 1) Simulate regulatory DNA sequence classification task\n", 22 | " 2) Train DragoNN models of varying complexity to solve the simulation\n", 23 | " 3) Interpret trained DragoNN models\n", 24 | " 4) Show how to train your DragoNN on your own, non-simulated data and use it to interpret data\n", 25 | "\n", 26 | "This tutorial is implemented in python (see this [online python course](https://www.udacity.com/course/programming-foundations-with-python--ud036) for an introduction).\n", 27 | "\n", 28 | "We start by loading dragonn's tutorial utilities. Let's review properties of regulatory sequence while the utilities are loading" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%reload_ext autoreload\n", 38 | "%autoreload 2\n", 39 | "%matplotlib inline\n", 40 | "import os\n", 41 | "os.environ['KERAS_BACKEND']=\"theano\"\n", 42 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\"\n", 43 | "os.environ['THEANO_FLAGS']='device=gpu,floatX=float32'\n", 44 | "\n", 45 | "from dragonn.tutorial_utils import *\n", 46 | "\n", 47 | "one_filter_dragonn_parameters = {\n", 48 | " 'seq_length': 500,\n", 49 | " 'num_filters': [1],\n", 50 | " 'conv_width': [45],\n", 51 | " 'pool_width': 45}\n", 52 | "one_filter_dragonn = get_SequenceDNN(one_filter_dragonn_parameters)\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "![sequence properties 1](./tutorial_images/sequence_properties_1.jpg)\n", 60 | "![sequence properties 2](./tutorial_images/sequence_properties_2.jpg)\n", 61 | "![sequence properties 3](./tutorial_images/sequence_properties_3.jpg)\n", 62 | "![sequence properties 4](./tutorial_images/sequence_properties_4.jpg)\n", 63 | "\n", 64 | "In this tutorial, we will simulate heterodimer motif grammar detection task. Specifically, we will simulate a \"positive\" class of sequences with a SIX5-ZNF143 grammar with relatively fixed spacing between the motifs and a negative class of sequences containing both motifs positioned independently:\n", 65 | "![heterodimer simulation](./tutorial_images/heterodimer_simulation.jpg)\n", 66 | "Here is an overview of the sequence simulation functions in the dragonn tutorial:\n", 67 | "![sequence](./tutorial_images/sequence_simulations.png)\n", 68 | "\n", 69 | "Let's run the print_available_simulations function and see it in action." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "print_available_simulations()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "## Getting simulation data\n", 86 | "\n", 87 | "To get simulation data we:\n", 88 | " \n", 89 | " 1) Define the simulation parameters\n", 90 | " - obtain description of simulation parameters using the print_simulation_info function\n", 91 | " 2) Call the get_simulation_data function, which takes as input the simulation name and the simulation\n", 92 | " parameters, and outputs the simulation data.\n", 93 | "\n", 94 | "We simulate the SIX5-ZNF143 heterodimer motif grammar using the \"simulate_heterodimer_grammar\" simulation function. To get a description of the simulation parameters we use the print_simulation_info function, which takes as input the simulation function name, and outputs documentation for the simulation including the simulation parameters:" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "print_simulation_info(\"simulate_heterodimer_grammar\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "Next, we define parameters for a heterodimer grammar simulation of 500bp long sequence, with 0.4 GC fraction, 10000 positive and negative sequences, with SIx5 and ZNF143 motifs spaced 2-10 bp apart in the positive sequences:" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "heterodimer_grammar_simulation_parameters = {\n", 120 | " \"seq_length\": 500,\n", 121 | " \"GC_fraction\": 0.4,\n", 122 | " \"num_pos\": 10000,\n", 123 | " \"num_neg\": 10000,\n", 124 | " \"motif1\": \"SIX5_known5\",\n", 125 | " \"motif2\": \"ZNF143_known2\",\n", 126 | " \"min_spacing\": 2,\n", 127 | " \"max_spacing\": 10}" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "We get the simulation data by calling the get_simulation_data function with the simulation name and the simulation parameters as inputs." 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "simulation_data = get_simulation_data(\"simulate_heterodimer_grammar\", heterodimer_grammar_simulation_parameters)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "simulation_data provides training, validation, and test sets of input sequences X and sequence labels y. The inputs X are matrices with a one-hot-encoding of the sequences:\n", 151 | "![one hot encoding](./tutorial_images/one_hot_encoding.png)\n", 152 | "Here are the first 10bp of a sequence in our training data:" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "simulation_data.X_train[0, :, :, :10]" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "This matrix represent the 10bp sequence TTGGTAGATA.\n", 169 | "\n", 170 | "Next, we will provide a brief overview of DragoNNs and proceed to train a DragoNN to classify the sequences we simulated:\n", 171 | "![classificatioin task](./tutorial_images/classification_task.jpg)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "# DragoNN Models\n", 179 | "\n", 180 | "A locally connected linear unit in a DragoNN model can represent a PSSM (part a). A sequence PSSM score is obtained by multiplying the PSSM across the sequence, thersholding the PSSM scores, and taking the max (part b). A PSSM score can also be computed by a DragoNN model with tiled locally connected linear units, amounting to a convolutional layer with a single convolutional filter representing the PSSM, followed by ReLU thersholding and maxpooling (part c).\n", 181 | "![dragonn vs pssm](./tutorial_images/dragonn_and_pssm.jpg)\n", 182 | "By utilizing multiple convolutional layers with multiple convolutional filters, DragoNN models can represent a wide range of sequence features in a compositional fashion:\n", 183 | "![dragonn model figure](./tutorial_images/dragonn_model_figure.jpg)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "# Getting a DragoNN model\n", 191 | "\n", 192 | "The main DragoNN model class is SequenceDNN, which provides a simple interface to a range of models and methods to train, test, and interpret DragoNNs. SequenceDNN uses [keras](http://keras.io/), a deep learning library for [Theano](https://github.com/Theano/Theano) and [TensorFlow](https://github.com/tensorflow/tensorflow), which are popular software packages for deep learning.\n", 193 | "\n", 194 | "To get a DragoNN model we:\n", 195 | " \n", 196 | " 1) Define the DragoNN architecture parameters\n", 197 | " - obtain description of architecture parameters using the inspect_SequenceDNN() function\n", 198 | " 2) Call the get_SequenceDNN function, which takes as input the DragoNN architecture parameters, and outputs a \n", 199 | " randomly initialized DragoNN model." 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "To get a description of the architecture parameters we use the inspect_SequenceDNN function, which outputs documentation for the model class including the architecture parameters:" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "inspect_SequenceDNN()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "\"Available methods\" display what can be done with a SequenceDNN model. These include common operations such as training and testing the model, and more complex operations such as extracting insight from trained models. We define a simple DragoNN model with one convolutional layer with one convolutional filter, followed by maxpooling of width 35. " 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "one_filter_dragonn_parameters = {\n", 232 | " 'seq_length': 500,\n", 233 | " 'num_filters': [1],\n", 234 | " 'conv_width': [45],\n", 235 | " 'pool_width': 45}" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "we get a radnomly initialized DragoNN model by calling the get_SequenceDNN function with one_filter_dragonn_parameters as the input" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "import warnings\n", 252 | "warnings.filterwarnings('ignore')\n", 253 | "one_filter_dragonn = get_SequenceDNN(one_filter_dragonn_parameters)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "## Training a DragoNN model\n", 261 | "\n", 262 | "Next, we train the one_filter_dragonn by calling train_SequenceDNN with one_filter_dragonn and simulation_data as the inputs. In each epoch, the one_filter_dragonn will perform a complete pass over the training data, and update its parameters to minimize the loss, which quantifies the error in the model predictions. After each epoch, the code prints performance metrics for the one_filter_dragonn on the validation data. Training stops once the loss on the validation stops improving for multiple consecutive epochs. The performance metrics include balanced accuracy, area under the receiver-operating curve ([auROC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)), are under the precision-recall curve ([auPRC](https://en.wikipedia.org/wiki/Precision_and_recall)), area under the precision-recall-gain curve ([auPRG](https://papers.nips.cc/paper/5867-precision-recall-gain-curves-pr-analysis-done-right.pdf)), and recall for multiple false discovery rates (Recall at [FDR](https://en.wikipedia.org/wiki/False_discovery_rate))." 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "train_SequenceDNN(one_filter_dragonn, simulation_data)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "We can see that the validation loss is not decreasing and the auROC metric is not decreasing, which indicates this model is not learning. A simple plot of the learning curve, showing the loss function on the training and validation data over the course of training, demonstrates this visually:" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "SequenceDNN_learning_curve(one_filter_dragonn)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "# A multi-filter DragoNN model \n", 295 | "Next, we modify the model to have 15 convolutional filters instead of just one filter. Will the model learn now?" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "multi_filter_dragonn_parameters = {\n", 305 | " 'seq_length': 500,\n", 306 | " 'num_filters': [15], ## notice the change from 1 filter to 15 filters\n", 307 | " 'conv_width': [45],\n", 308 | " 'pool_width': 45,\n", 309 | " 'dropout': 0.1}\n", 310 | "multi_filter_dragonn = get_SequenceDNN(multi_filter_dragonn_parameters)\n", 311 | "train_SequenceDNN(multi_filter_dragonn, simulation_data)\n", 312 | "SequenceDNN_learning_curve(multi_filter_dragonn)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "## Interpreting a DragoNN model using filter visualization\n", 320 | "We can see that this model has not learned much because the validation loss has hardly decreased over the course of training and the auROC is only 0.586. Let's see what the sequence filters of this model look like." 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "interpret_SequenceDNN_filters(multi_filter_dragonn, simulation_data)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "As can be expected, the sequence filters don't reveal patterns that resemble the simulated motifs. Next we explore methods to interpret specific sequences with this DragoNN model.\n", 337 | "\n", 338 | "# Interpreting data with a DragoNN model\n", 339 | "\n", 340 | "Using in-silico mutagenesis (ISM) and [DeepLIFT](https://arxiv.org/pdf/1605.01713v2.pdf), we can obtain scores for specific sequence indicating the importance of each position in the sequence. To assess these methods we compare ISM and DeepLIFT scores to motif scores for each simulated motif at each position in the sequence. These motif scores represent the \"ground truth\" importance of each position because they are based on the motifs used to simulate the data. We plot provide comaprisons for a positive class sequence on the left and a negative class sequence on the right." 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "interpret_data_with_SequenceDNN(multi_filter_dragonn, simulation_data)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "We can see that neither DeepLIFT nor ISM highlight the locations of the simulated motifs (highlighted in grey). This is expected because this model doesn't perform well on this simulation." 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "# A multi-layer DragoNN model\n", 364 | "Next, we extend modify multi_filter_dragon to have 3 convolutional layers, with convolutional filter of 15 in each layer, to learn the heterodimer grammar compositionally across multiple layers." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "multi_layer_dragonn_parameters = {\n", 374 | " 'seq_length': 500,\n", 375 | " 'num_filters': [15, 15, 15], ## notice the change to multiple filter values, one for each layer\n", 376 | " 'conv_width': [25, 25, 25], ## convolutional filter width has been modified to 25 from 45\n", 377 | " 'pool_width': 45,\n", 378 | " 'dropout': 0.1}\n", 379 | "multi_layer_dragonn = get_SequenceDNN(multi_layer_dragonn_parameters)\n", 380 | "train_SequenceDNN(multi_layer_dragonn, simulation_data)\n", 381 | "SequenceDNN_learning_curve(multi_layer_dragonn)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "The multi-layered DragoNN model achieves a higher auROC and a lower training and validation loss than the multi-filter DragoNN model. Try the same model without dropout regularization: how important is dropout?\n", 389 | "\n", 390 | "Let's see what the model learns in its sequence filters." 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "interpret_SequenceDNN_filters(multi_layer_dragonn, simulation_data)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "The sequence filters here are not amenable to interpretation based on visualization alone. In multi-layered models, sequence features are learned compositionally across the layers. As a result, sequence filters in the first layer focus more on simple features that can be combined in higher layers to learn motif features more efficiently, and their interpretation becomes less clear based on simple visualizations. Let's see where ISM and DeepLIFT get us with this model." 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "interpret_data_with_SequenceDNN(multi_layer_dragonn, simulation_data)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "DeepLIFT and ISM scores for this model on representative positive (left) and negative (right) sequences expose what the model is doing.. The SIX5-ZNF143 grammar is clearly highlighted by both methods in the positive class sequence. However, ISM assigns higher scores to false features around position 250, so we would not be able to distinguish between flase and true features in this example based on ISM score magnitude. DeepLIFT, on the other hand, assigns the highest scores to the true features and therefore it could be used in this example to detect the SIX5-ZNF143 grammar." 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "# Using DragoNN on your own non-simulated data\n", 430 | "\n", 431 | "The dragonn package provides a command-line interface to train and test DragoNN models, and use them to predict and interpret new data. We start by training a dragonn model on positive and negative sequence:" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "#### important: If you are running this notebook on the Kundaje Lab public AWS image, please click on Kernel -> Restart in the menu at the top before running the commands below. \n", 439 | "(This will prevent process contention for GPU's)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "!dragonn train --pos-sequences example_pos_sequences.fa --neg-sequences example_neg_sequences.fa --prefix training_example" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "Based on the provided prefix, this command stores a model file, training_example.model.json, with the model architecture and a weights file, training_example.weights.hd5, with the parameters of the trained model. We test the model by running:" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "!dragonn test --pos-sequences example_pos_sequences.fa --neg-sequences example_neg_sequences.fa \\\n", 465 | "--arch-file training_example.arch.json --weights-file training_example.weights.h5" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": {}, 471 | "source": [ 472 | "This command prints the model's test performance metrics on the provided data. Model predictions on sequence data can be obtained by running:" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "!dragonn predict --sequences example_pos_sequences.fa --arch-file training_example.arch.json \\\n", 482 | "--weights-file training_example.weights.h5 --output-file example_predictions.txt" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": {}, 488 | "source": [ 489 | "This command stores the model predictions for sequences in example_pos_sequences.fa in the output file example_predictions.txt. We can interpret sequence data with a dragonn model by running:" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "!dragonn interpret --sequences example_pos_sequences.fa --arch-file training_example.arch.json \\\n", 499 | "--weights-file training_example.weights.h5 --prefix example_interpretation" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "This will write the most important subsequence in each input sequence along with its location in the input sequence in the file example_interpretation.task_0.important_sequences.txt. Note: by default, only examples with predicted positive class probability >0.5 are interpreted. Examples below this threshold yield important subsequence of Ns with location -1. Let's look the first few lines of this file:" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "!head example_interpretation.task_0.important_sequences.txt" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "## Extras for HW\n", 523 | "\n", 524 | "The tutorial example here touches on general principles of DragoNN model development and interpretation. To gain a deeper insight into the difference between DeepLIFT and ISM for model interpretation, consider the following exercise:\n", 525 | "\n", 526 | "Train, test, and run sequence-centric interpretation for the one layered CNN model used here for the following\n", 527 | "simulations:\n", 528 | " 1. single motif detection simulation of TAL1 in 1000bp sequence with 40% GC content\n", 529 | " (run print_simulation_info(\"simulate_single_motif_detection\") to see the exact simulation parameters)\n", 530 | " 2. motif density localization simulation of 2-4 TAL1 motif instances in the central of 150bp of a total 1000bp\n", 531 | " sequence with 40% GC\n", 532 | " content\n", 533 | " (run print_simulation_info(\"simulate_motif_density_localization\") to see the exact simulation parameters)\n", 534 | "\n", 535 | "Key questions:\n", 536 | "\n", 537 | " 1) What could explain the difference in ISM's sensitivity to the TAL1 motif sequence between the simulations?\n", 538 | " 2) What does that tell us about the the scope of ISM for feature discovery? Under what conditions is it likely\n", 539 | " to show sensitivity to sequence features?\n", 540 | " \n", 541 | "Starter code is provided below to get the data for each simulation and new DragoNN model.\n" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "single_motif_detection_simulation_parameters = {\n", 551 | " \"motif_name\": \"TAL1_known4\",\n", 552 | " \"seq_length\": 1000,\n", 553 | " \"num_pos\": 10000,\n", 554 | " \"num_neg\": 10000,\n", 555 | " \"GC_fraction\": 0.4}\n", 556 | "\n", 557 | "density_localization_simulation_parameters = {\n", 558 | " \"motif_name\": \"TAL1_known4\",\n", 559 | " \"seq_length\": 1000,\n", 560 | " \"center_size\": 150,\n", 561 | " \"min_motif_counts\": 2,\n", 562 | " \"max_motif_counts\": 4,\n", 563 | " \"num_pos\": 10000,\n", 564 | " \"num_neg\": 10000,\n", 565 | " \"GC_fraction\": 0.4}\n", 566 | "\n", 567 | "single_motif_detection_simulation_data = get_simulation_data(\n", 568 | " \"simulate_single_motif_detection\", single_motif_detection_simulation_parameters)\n", 569 | "\n", 570 | "density_localization_simulation_data = get_simulation_data(\n", 571 | " \"simulate_motif_density_localization\", density_localization_simulation_parameters)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "new_dragonn_model = get_SequenceDNN(multi_layer_dragonn_parameters)" 581 | ] 582 | } 583 | ], 584 | "metadata": { 585 | "kernelspec": { 586 | "display_name": "dragonn", 587 | "language": "python", 588 | "name": "dragonn" 589 | }, 590 | "language_info": { 591 | "codemirror_mode": { 592 | "name": "ipython", 593 | "version": 2 594 | }, 595 | "file_extension": ".py", 596 | "mimetype": "text/x-python", 597 | "name": "python", 598 | "nbconvert_exporter": "python", 599 | "pygments_lexer": "ipython2", 600 | "version": "2.7.15" 601 | } 602 | }, 603 | "nbformat": 4, 604 | "nbformat_minor": 1 605 | } 606 | -------------------------------------------------------------------------------- /dragonn_tutorial2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to train your DragoNN tutorial\n", 8 | "\n", 9 | "**Tutorial length**: 25-30 minutes with a CPU.\n", 10 | "\n", 11 | "## Outline\n", 12 | " * How to use this tutorial\n", 13 | " * Review of patterns in transcription factor binding sites\n", 14 | " * Learning to localize homotypic motif density\n", 15 | " * Sequence model definition\n", 16 | " * Training and interpretation of\n", 17 | " - single layer, single filter DragoNN\n", 18 | " - single layer, multiple filters DragoNN\n", 19 | " - Multi-layer DragoNN\n", 20 | " - Regularized multi-layer DragoNN\n", 21 | " * Critical questions in this tutorial:\n", 22 | " - What is the \"right\" way to get insight from a DragoNN model?\n", 23 | " - What are the limitations of different interpretation methods?\n", 24 | " - Do those limitations depend on the model and the target pattern?\n", 25 | " * Suggestions for further exploration\n", 26 | "\n", 27 | "Github issues on the dragonn repository with feedback, questions, and discussion are always welcome.\n", 28 | "\n", 29 | "\n", 30 | "## How to use this tutorial\n", 31 | "\n", 32 | "This tutorial utilizes a Jupyter/IPython Notebook - an interactive computational enviroment that combines live code, visualizations, and explanatory text. The notebook is organized into a series of cells. You can run the next cell by cliking the play button:\n", 33 | "![play button](./tutorial_images/play_button.png)\n", 34 | "You can also run all cells in a series by clicking \"run all\" in the Cell drop-down menu:\n", 35 | "![play all button](./tutorial_images/play_all_button.png)\n", 36 | "Half of the cells in this tutorial contain code, the other half contain visualizations and explanatory text. Code, visualizations, and text in cells can be modified - you are encouraged to modify the code as you advance through the tutorial. You can inspect the implementation of a function used in a cell by following these steps:\n", 37 | "![inspecting code](./tutorial_images/inspecting_code.png)\n", 38 | "\n", 39 | "We start by loading dragonn's tutorial utilities and reviewing properties of regulatory sequence that transcription factors bind." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "%reload_ext autoreload\n", 49 | "%autoreload 2\n", 50 | "from dragonn.tutorial_utils import *\n", 51 | "%matplotlib inline" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "![sequence properties 1](./tutorial_images/sequence_properties_1.jpg)\n", 59 | "![sequence properties 2](./tutorial_images/sequence_properties_2.jpg)\n", 60 | "\n", 61 | "# Learning to localize a homotypic motif density\n", 62 | "In this tutorial we will learn how to localize a homotypic motif cluster. We will simulate a positive set of sequences with multiple instances of a motif in the center and a negative set of sequences with multiple motif instances positioned anywhere in the sequence:\n", 63 | "![honotypic motif density localization](./tutorial_images/homotypic_motif_density_localization.jpg)\n", 64 | "We will then train a binary classification model to classify the simulated sequences. To solve this task, the model will need to learn the motif pattern and whether instances of that pattern are present in the central part of the sequence.\n", 65 | "\n", 66 | "We start by getting the simulation data." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Getting simulation data\n", 74 | "\n", 75 | "DragoNN provides a set of simulation functions. We will use the simulate_motif_density_localization function to simulate homotypic motif density localization. First, we obtain documentation for the simulation parameters." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "print_simulation_info(\"simulate_motif_density_localization\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "Next, we define parameters for a TAL1 motif density localization in 1500bp long sequence, with 0.4 GC fraction, and 2-4 instances of the motif in the central 150bp for the positive sequences. We simulate a total of 3000 positive and 3000 negative sequences." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "motif_density_localization_simulation_parameters = {\n", 101 | " \"motif_name\": \"TAL1_known4\",\n", 102 | " \"seq_length\": 1000,\n", 103 | " \"center_size\": 150,\n", 104 | " \"min_motif_counts\": 2,\n", 105 | " \"max_motif_counts\": 4, \n", 106 | " \"num_pos\": 3000,\n", 107 | " \"num_neg\": 3000,\n", 108 | " \"GC_fraction\": 0.4}" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "We get the simulation data by calling the get_simulation_data function with the simulation name and the simulation parameters as inputs. 1000 sequences are held out for a test set, 1000 sequences for a validation set, and the remaining 4000 sequences are in the training set." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "simulation_data = get_simulation_data(\"simulate_motif_density_localization\",\n", 125 | " motif_density_localization_simulation_parameters,\n", 126 | " validation_set_size=1000, test_set_size=1000)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "simulation_data provides training, validation, and test sets of input sequences X and sequence labels y. The inputs X are matrices with a one-hot-encoding of the sequences:\n", 134 | "![one hot encoding](./tutorial_images/one_hot_encoding.png)\n", 135 | "Here are the first 10bp of a sequence in our training data:" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "simulation_data.X_train[0, :, :, :10]" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "This matrix represent the 10bp sequence AAATGGGCCG.\n", 152 | "\n", 153 | "# The homotypic motif density localization task\n", 154 | "The goal of the model is to take the positive and negative sequences simulated above and classify them:\n", 155 | "![classificatioin task](./tutorial_images/homotypic_motif_density_localization_task.jpg)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "# DragoNN Models\n", 163 | "\n", 164 | "A locally connected linear unit in a DragoNN model can represent a PSSM (part a). A sequence PSSM score is obtained by multiplying the PSSM across the sequence, thersholding the PSSM scores, and taking the max (part b). A PSSM score can also be computed by a DragoNN model with tiled locally connected linear units, amounting to a convolutional layer with a single convolutional filter representing the PSSM, followed by ReLU thersholding and maxpooling (part c).\n", 165 | "![dragonn vs pssm](./tutorial_images/dragonn_and_pssm.jpg)\n", 166 | "By utilizing multiple convolutional layers with multiple convolutional filters, DragoNN models can represent a wide range of sequence features in a compositional fashion:\n", 167 | "![dragonn model figure](./tutorial_images/dragonn_model_figure.jpg)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "# Getting a DragoNN model\n", 175 | "\n", 176 | "The main DragoNN model class is SequenceDNN, which provides a simple interface to a range of models and methods to train, test, and interpret DragoNNs. SequenceDNN uses [keras](http://keras.io/), a deep learning library for [Theano](https://github.com/Theano/Theano) and [TensorFlow](https://github.com/tensorflow/tensorflow), which are popular software packages for deep learning.\n", 177 | "\n", 178 | "To get a DragoNN model we:\n", 179 | " \n", 180 | " 1) Define the DragoNN architecture parameters\n", 181 | " - obtain description of architecture parameters using the inspect_SequenceDNN() function\n", 182 | " 2) Call the get_SequenceDNN function, which takes as input the DragoNN architecture parameters, and outputs a \n", 183 | " randomly initialized DragoNN model." 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "To get a description of the architecture parameters we use the inspect_SequenceDNN function, which outputs documentation for the model class including the architecture parameters:" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "inspect_SequenceDNN()" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "\"Available methods\" display what can be done with a SequenceDNN model. These include common operations such as training and testing the model, and more complex operations such as extracting insight from trained models. We define a simple DragoNN model with one convolutional layer with one convolutional filter, followed by maxpooling of width 35. " 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "one_filter_dragonn_parameters = {\n", 216 | " 'seq_length': 1000,\n", 217 | " 'num_filters': [1],\n", 218 | " 'conv_width': [10],\n", 219 | " 'pool_width': 35}" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "we get a radnomly initialized DragoNN model by calling the get_SequenceDNN function with one_filter_dragonn_parameters as the input" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "import warnings\n", 236 | "warnings.filterwarnings('ignore')\n", 237 | "one_filter_dragonn = get_SequenceDNN(one_filter_dragonn_parameters)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "## Training a DragoNN model\n", 245 | "\n", 246 | "Next, we train the one_filter_dragonn by calling train_SequenceDNN with one_filter_dragonn and simulation_data as the inputs. In each epoch, the one_filter_dragonn will perform a complete pass over the training data, and update its parameters to minimize the loss, which quantifies the error in the model predictions. After each epoch, the code prints performance metrics for the one_filter_dragonn on the validation data. Training stops once the loss on the validation stops improving for multiple consecutive epochs. The performance metrics include balanced accuracy, area under the receiver-operating curve ([auROC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)), are under the precision-recall curve ([auPRC](https://en.wikipedia.org/wiki/Precision_and_recall)), and recall for multiple false discovery rates (Recall at [FDR](https://en.wikipedia.org/wiki/False_discovery_rate))." 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "train_SequenceDNN(one_filter_dragonn, simulation_data)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "A single layer, single filter model gets good performance and doesn't overfit much. Let's look at the learning curve to demonstrate this visually:" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "SequenceDNN_learning_curve(one_filter_dragonn)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "# A multi-filter DragoNN model \n", 279 | "Next, we modify the model to have 15 convolutional filters instead of just one filter. How does this model compare to the single filter model?" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "multi_filter_dragonn_parameters = {\n", 289 | " 'seq_length': 1000,\n", 290 | " 'num_filters': [15], ## notice the change from 1 filter to 15 filters\n", 291 | " 'conv_width': [10],\n", 292 | " 'pool_width': 35}\n", 293 | "multi_filter_dragonn = get_SequenceDNN(multi_filter_dragonn_parameters)\n", 294 | "train_SequenceDNN(multi_filter_dragonn, simulation_data)\n", 295 | "SequenceDNN_learning_curve(multi_filter_dragonn)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "It slightly outperforms the single filter model. Let's check if the learned filters capture the simulated pattern." 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "interpret_SequenceDNN_filters(multi_filter_dragonn, simulation_data)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "Only some of the filters closesly match the simulated pattern. This illustrates that interpreting model parameters directly works partially for multi-filter models. Another way to deduce learned patterns is to examine feature importances for specific examples. Next, we explore methods for feature importance scoring." 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "# Interpreting data with a DragoNN model\n", 326 | "\n", 327 | "Using in-silico mutagenesis (ISM) and [DeepLIFT](https://arxiv.org/pdf/1605.01713v2.pdf), we can obtain scores for specific sequence indicating the importance of each position in the sequence. To assess these methods we compare ISM and DeepLIFT scores to motif scores for each simulated motif at each position in the sequence. These motif scores represent the \"ground truth\" importance of each position because they are based on the motifs used to simulate the data. We plot provide comaprisons for a positive class sequence on the left and a negative class sequence on the right." 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "interpret_data_with_SequenceDNN(multi_filter_dragonn, simulation_data)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "In the positive example (left side), ISM correctly highlights the two motif instances in the central 150bp. DeepLIFT highlights them as well. DeepLIFT also slightly highlights false positive feature on the left side but its score is sufficiently small that we can discriminate between the false positive feature and the true positive features. In the negative example (right side), ISM doesn't highlight anything but DeepLIFT a couple false positive feature almost as much as it highlights true positive features in the positive example." 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "# A multi-layer DragoNN model\n", 351 | "Next, we train a 3 layer model for this task. Will it outperform the single layer model and to what extent will it overfit?" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "multi_layer_dragonn_parameters = {\n", 361 | " 'seq_length': 1000,\n", 362 | " 'num_filters': [15, 15, 15], ## notice the change to multiple filter values, one for each layer\n", 363 | " 'conv_width': [10, 10, 10], ## convolutional filter width has been modified to 25 from 45\n", 364 | " 'pool_width': 35}\n", 365 | "\n", 366 | "multi_layer_dragonn = get_SequenceDNN(multi_layer_dragonn_parameters)\n", 367 | "train_SequenceDNN(multi_layer_dragonn, simulation_data)\n", 368 | "SequenceDNN_learning_curve(multi_layer_dragonn)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "This model performs about the same as the single layer model but it overfits more. We will try to address that with dropout regularization. But first, what do the first layer filters look like?" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "interpret_SequenceDNN_filters(multi_layer_dragonn, simulation_data)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "The filters now make less sense than in the single layer model case. In multi-layered models, sequence features are learned compositionally across the layers. As a result, sequence filters in the first layer focus more on simple features that can be combined in higher layers to learn motif features more efficiently, and their interpretation becomes less clear based on simple visualizations. Let's see where ISM and DeepLIFT get us with this model." 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "interpret_data_with_SequenceDNN(multi_layer_dragonn, simulation_data)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "As in the single layer model case, ISM correctly highlights the two true positive features in the positive example (left side) and correctly ignores features in the negative example (right side). DeepLIFT still highlight the same false positive feature example in the positive example as before, but we can still separate it from the true positive features. In the negative example, it still highlights some false positive features." 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "# A regularized multi-layer DragoNN model\n", 415 | "Next, we regularize the 3 layer using 0.2 dropout on every convolutional layer. Will dropout improve validation performance?" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": { 422 | "scrolled": true 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "regularized_multi_layer_dragonn_parameters = {\n", 427 | " 'seq_length': 1000,\n", 428 | " 'num_filters': [15, 15, 15],\n", 429 | " 'conv_width': [10, 10, 10],\n", 430 | " 'pool_width': 35,\n", 431 | " 'dropout': 0.2} ## we introduce dropout of 0.2 on every convolutional layer for regularization\n", 432 | "regularized_multi_layer_dragonn = get_SequenceDNN(regularized_multi_layer_dragonn_parameters)\n", 433 | "train_SequenceDNN(regularized_multi_layer_dragonn, simulation_data)\n", 434 | "SequenceDNN_learning_curve(regularized_multi_layer_dragonn)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "As expected, dropout decreased the overfitting this model displayed previously and increased validation performance. Let's see the effect on feature discovery." 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "interpret_data_with_SequenceDNN(regularized_multi_layer_dragonn, simulation_data)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": {}, 456 | "source": [ 457 | "ISM now highlights a false positive feature in the positive example (left side) more than the true positive features. What happened? A sufficiently accurate model should not change its confidence that there are 2 or more features in the central 150 base pairs (bps) due to a single bp change. So it makes sense that in the limit of the \"perfect\" model ISM will actually lose its power to discover features in this example.\n", 458 | "\n", 459 | "How about DeepLIFT? DeepLIFT correctly highlights the only two positive features in the positive example. So it seems that in the limit of the \"perfect\" model, DeepLIFT gets closer to the true positive features.\n", 460 | "\n", 461 | "Why did this happen? Why, as we regularize the model and improve the performance, ISM fails to highlight the true positive features? Here is a hint: in the limit of the \"perfect\" model for this simulation, will a single base pair perturbation to the positive example here change its confidence that it is still a positive example? I encourage you to open github issues on the dragonn repo to discuss these questions.\n", 462 | "\n", 463 | "Below is an overview of patterns and simulations for further exploration." 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "# For further exploration \n", 471 | "In this tutorial we explored modeling of homotypic motif density. Other properties of regulatory DNA sequence include\n", 472 | "![sequence properties 3](./tutorial_images/sequence_properties_3.jpg)\n", 473 | "![sequence properties 4](./tutorial_images/sequence_properties_4.jpg)\n", 474 | "\n", 475 | "DragoNN provides simulations that formulate learning these patterns into classification problems:\n", 476 | "![sequence](./tutorial_images/sequence_simulations.png)\n", 477 | "\n", 478 | "You can view the available simulation functions by running print_available_simulations:" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "print_available_simulations()" 488 | ] 489 | } 490 | ], 491 | "metadata": { 492 | "kernelspec": { 493 | "display_name": "Python 3", 494 | "language": "python", 495 | "name": "python3" 496 | }, 497 | "language_info": { 498 | "codemirror_mode": { 499 | "name": "ipython", 500 | "version": 3 501 | }, 502 | "file_extension": ".py", 503 | "mimetype": "text/x-python", 504 | "name": "python", 505 | "nbconvert_exporter": "python", 506 | "pygments_lexer": "ipython3", 507 | "version": "3.5.2" 508 | } 509 | }, 510 | "nbformat": 4, 511 | "nbformat_minor": 1 512 | } 513 | -------------------------------------------------------------------------------- /keras_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "nbpresent": { 7 | "id": "2fd3e30c-683a-4822-b66c-b44433498fb4" 8 | } 9 | }, 10 | "source": [ 11 | "# Introduction to Keras #" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "nbpresent": { 18 | "id": "9aa90ae4-ccb4-4e11-ac66-694182876725" 19 | } 20 | }, 21 | "source": [ 22 | "## https://keras.io/ ## \n", 23 | "(Keras has great documentation, in fact, most of the examples below are taken in part of in full from the docs!)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "nbpresent": { 30 | "id": "105c3772-4173-4ef2-ace2-34ae4cb87861" 31 | } 32 | }, 33 | "source": [ 34 | "## Keras workflow for a sequential model -- MNIST toy example ##" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "nbpresent": { 42 | "id": "1d0ff088-2f86-48bb-8818-f11e3d965029" 43 | } 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import keras \n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "collapsed": true, 54 | "nbpresent": { 55 | "id": "5f9c4818-f573-4be1-97c2-219b3f80f6a6" 56 | } 57 | }, 58 | "source": [ 59 | "A **Sequential** model class is a linear stack of layers. \n", 60 | "Each layer has it's own class. Some examples that you might encounter frequently are included in the import statements below " 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "nbpresent": { 68 | "id": "9ac06fe8-8feb-429b-a227-6123d2121dbe" 69 | } 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "#import the Sequential model class. \n", 74 | "from keras.models import Sequential \n", 75 | "\n", 76 | "#Core layers \n", 77 | "from keras.layers.core import Dense, Activation,Dropout,Flatten\n", 78 | "\n", 79 | "#Convolution layers \n", 80 | "from keras.layers.convolutional import Conv1D, Conv2D \n", 81 | "\n", 82 | "#Pooling layers \n", 83 | "from keras.layers.pooling import MaxPooling1D, MaxPooling2D, AveragePooling1D, AveragePooling2D \n", 84 | "\n", 85 | "#Recurrent layers \n", 86 | "from keras.layers.recurrent import Recurrent, SimpleRNN, GRU, LSTM\n", 87 | "\n", 88 | "#Embedding layers \n", 89 | "from keras.layers.embeddings import Embedding\n", 90 | "\n", 91 | "#Merge layers \n", 92 | "from keras.layers.merge import Add, Multiply, Average, Maximum, Concatenate, Dot\n", 93 | "\n", 94 | "#Normalization layers \n", 95 | "from keras.layers.normalization import BatchNormalization \n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "nbpresent": { 103 | "id": "7f772d73-2cea-4a33-8768-40f559e85e90" 104 | } 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "#the MNIST dataset is included with keras installation \n", 109 | "from keras.datasets import mnist \n", 110 | "(X_train,y_train),(X_test,y_test)=mnist.load_data() " 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": { 117 | "nbpresent": { 118 | "id": "95776b5b-aef2-4e7b-be13-5927378649e7" 119 | } 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "#we want to create a validation set to illustrate training with a validation dataset, so we \"hack\" this dataset by \n", 124 | "#splitting the test data into a test and validation dataset \n", 125 | "X_valid=X_test[0:int(X_test.shape[0]/2)]\n", 126 | "y_valid=y_test[0:int(y_test.shape[0]/2)]\n", 127 | "X_test=X_test[int(X_test.shape[0]/2)::]\n", 128 | "y_test=y_test[int(y_test.shape[0]/2)::]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "nbpresent": { 136 | "id": "d9a12e90-f2c0-4993-b91e-fa464821cabf" 137 | } 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "#Let's briefly examine our data \n", 142 | "print(\"Training X:\"+str(X_train.shape))\n", 143 | "print(\"Training y:\"+str(y_train.shape))\n", 144 | "print(\"Valid X:\"+str(X_valid.shape))\n", 145 | "print(\"Valid y:\"+str(y_valid.shape))\n", 146 | "print(\"Test X:\"+str(X_test.shape))\n", 147 | "print(\"Test y:\"+str(y_test.shape))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "nbpresent": { 155 | "id": "180488c5-949f-49f4-84bb-dd23e196ab08" 156 | } 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "%matplotlib inline\n", 161 | "from matplotlib.pyplot import imshow\n", 162 | "#You can visualize the digits that we classify \n", 163 | "digit_index=0\n", 164 | "imshow(X_train[digit_index])\n", 165 | "print(\"training label:\"+str(y_train[digit_index]))" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "nbpresent": { 173 | "id": "3ac98894-da8c-414f-a91b-db87936a45ac" 174 | } 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "#plot our first test digit -- let's see if we can train a deep learning model to correctly predict this digit. \n", 179 | "imshow(X_test[digit_index])\n" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": { 186 | "nbpresent": { 187 | "id": "9cd05738-dbbe-4ca7-a050-8c632e5edc7e" 188 | } 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "#pre-process the input data \n", 193 | "from keras import backend as K \n", 194 | "from keras.utils import np_utils \n", 195 | "# input image dimensions \n", 196 | "img_rows, img_cols = 28, 28\n", 197 | "#number of output classes\n", 198 | "nb_classes=10 \n", 199 | "\n", 200 | "\n", 201 | "#WARNING! order of dimensions differs for theano & tensorflow (that's why we check the backend and re-arrange the image \n", 202 | "#dimensions accordingly)\n", 203 | "if K.image_dim_ordering() == 'th':\n", 204 | " X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)\n", 205 | " X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)\n", 206 | " X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols)\n", 207 | " input_shape = (1, img_rows, img_cols)\n", 208 | "else:\n", 209 | " X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)\n", 210 | " X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)\n", 211 | " X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1)\n", 212 | " input_shape = (img_rows, img_cols, 1)\n", 213 | "\n", 214 | " \n", 215 | "#this is done to normalize the data, as RGB values range in intensity from 0 to 255 \n", 216 | "X_train = X_train.astype('float32')\n", 217 | "X_test = X_test.astype('float32')\n", 218 | "X_valid = X_valid.astype('float32')\n", 219 | "X_train /= 255\n", 220 | "X_test /= 255\n", 221 | "X_valid /= 255 \n", 222 | "\n", 223 | "print('X_train shape:', X_train.shape)\n", 224 | "print(X_train.shape[0], 'train samples')\n", 225 | "print(X_test.shape[0], 'test samples')\n", 226 | "print(X_valid.shape[0], 'test samples')\n", 227 | "\n", 228 | "\n", 229 | "# convert class vectors to binary class matrices \n", 230 | "y_train = np_utils.to_categorical(y_train, nb_classes)\n", 231 | "y_valid = np_utils.to_categorical(y_valid, nb_classes)\n", 232 | "y_test = np_utils.to_categorical(y_test, nb_classes)\n", 233 | "\n", 234 | "print(\"y_train shape:\", y_train.shape)\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": { 241 | "nbpresent": { 242 | "id": "e8d472df-cfdd-4bee-8a02-63b55bf3b2cc" 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "K.image_dim_ordering()" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": { 253 | "nbpresent": { 254 | "id": "f640725b-e6d8-4472-9f44-e04ec053124e" 255 | } 256 | }, 257 | "source": [ 258 | "### We want to implement the following model architecture to train a CNN to recognize handwritten digits: \n", 259 | "```` \n", 260 | " Convolution2D\n", 261 | " |\n", 262 | " v\n", 263 | " ReLU\n", 264 | " |\n", 265 | " v\n", 266 | "Convolution2D\n", 267 | " |\n", 268 | " v\n", 269 | " ReLU\n", 270 | " |\n", 271 | " v\n", 272 | " MaxPool2D\n", 273 | " |\n", 274 | " v\n", 275 | " Dropout \n", 276 | " |\n", 277 | " v\n", 278 | " Flatten\n", 279 | " |\n", 280 | " v\n", 281 | " Dense\n", 282 | " |\n", 283 | " v\n", 284 | " ReLU\n", 285 | " |\n", 286 | " v\n", 287 | " Dropout\n", 288 | " |\n", 289 | " v\n", 290 | " Dense\n", 291 | " |\n", 292 | " v\n", 293 | " Softmax\n", 294 | "````" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "nbpresent": { 302 | "id": "0b20ac20-c1e1-4391-8798-fc45b4f6df01" 303 | } 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "#architecture hyperparameters \n", 308 | "# number of convolutional filters to use \n", 309 | "nb_filters = 32\n", 310 | "# size of pooling area for max pooling \n", 311 | "pool_size = (2, 2)\n", 312 | "# convolution kernel size \n", 313 | "kernel_size = (3, 3)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "nbpresent": { 321 | "id": "e841541b-d37a-4189-a07c-0cba4314cae4" 322 | } 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "#model architecture \n", 327 | "model = Sequential()\n", 328 | "\n", 329 | "model.add(Conv2D(nb_filters, (kernel_size[0], kernel_size[1]),\n", 330 | " padding='valid',\n", 331 | " input_shape=input_shape))\n", 332 | "model.add(Activation('relu'))\n", 333 | "model.add(Conv2D(nb_filters, (kernel_size[0], kernel_size[1])))\n", 334 | "model.add(Activation('relu'))\n", 335 | "model.add(MaxPooling2D(pool_size=pool_size))\n", 336 | "model.add(Dropout(0.25))\n", 337 | "\n", 338 | "model.add(Flatten())\n", 339 | "model.add(Dense(128))\n", 340 | "model.add(Activation('relu'))\n", 341 | "model.add(Dropout(0.5))\n", 342 | "model.add(Dense(nb_classes))\n", 343 | "model.add(Activation('softmax'))" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": { 349 | "nbpresent": { 350 | "id": "539bb367-a4a8-4db2-b7fd-b0e4d2e22537" 351 | } 352 | }, 353 | "source": [ 354 | "#### A quick note on \"padding\" parameter: \n", 355 | "\n", 356 | "* With padding=\"valid\" you get an output that is smaller than the input because the convolution is only computed where the input and the filter fully overlap.\n", 357 | "\n", 358 | "\n", 359 | "* With padding=\"same\" you get an output that is the \"same\" size as the input. That means that the filter has to go outside the bounds of the input by \"filter size / 2\" - the area outside of the input is normally padded with zeros.\n", 360 | "\n", 361 | "\n", 362 | "* Note that some libraries also support the padding=\"full\" where the filter goes even further outside the bounds of the input - up to \"filter size - 1\". This results in an output shape larger than the input.\n", 363 | "\n" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": { 370 | "nbpresent": { 371 | "id": "647d0e86-4bb3-4b56-b7ca-cf9a269253e7" 372 | } 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "#compile the model \n", 377 | "\n", 378 | "#we can try several different optimizers \n", 379 | "#optimizer=\"adam\"\n", 380 | "#optimizer=\"adadelta\"\n", 381 | "#optimizer=\"adagrad\"\n", 382 | "optimizer=\"sgd\"\n", 383 | "\n", 384 | "#add momentum, alter learning rate: \n", 385 | "from keras.optimizers import SGD\n", 386 | "#optimizer=SGD(lr=0.01, momentum=0.9, nesterov=True)\n", 387 | "\n", 388 | "model.compile(loss='categorical_crossentropy',\n", 389 | " optimizer=optimizer,\n", 390 | " metrics=['accuracy'])\n" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": { 397 | "nbpresent": { 398 | "id": "dd208054-ab67-4d88-b0e1-7a701ea9e65b" 399 | } 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "from IPython.display import SVG\n", 404 | "from keras.utils.vis_utils import model_to_dot\n", 405 | "\n", 406 | "SVG(model_to_dot(model).create(prog='dot', format='svg'))" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": { 413 | "nbpresent": { 414 | "id": "f793c1ab-d747-4e68-af57-3de591f6def7" 415 | } 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "#training\n", 420 | "batch_size=128\n", 421 | "nb_epoch=10\n", 422 | "history=model.fit(X_train, y_train, batch_size=batch_size, epochs=nb_epoch,verbose=1, validation_data=(X_valid, y_valid),shuffle=True)\n" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": { 429 | "nbpresent": { 430 | "id": "e92d92be-440f-4f4f-a219-27c6f1a8273b" 431 | } 432 | }, 433 | "outputs": [], 434 | "source": [ 435 | "# summarize history for accuracy\n", 436 | "import matplotlib.pyplot as plt\n", 437 | "plt.plot(history.history['acc'])\n", 438 | "plt.plot(history.history['val_acc'])\n", 439 | "plt.title('model accuracy')\n", 440 | "plt.ylabel('accuracy')\n", 441 | "plt.xlabel('epoch')\n", 442 | "plt.legend(['train', 'validation'], loc='lower right')\n", 443 | "plt.show()\n" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "nbpresent": { 451 | "id": "7a3e7a2c-35a4-4d96-a632-82fa7ec953a9" 452 | } 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "# summarize history for loss\n", 457 | "plt.plot(history.history['loss'])\n", 458 | "plt.plot(history.history['val_loss'])\n", 459 | "plt.title('model loss')\n", 460 | "plt.ylabel('loss')\n", 461 | "plt.xlabel('epoch')\n", 462 | "plt.legend(['train', 'validation'], loc='upper right')\n", 463 | "plt.show()" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": { 470 | "nbpresent": { 471 | "id": "ad9da494-f1f3-4d0b-87b2-97d88881b017" 472 | } 473 | }, 474 | "outputs": [], 475 | "source": [ 476 | "#training with fit generator \n", 477 | "import random \n", 478 | "\n", 479 | "#note: in a real \"use case\" x_matrix & y_matrix would likely be stored in an hdf5 file rather than loaded into memory. \n", 480 | "#check out http://www.h5py.org/\n", 481 | "def create_generator(x_matrix,y_matrix,samples_to_yield): \n", 482 | " num_entries=x_matrix.shape[0]\n", 483 | " upper_bound=num_entries-samples_to_yield\n", 484 | " while 1: \n", 485 | " batch_index=random.randint(0,upper_bound)\n", 486 | " x=x_matrix[batch_index:batch_index+samples_to_yield]\n", 487 | " y=y_matrix[batch_index:batch_index+samples_to_yield]\n", 488 | " yield x,y\n", 489 | "\n", 490 | "\n", 491 | "#create the generator for training data \n", 492 | "train_generator=create_generator(X_train,y_train,batch_size)\n", 493 | "#create the generator for validation data \n", 494 | "valid_generator=create_generator(X_valid,y_valid,batch_size)\n", 495 | "\n", 496 | "samples_per_epoch=3000 \n", 497 | "nb_val_samples=1000\n", 498 | "\n", 499 | "nb_epoch=2 #for illustrative purposes, we'll use 2 epochs to save time \n", 500 | "history=model.fit_generator(train_generator, samples_per_epoch, nb_epoch, verbose=1, \n", 501 | " callbacks=[], validation_data=valid_generator, validation_steps=nb_val_samples)\n" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": { 508 | "nbpresent": { 509 | "id": "58a6176c-b33d-4c85-9790-ff69ea2c9a94" 510 | } 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "#evaluate \n", 515 | "score = model.evaluate(X_test, y_test, verbose=0)\n", 516 | "print('Test score:', score[0])\n", 517 | "print('Test accuracy:', score[1])\n", 518 | "\n", 519 | "\n", 520 | "#evaluate with generator (for large dataset)\n", 521 | "evaluation_generator=create_generator(X_test,y_test,batch_size)\n", 522 | "nb_eval_samples=X_test.shape[0] \n", 523 | "score=model.evaluate_generator(evaluation_generator,nb_eval_samples)\n", 524 | "print(\"With evaluate_generator:\")\n", 525 | "print('Test score:', score[0])\n", 526 | "print('Test accuracy:', score[1])\n" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": { 533 | "nbpresent": { 534 | "id": "70bd33c6-077d-4c62-b583-5158ad3e7b4c" 535 | } 536 | }, 537 | "outputs": [], 538 | "source": [ 539 | "#predict classes\n", 540 | "class_predictions=model.predict_classes(X_test) \n", 541 | "\n", 542 | "#predict probabilities \n", 543 | "class_probabilities=model.predict_proba(X_test)\n" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "nbpresent": { 551 | "id": "99e0e9b3-977b-46e1-a114-1459cc66cf85" 552 | } 553 | }, 554 | "outputs": [], 555 | "source": [ 556 | "#let's look at the prediction for our test digit of interest \n", 557 | "print(\"predicted class:\"+ str(class_predictions[digit_index]))\n", 558 | "print(\"predicted probability:\"+str([round(i,2) for i in class_probabilities[digit_index]]))" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": { 564 | "nbpresent": { 565 | "id": "764b027e-9c02-40b5-88c7-034ee2f0e113" 566 | } 567 | }, 568 | "source": [ 569 | "## Keras functional API ##\n", 570 | "The Keras functional API is the way to go for defining complex models, such as multi-output models, directed acyclic graphs, or models with shared layers.\n", 571 | "\n", 572 | "Use the functional API if you have: \n", 573 | "\n", 574 | "* multiple inputs & outputs \n", 575 | "* bypass layers\n", 576 | "* merge layers \n", 577 | "* basically any non-linear connection between layers \n", 578 | "\n", 579 | "\n", 580 | "There are a few key points to remember for the functional API: \n", 581 | "\n", 582 | "* A layer instance is callable (on a tensor), and it returns a tensor\n", 583 | "* Input tensor(s) and output tensor(s) can then be used to define a Model\n", 584 | "* Such a model can be trained just like Keras Sequential models.\n", 585 | "* All models are callable (just like layers). \n", 586 | "\n", 587 | "For example, let's look at a simple model that includes all layers required in the computation of output **b** given input **a**." 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "metadata": { 594 | "nbpresent": { 595 | "id": "8f3be052-69e1-4b95-90bd-b58879a8fcd2" 596 | } 597 | }, 598 | "outputs": [], 599 | "source": [ 600 | "#The simplest possible example for the functional API \n", 601 | "from keras.models import Model\n", 602 | "from keras.layers import Input, Dense\n", 603 | "\n", 604 | "a = Input(shape=(32,))\n", 605 | "b = Dense(32)(a)\n", 606 | "model = Model(inputs=a, outputs=b)\n" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "metadata": { 612 | "nbpresent": { 613 | "id": "fe336b7c-7d2e-4262-8427-e960ac68bd62" 614 | } 615 | }, 616 | "source": [ 617 | "Useful attributes of Model\n", 618 | "\n", 619 | "* `model.layers` is a flattened list of the layers comprising the model graph.\n", 620 | "* `model.inputs` is the list of input tensors.\n", 621 | "* `model.outputs` is the list of output tensors.\n", 622 | "\n", 623 | "It's very straightforward to extend this model formulation to multi-input and multi-output models: \n", 624 | "\n" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": { 631 | "nbpresent": { 632 | "id": "0bf44eb4-9f0f-401c-970c-e0ceff1dd14d" 633 | } 634 | }, 635 | "outputs": [], 636 | "source": [ 637 | "#example: multiple inputs and multiple outputs with the functional API \n", 638 | "a1 = Input(shape=(32,))\n", 639 | "a2 = Input(shape=(32,))\n", 640 | "\n", 641 | "b1 = Dense(32)(a1)\n", 642 | "b2 = Dense(32)(a2)\n", 643 | "\n", 644 | "model = Model(inputs=[a1, a2], outputs=[b1, b2])\n" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": { 651 | "nbpresent": { 652 | "id": "36a3cbf6-1b15-478d-bfea-5f9a7cdcc5ed" 653 | } 654 | }, 655 | "outputs": [], 656 | "source": [ 657 | "#let's examine the attributes of the model \n", 658 | "model.layers" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": { 665 | "nbpresent": { 666 | "id": "bf2d9c31-f1eb-486c-978b-0d9ab65e18fb" 667 | } 668 | }, 669 | "outputs": [], 670 | "source": [ 671 | "model.inputs" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "metadata": { 678 | "nbpresent": { 679 | "id": "a9748b3a-71a1-46ed-a369-c175193e8269" 680 | } 681 | }, 682 | "outputs": [], 683 | "source": [ 684 | "model.outputs" 685 | ] 686 | }, 687 | { 688 | "cell_type": "markdown", 689 | "metadata": { 690 | "nbpresent": { 691 | "id": "2a24687d-0d00-4e1a-98ea-1513ba27cd38" 692 | } 693 | }, 694 | "source": [ 695 | "## Example: Word embedding & training an LSTM with the functional API ## \n", 696 | " \n", 697 | "The functional API makes it easy to manipulate a large number of intertwined datastreams.\n", 698 | "\n", 699 | "Let's consider the following model. We seek to predict how many retweets and likes a news headline will receive on Twitter. The main input to the model will be the headline itself, as a sequence of words, but to spice things up, our model will also have an auxiliary input, receiving extra data such as the time of day when the headline was posted, etc. The model will also be supervised via two loss functions. \n", 700 | "\n", 701 | "Here's what our model looks like:\n", 702 | "![title](images/multi-input-multi-output-graph.png)" 703 | ] 704 | }, 705 | { 706 | "cell_type": "markdown", 707 | "metadata": { 708 | "nbpresent": { 709 | "id": "a25899d4-5b46-4b2e-b0be-beee3cc19a7c" 710 | } 711 | }, 712 | "source": [ 713 | "The main input will receive the headline, as a sequence of integers (each integer encodes a word). The integers will be between 1 and 10,000 (a vocabulary of 10,000 words) and the sequences will be 100 words long." 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": null, 719 | "metadata": { 720 | "nbpresent": { 721 | "id": "d934404a-046b-4c07-9912-5c7550d60900" 722 | } 723 | }, 724 | "outputs": [], 725 | "source": [ 726 | "from keras.layers import Input, Embedding, LSTM, Dense,Add\n", 727 | "from keras.layers.merge import Concatenate\n", 728 | "from keras.models import Model\n", 729 | "\n", 730 | "# headline input: meant to receive sequences of 100 integers, between 1 and 10000.\n", 731 | "# note that we can name any layer by passing it a \"name\" argument.\n", 732 | "main_input = Input(shape=(100,), dtype='int32', name='main_input')\n", 733 | "\n", 734 | "# this embedding layer will encode the input sequence\n", 735 | "# into a sequence of dense 512-dimensional vectors.\n", 736 | "x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)\n", 737 | "\n", 738 | "# a LSTM will transform the vector sequence into a single vector,\n", 739 | "# containing information about the entire sequence\n", 740 | "lstm_out = LSTM(32)(x)" 741 | ] 742 | }, 743 | { 744 | "cell_type": "markdown", 745 | "metadata": { 746 | "nbpresent": { 747 | "id": "585e7344-e7fc-42f6-ae3e-cdb70f3181e1" 748 | } 749 | }, 750 | "source": [ 751 | "Here we insert the auxiliary loss, allowing the LSTM and Embedding layer to be trained smoothly even though the main loss will be much higher in the model." 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "metadata": { 758 | "nbpresent": { 759 | "id": "5d6318c8-34f8-49c1-b35f-4e65cc51fc4c" 760 | } 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)" 765 | ] 766 | }, 767 | { 768 | "cell_type": "markdown", 769 | "metadata": { 770 | "nbpresent": { 771 | "id": "f3501821-5b61-4a9e-8977-4ca7628c228a" 772 | } 773 | }, 774 | "source": [ 775 | "At this point, we feed into the model our auxiliary input data by concatenating it with the LSTM output:" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": null, 781 | "metadata": { 782 | "nbpresent": { 783 | "id": "9af8785d-62b2-45a0-8084-18a8fd7178d1" 784 | } 785 | }, 786 | "outputs": [], 787 | "source": [ 788 | "auxiliary_input = Input(shape=(5,), name='aux_input')\n", 789 | "x = Concatenate()([lstm_out, auxiliary_input])\n", 790 | "\n", 791 | "# we stack a deep fully-connected network on top\n", 792 | "x = Dense(64, activation='relu')(x)\n", 793 | "x = Dense(64, activation='relu')(x)\n", 794 | "x = Dense(64, activation='relu')(x)\n", 795 | "\n", 796 | "# and finally we add the main logistic regression layer\n", 797 | "main_output = Dense(1, activation='sigmoid', name='main_output')(x)" 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "metadata": { 803 | "nbpresent": { 804 | "id": "e0c2ce08-3761-4376-b4b7-c805be534720" 805 | } 806 | }, 807 | "source": [ 808 | "This defines a model with two inputs and two outputs:" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "metadata": { 815 | "nbpresent": { 816 | "id": "b9ea717a-bc57-4a73-90ac-1e66eceba1ef" 817 | } 818 | }, 819 | "outputs": [], 820 | "source": [ 821 | "model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])" 822 | ] 823 | }, 824 | { 825 | "cell_type": "markdown", 826 | "metadata": { 827 | "nbpresent": { 828 | "id": "da0aa40c-a37f-4a15-9964-cb11c9da2fad" 829 | } 830 | }, 831 | "source": [ 832 | "We compile the model and assign a weight of 0.2 to the auxiliary loss. To specify different loss_weights or loss for each different output, you can use a list or a dictionary. Here we pass a single loss as the loss argument, so the same loss will be used on all outputs." 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": null, 838 | "metadata": { 839 | "nbpresent": { 840 | "id": "cfaf628a-5399-4067-a1e4-6fdf6647f064" 841 | } 842 | }, 843 | "outputs": [], 844 | "source": [ 845 | "model.compile(optimizer='rmsprop', loss='binary_crossentropy',\n", 846 | " loss_weights=[1., 0.2])" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": { 852 | "nbpresent": { 853 | "id": "3e07c460-a37a-4c19-b5ae-79f20806ab0e" 854 | } 855 | }, 856 | "source": [ 857 | "We can train the model by passing it lists of input arrays and target arrays:" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": null, 863 | "metadata": { 864 | "nbpresent": { 865 | "id": "bb36eb19-dfc8-4a63-b13b-e22cb7937be3" 866 | } 867 | }, 868 | "outputs": [], 869 | "source": [ 870 | "#model.fit([headline_data, additional_data], [labels, labels],\n", 871 | "# nb_epoch=50, batch_size=32)" 872 | ] 873 | }, 874 | { 875 | "cell_type": "markdown", 876 | "metadata": { 877 | "nbpresent": { 878 | "id": "e2798767-cd31-494d-a3f3-9bbb5ace434f" 879 | } 880 | }, 881 | "source": [ 882 | "Since our inputs and outputs are named (we passed them a \"name\" argument), We could also have compiled the model via:" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": null, 888 | "metadata": { 889 | "nbpresent": { 890 | "id": "e21134bb-3418-4624-ab48-8d47f0cf6b40" 891 | } 892 | }, 893 | "outputs": [], 894 | "source": [ 895 | "model.compile(optimizer='rmsprop',\n", 896 | " loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},\n", 897 | " loss_weights={'main_output': 1., 'aux_output': 0.2})\n", 898 | "\n", 899 | "# and trained it via:\n", 900 | "#model.fit({'main_input': headline_data, 'aux_input': additional_data},\n", 901 | "# {'main_output': labels, 'aux_output': labels},\n", 902 | "# nb_epoch=50, batch_size=32)" 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "metadata": { 908 | "nbpresent": { 909 | "id": "2a329ea9-8b2f-4afc-b9db-d44849104734" 910 | } 911 | }, 912 | "source": [ 913 | "## Getting fancy -- writing your own keras layers ##" 914 | ] 915 | }, 916 | { 917 | "cell_type": "markdown", 918 | "metadata": { 919 | "collapsed": true, 920 | "nbpresent": { 921 | "id": "cec3471e-f026-4f91-b2ac-6ccf79a4710a" 922 | } 923 | }, 924 | "source": [ 925 | "Here is the skeleton of a Keras layer. There are only three methods you need to implement:\n", 926 | "\n", 927 | "* **build(input_shape)**: this is where you will define your weights. Trainable weights should be added to the list self.trainable_weights. Other attributes of note are: self.non_trainable_weights (list) and self.updates (list of update tuples (tensor, new_tensor)). For an example of how to use non_trainable_weights and updates, see the code for the BatchNormalization layer.\n", 928 | "\n", 929 | "* **call(x)**: this is where the layer's logic lives. Unless you want your layer to support masking, you only have to care about the first argument passed to call: the input tensor.\n", 930 | "\n", 931 | "* **get_output_shape_for(input_shape)**: in case your layer modifies the shape of its input, you should specify here the shape transformation logic. This allows Keras to do automatic shape inference.\n" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": null, 937 | "metadata": { 938 | "nbpresent": { 939 | "id": "8b4e06ff-5b7e-4f93-9c9c-50ccd4fe1085" 940 | } 941 | }, 942 | "outputs": [], 943 | "source": [ 944 | "#skeleton code for a keras layer\n", 945 | "from keras import backend as K\n", 946 | "from keras.engine.topology import Layer\n", 947 | "import numpy as np\n", 948 | "\n", 949 | "class MyLayer(Layer):\n", 950 | " def __init__(self, output_dim, **kwargs):\n", 951 | " self.output_dim = output_dim\n", 952 | " super(MyLayer, self).__init__(**kwargs)\n", 953 | "\n", 954 | " def build(self, input_shape):\n", 955 | " input_dim = input_shape[1]\n", 956 | " initial_weight_value = np.random.random((input_dim, output_dim))\n", 957 | " self.W = K.variable(initial_weight_value)\n", 958 | " self.trainable_weights = [self.W]\n", 959 | "\n", 960 | " def call(self, x, mask=None):\n", 961 | " return K.dot(x, self.W)\n", 962 | "\n", 963 | " def get_output_shape_for(self, input_shape):\n", 964 | " return (input_shape[0], self.output_dim)" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": null, 970 | "metadata": {}, 971 | "outputs": [], 972 | "source": [] 973 | } 974 | ], 975 | "metadata": { 976 | "kernelspec": { 977 | "display_name": "Python 3", 978 | "language": "python", 979 | "name": "python3" 980 | }, 981 | "language_info": { 982 | "codemirror_mode": { 983 | "name": "ipython", 984 | "version": 3 985 | }, 986 | "file_extension": ".py", 987 | "mimetype": "text/x-python", 988 | "name": "python", 989 | "nbconvert_exporter": "python", 990 | "pygments_lexer": "ipython3", 991 | "version": "3.6.6" 992 | } 993 | }, 994 | "nbformat": 4, 995 | "nbformat_minor": 1 996 | } 997 | -------------------------------------------------------------------------------- /pytorch_mnist_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Basic imports\n", 10 | "import torch\n", 11 | "import torchvision\n", 12 | "from torchvision.datasets import MNIST\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "\n", 16 | "# Sort of a \"batteries included\" library for pytorch\n", 17 | "# Integrates with many useful sklearn tools like GridSearchCV\n", 18 | "# Implements boilerplate like early stopping that you might otherwise have to write yourself\n", 19 | "# However, because everything is made to be generally useful, it might not exactly fit your needs\n", 20 | "import skorch\n", 21 | "\n", 22 | "from sklearn.model_selection import train_test_split\n", 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": { 33 | "scrolled": false 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "mnist_train_data = MNIST('datasets', train=True, download=True, transform=torchvision.transforms.Compose([\n", 38 | " torchvision.transforms.ToTensor(),\n", 39 | "]))" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "mnist_test_data = MNIST('datasets', train=False, download=True, transform=torchvision.transforms.Compose([\n", 49 | " torchvision.transforms.ToTensor(),\n", 50 | "]))" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "In this example we won't actually need data loaders since they are abstracted away by skorch. However, if you need to roll some custom training code, you will need to do this." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "mnist_train_data_loader = torch.utils.data.DataLoader(\n", 67 | " mnist_train_data,\n", 68 | " num_workers=4, # This can be a useful option to speed up training by keeping your GPU \"fed\"b\n", 69 | " batch_size=32,\n", 70 | " shuffle=True,\n", 71 | ")\n", 72 | "mnist_test_data_loader = torch.utils.data.DataLoader(\n", 73 | " mnist_test_data,\n", 74 | " num_workers=4,\n", 75 | " batch_size=32,\n", 76 | " shuffle=False,\n", 77 | ")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Initial data exploration\n", 85 | "It is often useful to peek at your data. This can be useful to catch weird issues with your data early on, and is often useful for debugging your network later too." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "examples = enumerate(mnist_test_data_loader)\n", 95 | "i, (example_data, example_targets) = next(examples)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 6, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "torch.Size([32, 1, 28, 28])" 107 | ] 108 | }, 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "example_data.shape" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABxYAAAE9CAYAAAAmk4f5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAXEQAAFxEByibzPwAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de7RdZX0v/N8D4SLBV7mEShslilVE5I4KRa1XLBJFCcKB0VY5B2yHVFoQa4scELGOl47S0qpQ31GqYg9llNtBwAilXCtWg1wKAh7SYZAjDIkhQAKRmDzvH3tF03Rn/xbZc62551qfzxhrLPaa3z3nw5OsLyG/PfcutdYAAAAAAAAAmMpmbS8AAAAAAAAAmPkMFgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUgaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgJTBIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASLU+WCylbF1K+XQp5QellFWllB+XUi4spcxte20A6+gqYKbTU0AX6CqgC3QV0AW6CmhLqbW2d/FSto6IGyLioIh4NCJujYh5EfH6iHg8Ig6stS5ubYEAoauAmU9PAV2gq4Au0FVAF+gqoE2zWr7+n8ZE+d0eEe+qta6IiCilnBwRfxERF0bEWzb15KWUxyJim4j40fSXCsxAL42IZ2qtLxnwdXQVMB3D6KqB9lTvXLoKRpuuArpAVwFdoKuALtjkrmrtjsVSyhYR8ZOIeHFE7FtrvXOD43dHxJ4RsX+t9Y5NvMZTEfHC6a4VmNlqrWVQ59ZVQFMG1VXD6KneeXQVjAFdBXSBrgK6QFcBXbApXdXmz1g8OCbKb/GG5ddzae95/jSu4aspgOnSVcBMN4yeitBVwPToKqALdBXQBboKaFWbg8W9es/f28jx722QA2iDrgJmOj0FdIGuArpAVwFdoKuAVrU5WHxZ7/mRjRx/ZIMcQBt0FTDT6SmgC3QV0AW6CugCXQW0alaL19629/zMRo6v3CC3UaWU+zZyaNfnuyiADegqYKZrrKcidBUwMLoK6AJdBXSBrgJa1eYdi+t+IGRNjgO0SVcBM52eArpAVwFdoKuALtBVQKvavGPx6d7z7I0c36b3vCI7Ua31tZO93vtqi92f/9IAfkFXATNdYz0VoauAgdFVQBfoKqALdBXQqjbvWHy49zx3I8fnbpADaIOuAmY6PQV0ga4CukBXAV2gq4BWtTlYvLv3vO9Gjq97/Z4hrAVgY3QVMNPpKaALdBXQBboK6AJdBbSqzcHiv0bEkxGxaylln0mOL+g9Xz28JQH8F7oKmOn0FNAFugroAl0FdIGuAlrV2mCx1vpcRHy+9+HnSym/+J7QpZSTI2LPiLit1vrdNtYHEKGrgJlPTwFdoKuALtBVQBfoKqBtpdba3sVL2ToiboqIN0TEoxFxa0Ts0vv4pxHxxlrrQ9M4vx8wC2Og1loGeX5dBTRhkF016J7qXUNXwRjQVUAX6CqgC3QV0AWb0lVtfivUqLWuioi3RsRnIuKZiDg8IuZFxFciYp/plh9AE3QVMNPpKaALdBXQBboK6AJdBbSp1TsWB81XVcB4GPQdi4Omq2A86CqgC3QV0AW6CugCXQV0QefuWAQAAAAAAAC6wWARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIDWr7QUAwDB8/OMfTzMveMEL0syee+6ZZhYsWNDXmjLnn39+mrn99tvTzEUXXdTEcgAAAACAMeeORQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAqVltLwAApuuSSy5JMwsWLBjCSiasXbu2kfN85CMfSTPveMc70szNN9+cZh5++OG+1gSwKV71qlelmQceeCDNnHTSSWnmb/7mb/paE9C+2bNnp5k///M/TzP9/JnpjjvuSDNHHnlkmlmyZEmaAQCAUeaORQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAqVltLwAApnLJJZekmQULFgxhJRMeeOCBNPPNb34zzbziFa9IM/Pnz08zu+66a5o59thj08znPve5NAOwqfbZZ580s3bt2jTzyCOPNLEcYIbYeeed08zxxx+fZvrpj/322y/NHHbYYWnmC1/4QpoBZoZ99903zVx++eVpZt68eQ2sppve9a53pZn7779/yuM/+tGPmloO0CH9/J3WVVddlWZOPPHENHPBBRekmTVr1qQZ+ueORQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAqVltLwCA8bX//vunmfe///2NXOu+++5LM+9973vTzNKlS9PMihUr0syWW26ZZr797W+nmb322ivN7LDDDmkGYJD23nvvNLNy5co0c8UVVzSxHGAI5syZk2a+8pWvDGElwLg65JBD0sxWW201hJV01/z589PMcccdN+Xxo48+uqnlADNEP3/P9MUvfrGRa33+859PMxdeeGGaefbZZ5tYDj3uWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACA1Ky2F8CmWbBgwZTHjz/++PQcP/7xj9PMqlWr0sw//MM/pJnHHnsszTz00ENpBhgtO++8c5oppaSZ++67L80ccsghaebRRx9NM0055ZRT0szuu+/eyLWuueaaRs4DMJk99tgjzZx44olp5qKLLmpiOcAQfOxjH0szhx9+eJp5/etf38RyGvPmN785zWy2Wf712XfffXeaueWWW/paEzC5WbPyv9I89NBDh7CS0XbHHXekmZNPPnnK47Nnz07PsXLlyr7XBLSvnz8zzZ07t5FrXXzxxWmmnxkGzXLHIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACA1Ky2F8CmOeecc6Y8Pm/evOEsJCI+8pGPpJmnn346zdx3331NLGdkPfLII1Mez35PREQsWrSoqeVAI77+9a+nmVe+8pVppp+OWbZsWV9rGpajjz46zWyxxRZDWAnA9Oy2225pZvbs2WnmkksuaWI5wBD85V/+ZZpZu3btEFbSrA984AONZJYsWZJmjjrqqDRzxx13pBkYV29961vTzIEHHphm+vm7lHG23XbbpZndd999yuPbbLNNeo6VK1f2vSZgsLbaaqs0c9pppw1hJRMuuuiiNFNrHcJKWJ87FgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUgaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApGa1vQA2zfHHHz/l8T333DM9x/33359mXvOa16SZfffdN8385m/+Zpp54xvfmGZ+9KMfpZmXvvSlaaYpP//5z9PM448/nmZ23nnnaa/l4YcfTjOLFi2a9nVg2JYsWdL2Ep63U089Nc286lWvauRa//Zv/9ZIBmBTfeITn0gz/XS5P6fAzHDttdemmc02697XKP/0pz9NMytWrEgzu+yyS5p5+ctfnma+853vpJnNN988zcAo2mOPPdLMxRdfnGYWL16cZv7sz/6srzWNq/e9731tLwEYste97nVpZr/99mvkWv383fo3vvGNRq5Fs7r3fwMAAAAAAADA0BksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkGpksFhK2a+U8slSyuWllP9bSqmllFV9fN7vlFK+U0pZUUpZVkq5tpRyUBNrAtiQrgK6QFcBXaCrgC7QVUAX6Cqga2Y1dJ7TI+J9z+cTSinnRsQfRcSzEXFdRGwdEe+MiHeVUo6stV7R0NoA1tFVQBfoKqALdBXQBboK6AJdBXRKU4PF2yPi7oj4bu/x2FThUsrbYqL4fhoRB9Za/0/v9QMj4qaI+PtSyk211icaWt/IueGGG6Z1vF8LFy5s5Dzbbbddmtl7773TzB133JFmDjjggL7W1IRVq9IvHoof/OAHaeb+++9PM9tvv/2UxxcvXpyeA13F9B122GFp5qyzzkozW265ZZr5yU9+kmb+5E/+JM0888wzaYYZRVcxY8ybNy/N7L///mmmnz8PrVy5sp8lMXPoqg56y1vekmZe/epXp5m1a9c2kmnKBRdckGauu+66NPPkk0+mmbe97W1p5rTTTksz/fj93//9KY+ff/75jVxnxOmqDvrUpz6VZmbPnp1m3v3ud6eZFStW9LWmUZT9PVNEf//dGGbfjzBdxYxxxBFHDO1a/fz5jJmpkcFirfX/Xf/jUkr2Kaf0ns9eV3y989xeSrkgIj4WEcdFxF80sT6ACF0FdIOuArpAVwFdoKuALtBVQNc08jMWn49SytYR8fbeh5dOEln32vzhrAjgv9JVQBfoKqALdBXQBboK6AJdBcwEQx8sRsRuEbFVRDxea31kkuPf6z3vObwlAfwXugroAl0FdIGuArpAVwFdoKuA1rUxWHxZ73my4ota68qIWB4R25VSXji0VQH8Z7oK6AJdBXSBrgK6QFcBXaCrgNY18jMWn6dte8/PTJFZGREv7mWfzk5YSrlvI4d2fX5LA/gFXQV0ga4CukBXAV2gq4Au0FVA69q4Y3HdT5+tfWQA2qKrgC7QVUAX6CqgC3QV0AW6CmhdG3csrvsqidlTZLbpPa/o54S11tdO9nrvqy12739pAL+gq4Au0FVAF+gqoAt0FdAFugpoXRt3LD7ce5472cFSyuyYuFV7ea01vVUbYEB0FdAFugroAl0FdIGuArpAVwGta+OOxQcj4mcRMaeUMrfWuuEPmt2393zPcJfFID3xxBNp5sYbb2zkWjfccEMj52nKEUcckWa22267NPPv//7vUx6/5JJL+l4TfdFVTGr//fdPM1tuuWUj1+rnfX3zzTc3ci06S1cxUG95y1saOc/jjz/eyHnoLF01BPPmzUsz//iP/5hmdtxxxwZW058lS5akmcsuuyzNfPrTn04zzzwz1Y+i6l8/az7hhBPSzJw5c9LMOeecM+XxrbfeOj3H5z//+TSzevXqNDMmdNUQLFiwIM0ceuihaeahhx5KM4sWLeprTePqtNNOSzNr165NMzfddNOUx5cvX97vkuiPrmKg3vzmNzdynueeey7N9NNDzExDv2Ox1vpsRPxL78PJ/jSx7rWrh7MigP9KVwFdoKuALtBVQBfoKqALdBUwE7TxrVAjIs7tPX+qlPLr614spRwYER+JiKci4u/aWBjAenQV0AW6CugCXQV0ga4CukBXAa1q5FuhllLeExGnb/DylqWUb6/38WdqrddERNRa/7mUcl5EnBQRd5VSro+ILSPinTEx7Dy21rqsibUBrKOrgC7QVUAX6CqgC3QV0AW6Cuiapn7G4pyIeMMGr5UNXvtPP0Cg1vqHpZS7IuLEmCi91RFxQ0ScXWu9raF1AaxPVwFdoKuALtBVQBfoKqALdBXQKY0MFmutX46ILw/r8wA2ha4CukBXAV2gq4Au0FVAF+gqoGva+hmLAAAAAAAAQIcYLAIAAAAAAAApg0UAAAAAAAAg1cjPWIRxtdNOO6WZL37xi2lms83yGf9ZZ5015fFly5al5wCmduWVV6aZd73rXY1c66tf/Wqa+dSnPtXItQA21ete97pGznPOOec0ch5g42bNyv/3fscddxzCSibcfPPNaeboo49OM0uXLm1iOY1ZsmRJmvnc5z6XZs4999w0s80220x5vJ9uveqqq9LM4sWL0ww05cgjj0wz2e/9iP7+rmWczZs3L80ce+yxaWbNmjVp5uyzz57y+OrVq9NzAMNx0EEHNZLpx8qVK9PMXXfd1ci1GD53LAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgJTBIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASM1qewHQZR/96EfTzJw5c9LME088kWYefPDBvtYETG7nnXdOMwcddFCa2WqrrdLM0qVL08zZZ5+dZlasWJFmADbVG9/4xjTz4Q9/OM3ceeedaeb666/va01ANyxatCjNHHfccWmmnz8zddFVV12VZo499tg0c8ABBzSxHBiqF73oRVMe7+fPH/04//zzGznPqDrhhBPSzI477phm7r///jRz44039rUmoH3D/LOFnh5t7lgEAAAAAAAAUgaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgJTBIgAAAAAAAJCa1fYCYKb6jd/4jTTzyU9+spFrHX744Wnm3nvvbeRaMK4uu+yyNLPDDjs0cq2vfe1raWbx4sWNXAtgU73jHe9IM9tvv32aWbhwYZpZtWpVX2sCBmuzzZr52uI3vOENjZxnVJVS0kw/vxZN/HqdeeaZaea3f/u3p30dWGerrbaa8viv/dqvpee4+OKLm1rO2Np1110bOY+/i4LRsv/++zdynuXLl6eZ888/v5FrMTO5YxEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUgaLAAAAAAAAQGpW2wuAmerQQw9NM1tssUWaueGGG9LM7bff3teagMm9973vTTP77rtvI9e66aab0swZZ5zRyLUABmmvvfZKM7XWNHPppZc2sRxgmn7v934vzaxdu3YIK2H+/PlpZp999kkz2a9XP7+eZ555ZpqBJj399NNTHr/rrrvSc+y5555pZvvtt08zy5YtSzNdtNNOO6WZBQsWNHKt2267rZHzAIN38MEHp5ljjjmmkWs9+eSTaeaRRx5p5FrMTO5YBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAICUwSIAAAAAAACQmtX2AqANL3jBC9LMu9/97jTz3HPPpZkzzjgjzaxevTrNwLjaYYcd0syf/umfppktttiiieXEXXfdlWZWrFjRyLUANtVLXvKSNPOmN70pzTz44INp5oorruhrTcBgzZ8/v+0ldN6cOXPSzO67755m+vmzaRMef/zxNOP/NRm2Z599dsrjixcvTs9xxBFHpJlrrrkmzZx77rlpZpj22GOPNPOKV7wizcybNy/N1Fr7WVJq7dq1jZwHGLx+/v5ss82auc/s+uuvb+Q8dJc7FgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUgaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApGa1vQBow6mnnppm9tlnnzSzcOHCNPOtb32rrzUBkzvllFPSzAEHHNDIta688so0c8YZZzRyLYBB+tCHPpRmdtpppzTzjW98o4HVAHTDaaedlmY++tGPDmElE374wx9Oefx3f/d303M8/PDDDa0GmtHP/0+VUtLMe97znjRz8cUX97WmYVm6dGmaqbWmmR133LGJ5fTly1/+8tCuBUzPggULGjnP8uXL08zf/u3fNnItussdiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUrPaXgA07T3veU+aOf3009PMU089lWbOOuusvtYEbLqTTz55aNc68cQT08yKFSuGsBKA6dlll10aOc8TTzzRyHkA2nbttdemmVe/+tVDWEn/vv/97095/LbbbhvSSqA5DzzwQJr54Ac/mGb23nvvNPPKV76yrzUNy6WXXtrIeb7yla+kmWOPPbaRaz377LONnAeYnrlz56aZY445ppFrPfLII2lm0aJFjVyL7nLHIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACA1Ky2FwDPxw477JBm/vqv/zrNbL755mnm2muvTTPf/va30wzQHdtvv32aWb169RBW0r8nn3wyzfSz5i222CLNvOhFL+prTVN58YtfnGZOPvnkaV+nX2vWrEkzf/zHf5xmnnnmmSaWA4057LDDGjnP17/+9UbOAwxeKSXNbLZZM19b/Fu/9VuNnOdLX/pSmvnVX/3VRq7Vz7/72rVrG7lWU+bPn9/2EmDGuuuuuxrJdNF//Md/DO1ae+yxR5q59957h7ASGG8HHXRQmmnqz3lXXnllI+dhtLljEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgNS0B4ullG1KKYeXUv6ulHJPKeWpUsrKUsrdpZT/WUrZdorP/Z1SyndKKStKKctKKdeWUg6a7poANqSrgC7QVUAX6CqgC3QV0AW6CuiiJu5YPCYiroiI43rnWxgRt0bEyyPi0xHx3VLKTht+Uinl3Ij4SkTsERH/HBHfiYh3RsQtpZT3N7AugPXpKqALdBXQBboK6AJdBXSBrgI6p4nB4nMRcX5EvKrWuket9YO11ndHxKsj4s6I2C0i/mr9TyilvC0i/igifhoRe9VaD+99zpsjYk1E/H0pZbsG1gawjq4CukBXAV2gq4Au0FVAF+gqoHNmTfcEtdavRsRXJ3n90VLKRyPiWxHxgVLKlrXW53qHT+k9n11r/T/rfc7tpZQLIuJjMfFVGn8x3fXRHZtvvnmaWbhwYZp5+ctfnmYWL16cZk4//fQ0Q3foKvpxzz33tL2E5+2f/umf0syjjz6aZn7lV34lzRx11FF9rWnUPPbYY2nms5/9bCPX0lX04+CDD04zL3nJS4awEsaVrpqZzj///DRzzjnnNHKtq6++Os2sXbu2kWs1dZ6Zdq0LLrhgaNcaV7qKUVVKaSTTj3vvvbeR87Bxuop+7LDDDo2cZ+nSpWnmvPPOa+RajLYm7licyt29560iYoeIiFLK1hHx9t7rl07yOetemz/YpQH8gq4CukBXAV2gq4Au0FVAF+gqYEYa9GDxFb3n1RGxrPfPu8VEGT5ea31kks/5Xu95zwGvDWAdXQV0ga4CukBXAV2gq4Au0FXAjDToweJJveeFtdaf9f75Zb3nyYovaq0rI2J5RGxXSnnhgNcHEKGrgG7QVUAX6CqgC3QV0AW6CpiRpv0zFjemlHJoRPz3mPiKivV/WN22vednpvj0lRHx4l726T6udd9GDu2arxQYZ7oK6AJdBXSBrgK6QFcBXaCrgJlsIHcsllJeExFfi4gSEafWWu9e/3DvuU51ikGsC2B9ugroAl0FdIGuArpAVwFdoKuAma7xOxZLKXMjYmFEbBcR59Zaz9sgsu6rJGZPcZptes8r+rlmrfW1G1nLfRGxez/nAMaLrgK6QFcBXaCrgC7QVUAX6CqgCxq9Y7GUsmNEXB8T3+v57yPi45PEHu49z93IOWbHxK3ay2ut6a3aAM+XrgK6QFcBXaCrgC7QVUAX6CqgKxq7Y7H3w2C/ERG7RcTlEXF8rXWyW7IfjIifRcScUsrcWuuGP2h2397zPU2tjW7Yddf823bvt99+jVzr5JNPTjOLFy9u5FrMLLqqe6699to08773vW8IK5mZjjzyyLaX8J/8/Oc/n/L42rVrG7nOVVddlWYWLVrUyLVuvfXWRs7zfOgqpvL+978/zWy++eZp5s4770wzt9xyS19rYjzpqpnl8ssvTzOnnnpqmpkzZ04Ty+mkxx9/PM3cf//9aeaEE05IM48++mhfa2L6dBWjZvLfvs8/w8yiq5jKIYcc0sh5Hn744TTz5JNPNnItRlsjdyyWUraKiP8dEftHxDcj4r/VWtdMlq21PhsR/9L7cMEkkXWvXd3E2gDW0VVAF+gqoAt0FdAFugroAl0FdM20B4ullM0j4uKIeGtE3BoRH6i1Ppd82rm950+VUn59vXMdGBEfiYinIuLvprs2gHV0FdAFugroAl0FdIGuArpAVwFd1MS3Qj0xItZ9P6alEfHFUspkuY/XWpdGRNRa/7mUcl5EnBQRd5VSro+ILSPinTEx7Dy21rqsgbUBrKOrgC7QVUAX6CqgC3QV0AW6CuicJgaL2633z1P9wJczY6IcIyKi1vqHpZS7YqI83xkRqyPihog4u9Z6WwPrAlifrgK6QFcBXaCrgC7QVUAX6Cqgc6Y9WKy1nhkTxbYpn/vliPjydNcAkNFVQBfoKqALdBXQBboK6AJdBXTRtH/GIgAAAAAAADD6DBYBAAAAAACAlMEiAAAAAAAAkJr2z1iEfuyyyy5p5rrrrmvkWqeeemqaufrqqxu5FjB4H/jAB9LMJz7xiTSzxRZbNLGcvrz2ta9NM0cdddQQVjLhwgsvTDM//OEPG7nWZZddNuXxBx54oJHrwKjaZptt0syhhx7ayLUuvfTSNLNmzZpGrgUM3pIlS9LM0UcfnWYOP/zwNHPSSSf1taau+exnP5tmvvCFLwxhJQAbt/XWWzdynmeffbaR8wDT08/fV+26666NXGvVqlVpZvXq1Y1ci9HmjkUAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgJTBIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAAKlZbS+A8XDCCSekmZe97GWNXOvmm29OM7XWRq4FzAznnHNO2y6vdW4AAAw+SURBVEt43o455pi2lwDMQKtXr04zTzzxRJq56qqr0sx5553X15qA0XHLLbc0krnuuuvSTD//Dzh//vw000+ffelLX0ozpZQ08/3vfz/NALTtwx/+cJpZvnx5mvnMZz7TxHKAaVq7dm2aWbRoUZrZY4890sxDDz3U15og445FAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAICUwSIAAAAAAACQMlgEAAAAAAAAUgaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAACpWW0vgO47+OCD08wf/MEfDGElAADdtnr16jRz0EEHDWElABu3cOHCRjIAPH/f/e5308y5556bZm688cYmlgNM05o1a9LMaaedlmZqrWnmjjvu6GtNkHHHIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACA1Ky2F0D3velNb0oz2267bSPXWrx4cZpZsWJFI9cCAAAAgJlk/vz5bS8BGLIf//jHaea4444bwkpggjsWAQAAAAAAgJTBIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkZrW9AFjn7rvvTjNvf/vb08yyZcuaWA4AAAAAAADrccciAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAIBUqbW2vYaBKaXcFxG7t70OYLBqraXtNUyHroLxoKuALtBVQBfoKqALdBXQBZvSVe5YBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAKYNFAAAAAAAAIGWwCAAAAAAAAKQMFgEAAAAAAIDUqA8WX9r2AgD6oKuALtBVQBfoKqALdBXQBboKmNSsthcwYM/0nn+03mu79p4XD3kt48L+Dp49/qWXxi/f5122YVf5NR48ezx49viXdBWbyh4Pnj3+JV3FprLHg2ePf0lXsans8eDZ41/SVWwqezx49viXNrmrSq214bXMbKWU+yIiaq2vbXsto8j+Dp49Hn1+jQfPHg+ePR59fo0Hzx4Pnj0efX6NB88eD549Hn1+jQfPHg+ePR59fo0Hzx4Pnj1uxqh/K1QAAAAAAACgAQaLAAAAAAAAQMpgEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgVWqtba8BAAAAAAAAmOHcsQgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABIGSwCAAAAAAAAqbEYLJZSti6lfLqU8oNSyqpSyo9LKReWUua2vbYuKaXsV0r5ZCnl8lLK/y2l1FLKqj4+73dKKd8ppawopSwrpVxbSjloGGvuklLKNqWUw0spf1dKuaeU8lQpZWUp5e5Syv8spWw7xefa4xGgq6ZPTw2erkJXTZ+uGjxdha6aPl01eLoKXTV9umqw9BQRuqoJumqwdNXwlVpr22sYqFLK1hFxQ0QcFBGPRsStETEvIl4fEY9HxIG11sWtLbBDSilXRsT7Nnj5Z7XWraf4nHMj4o8i4tmIuC4ito6It0dEiYgja61XDGi5nVNK+R8R8f/1PrwvIr4fEf9PTPzefWFEPBARb6m1/mSDz7PHI0BXNUNPDZ6uGm+6qhm6avB01XjTVc3QVYOnq8abrmqGrhosPYWuaoauGixd1YJa60g/IuKsiKgR8a2I2Ha910/uvX5z22vsyiMi/jgiPh0Rh0XEr/T2b9UU+bf1Mksj4tfXe/3AiPhZRCyPiO3a/veaKY+I+J2I+OL6e9V7feeI+F5vL/+XPR7Nh65qbB/11OD3WFeN8UNXNbaPumrwe6yrxvihqxrbR101+D3WVWP80FWN7aOuGuz+6qkxf+iqxvZRVw12f3XVsPe87QUM9F8uYouIeKL3G2SfSY7f3Tu2X9tr7eKjjwK8ppf5w0mOndc7dkrb/x5dePQKrUbEqojY0h6P1kNXDXRv9dRw91tXjfBDVw10b3XVcPdbV43wQ1cNdG911XD3W1eN8ENXDXRvddXw9lpPjfhDVw10b3XV8PZaVw3gMeo/Y/HgiHhxRCyutd45yfFLe8/zh7ek8dC7Tf7tvQ8vnSRi75+fu3vPW0XEDhH2eMToqhZ4Dw2ErhptuqoF3kMDoatGm65qgffQQOiq0aarWuA91Dg9Nfp0VQu8jxqnqwZg1AeLe/Wev7eR49/bIEdzdouJN+vjtdZHJjm+bu/3HN6SOu0VvefVEbGs98/2eHToqnZ4DzVPV402XdUO76Hm6arRpqva4T3UPF012nRVO7yHmqWnRp+uaof3UbN01QCM+mDxZb3nyX5zrP/6yzZynE035d7XWldG7/sUl1JeOLRVdddJveeFtdaf9f7ZHo8OXdUO76Hm6arRpqva4T3UPF012nRVO7yHmqerRpuuaof3ULP01OjTVe3wPmqWrhqAUR8sbtt7fmYjx1dukKM52d5H2P++lFIOjYj/HhNfVXH6eofs8ejQVe3wHmqQrhoLuqod3kMN0lVjQVe1w3uoQbpqLOiqdngPNURPjQ1d1Q7vo4boqsEZ9cFi6T3X5DjNy/Z+/QwbUUp5TUR8LSb26tRa693rH+492+Pu01Xt8B5qiK4aG7qqHd5DDdFVY0NXtcN7qCG6amzoqnZ4DzVAT40VXdUO76MG6KrBGvXB4tO959kbOb5N73nFENYybrK9j7D/UyqlzI2IhRGxXUScW2s9b4OIPR4duqod3kMN0FVjRVe1w3uoAbpqrOiqdngPNUBXjRVd1Q7voWnSU2NHV7XD+2iadNXgjfpg8eHe89yNHJ+7QY7mTLn3pZTZEfHiiFhea316ssw4K6XsGBHXx8T3e/77iPj4JDF7PDp0VTu8h6ZJV40dXdUO76Fp0lVjR1e1w3tomnTV2NFV7fAemgY9NZZ0VTu8j6ZBVw3HqA8W193euu9Gjq97/Z4hrGXcPBgRP4uIOb2vENiQvd+I3g+E/UZE7BYRl0fE8bXWyW7LtsejQ1e1w3toGnTVWNJV7fAemgZdNZZ0VTu8h6ZBV40lXdUO76FNpKfGlq5qh/fRJtJVwzPqg8V/jYgnI2LXUso+kxxf0Hu+enhLGg+11mcj4l96Hy6YJGLvJ1FK2Soi/ndE7B8R34yI/1ZrXTNZ1h6PFF3VAu+hTaerxpauaoH30KbTVWNLV7XAe2jT6aqxpata4D20afTUWNNVLfA+2jS6ashqrSP9iIizY+KHcP5rRMxe7/WTe6/f2vYau/ro7d+qKY6/o5dZGhG/vt7rB0bEqpj4D9P2bf97zJRHRGweE19JUSPilojYpo/Psccj8tBVA9tXPdX8nuqqMX7oqoHtq65qfk911Rg/dNXA9lVXNb+numqMH7pqYPuqq5rdTz015g9dNbB91VXN7qeuGvKj9DZrZJVSto6ImyLiDRHxaETcGhG79D7+aUS8sdb6UGsL7JBSynsi4vT1XnpDTLz5vrPea5+ptV6z3uf8VUScFBHPxMT3Nt4yIt4ZE3fLfrDWetmg190VpZSTIuKveh9eERFPbST68Vrr0vU+zx6PAF3VDD01eLpqvOmqZuiqwdNV401XNUNXDZ6uGm+6qhm6arD0FLqqGbpqsHRVC9qebA7jEREviIizIuKhmPjeuY9FxJcj4qVtr61Lj4j4UEwU3lSPD23k8xZFxMqIWB4RCyPi4Lb/fWbaIyLO7GN/a0TMs8ej+dBVjeyhnhr8HuuqMX/oqkb2UFcNfo911Zg/dFUje6irBr/HumrMH7qqkT3UVYPdXz3loaua2UNdNdj91VVDfoz8HYsAAAAAAADA9G3W9gIAAAAAAACAmc9gEQAAAAAAAEgZLAIAAAAAAAApg0UAAAAAAAAgZbAIAAAAAAAApAwWAQAAAAAAgJTBIgAAAAAAAJAyWAQAAAAAAABSBosAAAAAAABAymARAAAAAAAASBksAgAAAAAAACmDRQAAAAAAACBlsAgAAAAAAACkDBYBAAAAAACAlMEiAAAAAAAAkDJYBAAAAAAAAFIGiwAAAAAAAEDKYBEAAAAAAABI/f/EnzZoh5EqVQAAAABJRU5ErkJggg==\n", 126 | "text/plain": [ 127 | "
" 128 | ] 129 | }, 130 | "metadata": { 131 | "needs_background": "light" 132 | }, 133 | "output_type": "display_data" 134 | } 135 | ], 136 | "source": [ 137 | "fig, axes = plt.subplots(dpi=150, ncols=6, figsize=(15, 3))\n", 138 | "for i, ax in enumerate(axes):\n", 139 | " ax.imshow(example_data[i][0], cmap='gray', interpolation='none')\n", 140 | "fig.show()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "class ConvNet(nn.Module):\n", 150 | " def __init__(self, dropout=0.5):\n", 151 | " \"\"\"\n", 152 | " You need to define the *learning* components of your network here. For example, since\n", 153 | " the ReLU function isn't something we learn, we do not need to instantiate it here.\n", 154 | " \"\"\"\n", 155 | " super(ConvNet, self).__init__()\n", 156 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 157 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 158 | " self.conv2_drop = nn.Dropout2d(p=dropout) # We will need to deal with this later!\n", 159 | " self.fc1 = nn.Linear(320, 50)\n", 160 | " self.fc2 = nn.Linear(50, 10)\n", 161 | " \n", 162 | " def forward(self, x):\n", 163 | " \"\"\"\n", 164 | " This is where the model is actually put together\n", 165 | " \"\"\"\n", 166 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 167 | " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", 168 | " x = x.view(-1, 320) # Reshapes\n", 169 | " x = F.relu(self.fc1(x))\n", 170 | " x = F.dropout(x, training=self.training) # This is EXTREMELY IMPORTANT\n", 171 | " x = self.fc2(x)\n", 172 | " return F.softmax(x, dim=1) # What you use here will dictate what loss function you use!" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 9, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "from skorch import NeuralNetClassifier\n", 182 | "from skorch.dataset import CVSplit\n", 183 | "from skorch.helper import SliceDataset\n", 184 | "from skorch.callbacks import ProgressBar, EarlyStopping\n", 185 | "from sklearn.model_selection import GridSearchCV" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "Usually you have to write your own function to train your network. With skorch, this is not necessary, and makes training these a breeze!" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 10, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "application/vnd.jupyter.widget-view+json": { 203 | "model_id": "", 204 | "version_major": 2, 205 | "version_minor": 0 206 | }, 207 | "text/plain": [ 208 | "HBox(children=(IntProgress(value=0, max=1876), HTML(value='')))" 209 | ] 210 | }, 211 | "metadata": {}, 212 | "output_type": "display_data" 213 | }, 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "\r", 219 | " epoch train_loss valid_acc valid_loss dur\n", 220 | "------- ------------ ----------- ------------ ------\n", 221 | " 1 \u001b[36m0.8856\u001b[0m \u001b[32m0.9159\u001b[0m \u001b[35m0.2769\u001b[0m 8.2096\n" 222 | ] 223 | }, 224 | { 225 | "data": { 226 | "application/vnd.jupyter.widget-view+json": { 227 | "model_id": "", 228 | "version_major": 2, 229 | "version_minor": 0 230 | }, 231 | "text/plain": [ 232 | "HBox(children=(IntProgress(value=0, max=1876), HTML(value='')))" 233 | ] 234 | }, 235 | "metadata": {}, 236 | "output_type": "display_data" 237 | }, 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "\r", 243 | " 2 \u001b[36m0.3907\u001b[0m \u001b[32m0.9496\u001b[0m \u001b[35m0.1686\u001b[0m 7.9793\n" 244 | ] 245 | }, 246 | { 247 | "data": { 248 | "application/vnd.jupyter.widget-view+json": { 249 | "model_id": "", 250 | "version_major": 2, 251 | "version_minor": 0 252 | }, 253 | "text/plain": [ 254 | "HBox(children=(IntProgress(value=0, max=1876), HTML(value='')))" 255 | ] 256 | }, 257 | "metadata": {}, 258 | "output_type": "display_data" 259 | }, 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "\r", 265 | " 3 \u001b[36m0.2891\u001b[0m \u001b[32m0.9587\u001b[0m \u001b[35m0.1357\u001b[0m 7.7666\n" 266 | ] 267 | }, 268 | { 269 | "data": { 270 | "application/vnd.jupyter.widget-view+json": { 271 | "model_id": "", 272 | "version_major": 2, 273 | "version_minor": 0 274 | }, 275 | "text/plain": [ 276 | "HBox(children=(IntProgress(value=0, max=1876), HTML(value='')))" 277 | ] 278 | }, 279 | "metadata": {}, 280 | "output_type": "display_data" 281 | }, 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "\r", 287 | " 4 \u001b[36m0.2439\u001b[0m \u001b[32m0.9681\u001b[0m \u001b[35m0.1117\u001b[0m 7.7007\n" 288 | ] 289 | }, 290 | { 291 | "data": { 292 | "application/vnd.jupyter.widget-view+json": { 293 | "model_id": "", 294 | "version_major": 2, 295 | "version_minor": 0 296 | }, 297 | "text/plain": [ 298 | "HBox(children=(IntProgress(value=0, max=1876), HTML(value='')))" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "output_type": "display_data" 303 | }, 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "\r", 309 | " 5 \u001b[36m0.2181\u001b[0m \u001b[32m0.9705\u001b[0m \u001b[35m0.1000\u001b[0m 7.6806\n" 310 | ] 311 | }, 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "[initialized](\n", 316 | " module_=ConvNet(\n", 317 | " (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", 318 | " (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n", 319 | " (conv2_drop): Dropout2d(p=0.5, inplace=False)\n", 320 | " (fc1): Linear(in_features=320, out_features=50, bias=True)\n", 321 | " (fc2): Linear(in_features=50, out_features=10, bias=True)\n", 322 | " ),\n", 323 | ")" 324 | ] 325 | }, 326 | "execution_count": 10, 327 | "metadata": {}, 328 | "output_type": "execute_result" 329 | } 330 | ], 331 | "source": [ 332 | "torch.manual_seed(1234) # Because reproducibility is good\n", 333 | "\n", 334 | "net = NeuralNetClassifier(\n", 335 | " ConvNet,\n", 336 | " max_epochs=5, # You would train many more epochs in practice - this is just a toy example\n", 337 | " iterator_train__num_workers=6, # Can increase to keep your GPU \"fed\" with data\n", 338 | " iterator_valid__num_workers=6,\n", 339 | " lr=0.0002,\n", 340 | " device=DEVICE,\n", 341 | " criterion=nn.NLLLoss, # This is your loss function\n", 342 | " optimizer=torch.optim.Adam, # This is your optimizer\n", 343 | " batch_size=32,\n", 344 | " callbacks=[\n", 345 | " ProgressBar(), # Nice visual progress bar as you are training\n", 346 | " EarlyStopping(patience=5, monitor='valid_loss'), # Stops training if we see no improvement in validation loss for x epochs\n", 347 | " ],\n", 348 | ")\n", 349 | "y_train = np.array([y for x, y in mnist_train_data])\n", 350 | "net.fit(mnist_train_data, y=y_train)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 11, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "from sklearn import metrics\n", 360 | "y_pred = net.predict(mnist_test_data)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 12, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "text/plain": [ 371 | "0.9736" 372 | ] 373 | }, 374 | "execution_count": 12, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ], 379 | "source": [ 380 | "metrics.accuracy_score(\n", 381 | " [y for x, y in mnist_test_data],\n", 382 | " y_pred,\n", 383 | ")" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "## But what about hyperparameters?" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 13, 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "data": { 400 | "text/plain": [ 401 | "[initialized](\n", 402 | " module_=ConvNet(\n", 403 | " (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", 404 | " (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n", 405 | " (conv2_drop): Dropout2d(p=0.5, inplace=False)\n", 406 | " (fc1): Linear(in_features=320, out_features=50, bias=True)\n", 407 | " (fc2): Linear(in_features=50, out_features=10, bias=True)\n", 408 | " ),\n", 409 | ")" 410 | ] 411 | }, 412 | "execution_count": 13, 413 | "metadata": {}, 414 | "output_type": "execute_result" 415 | } 416 | ], 417 | "source": [ 418 | "net.set_params(max_epochs=5, verbose=False, train_split=False, callbacks=[])\n", 419 | "params = {\n", 420 | " 'module__dropout': [0, 0.5, 0.8], # You an expand this section, and all combinations will be tested\n", 421 | "}\n", 422 | "net.initialize()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 14, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "gs = GridSearchCV(net, param_grid=params, scoring='accuracy', verbose=1, cv=3)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 15, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "mnist_train_sliceable = SliceDataset(mnist_train_data) # Helper class that wraps a torch dataset to make it work with sklearn." 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 16, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "Fitting 3 folds for each of 3 candidates, totalling 9 fits\n" 453 | ] 454 | }, 455 | { 456 | "name": "stderr", 457 | "output_type": "stream", 458 | "text": [ 459 | "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", 460 | "[Parallel(n_jobs=1)]: Done 9 out of 9 | elapsed: 4.4min finished\n" 461 | ] 462 | }, 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "GridSearchCV(cv=3, error_score='raise-deprecating',\n", 467 | " estimator=[initialized](\n", 468 | " module_=ConvNet(\n", 469 | " (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", 470 | " (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))\n", 471 | " (conv2_drop): Dropout2d(p=0.5, inplace=False)\n", 472 | " (fc1): Linear(in_features=320, out_features=50, bias=True)\n", 473 | " (fc2): Linear(in_features=50, out_features=10, bias=True)\n", 474 | " ),\n", 475 | "),\n", 476 | " iid='warn', n_jobs=None,\n", 477 | " param_grid={'module__dropout': [0, 0.5, 0.8]},\n", 478 | " pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n", 479 | " scoring='accuracy', verbose=1)" 480 | ] 481 | }, 482 | "execution_count": 16, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "gs.fit(mnist_train_sliceable, y_train)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 17, 494 | "metadata": {}, 495 | "outputs": [ 496 | { 497 | "data": { 498 | "text/plain": [ 499 | "{'module__dropout': 0}" 500 | ] 501 | }, 502 | "execution_count": 17, 503 | "metadata": {}, 504 | "output_type": "execute_result" 505 | } 506 | ], 507 | "source": [ 508 | "gs.best_params_" 509 | ] 510 | } 511 | ], 512 | "metadata": { 513 | "kernelspec": { 514 | "display_name": "Python 3", 515 | "language": "python", 516 | "name": "python3" 517 | }, 518 | "language_info": { 519 | "codemirror_mode": { 520 | "name": "ipython", 521 | "version": 3 522 | }, 523 | "file_extension": ".py", 524 | "mimetype": "text/x-python", 525 | "name": "python", 526 | "nbconvert_exporter": "python", 527 | "pygments_lexer": "ipython3", 528 | "version": "3.7.4" 529 | } 530 | }, 531 | "nbformat": 4, 532 | "nbformat_minor": 2 533 | } 534 | -------------------------------------------------------------------------------- /tensorflow_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "## Introduction to TensorFlow ## \n", 10 | "\n", 11 | "(Much of this material is originally from cs224d TensorFlow tutorial by Bharath Ramsundar)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "![TensorFlow logo](tutorial_images/tensorflow.png)\n", 19 | "\n", 20 | "TensorFlow provides primitives for defining functions on tensors and automatically computing their derivatives. \n", 21 | "\n", 22 | "* TensorFlow is a deep learning library for Python that has been recently open-sourced by Google. \n", 23 | "* TensorFlow has better support for distributed systems than many other competing libraries (i.e. Theano). \n", 24 | "* Keras (next tutorial) is a high-level library that builds on TensorFlow. \n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## What is a tensor? ##" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "![tensor definition](tutorial_images/tensor_definition.png) " 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## There are some similarities between TensorFlow and Numpy ##" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "* Both TensorFlow and Numpy are N-d array libraries \n", 53 | "* Numpy does not have methods to create tensor functions and automatically compute derivatives. \n", 54 | "* Numpy does not have GPU support, but TensorFlow does. " 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Numpy: ###\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import numpy as np " 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "a=np.zeros((2,2)); b=np.ones((2,2))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "np.sum(b,axis=1)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "a.shape" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "np.reshape(a,(1,4))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### Same commands in TensorFlow:###" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "import tensorflow as tf" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "tf.compat.v1.InteractiveSession()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "collapsed": true 138 | }, 139 | "source": [ 140 | "We just created an interactive Session. A Session object encapsulates the environment in which tensors are evaluated. " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | " a = tf.zeros((2,2)); b = tf.ones((2,2))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | " tf.reduce_sum(b, axis=1).numpy()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "help(tf.reduce_sum)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | " a.get_shape()" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "We see above that TensorShape behaves like a Python tuple. " 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | " tf.reshape(a, (1, 4)).eval()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "We can build a Numpy to TensorFlow dictionary: \n", 200 | "![Numpy To TensorFlow dictionary](tutorial_images/numpy_to_tensorflow.png)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "## TensorFlow requires explicit evaluation ##\n", 208 | "TensorFlow computations define a computation graph that has no value until evaluated. Specifically TensorFlow programs usually have two phases: \n", 209 | "\n", 210 | "* construction phase -- assembles the computation graph \n", 211 | "* evaluation phase -- uses a Session to execute operations in the graph ; all computations add nodes to the global default graph. " 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "#in Numpy: \n", 221 | "a=np.zeros((2,2))\n", 222 | "print(a)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "#but in TensorFlow\n", 232 | "ta=tf.zeros((2,2))\n", 233 | "print(ta)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "#now, we evaluate the computation graph: \n", 243 | "print(ta.numpy())" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "## More on Sessions ##" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "\n", 260 | "with tf.compat.v1.Session() as sess: \n", 261 | " a=tf.constant(5.0)\n", 262 | " b=tf.constant(6.0)\n", 263 | " c=a*b \n", 264 | " print(sess.run(c))\n", 265 | " print(c)\n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "```tf.compat.v1.InteractiveSession()``` is convenient syntax for keeping a default session open in iPython. " 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "## Variables ##\n", 280 | "\n", 281 | "Variables are in-memory buffers that contain tensors. They are used to hold and update parameters when a model is trained. " 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "\n", 291 | "\n", 292 | "with tf.compat.v1.Session() as sess:\n", 293 | " W1 = tf.ones((2,2))\n", 294 | " W2 = tf.Variable(tf.zeros((2,2)), name=\"weights\") \n", 295 | " print(sess.run(W1))\n", 296 | " sess.run(tf.global_variables_initializer())\n", 297 | " print(sess.run(W2))" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "Unlike constant tensors, TensorFlow variables must be initialized before they have values. " 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "#variable objects can be initialized from either constants or random values: \n", 314 | "W=tf.Variable(tf.zeros((2,2)), name=\"weights\") # initialized from zero values \n", 315 | "R=tf.Variable(tf.random_normal((2,2)), name=\"random_weights\") #initialized from random values \n", 316 | "\n", 317 | "#initialize all variables with values specified above: \n", 318 | "with tf.Session() as sess: \n", 319 | " sess.run(tf.global_variables_initializer())\n", 320 | " print(sess.run(W))\n", 321 | " print(sess.run(R))\n", 322 | " " 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "Updating variable state:" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "state = tf.Variable(0, name=\"counter\")\n", 339 | "\n", 340 | "#new_value = state + 1\n", 341 | "new_value = tf.add(state, tf.constant(1))\n", 342 | "\n", 343 | "#state=new_value\n", 344 | "update = tf.assign(state, new_value)\n", 345 | "\n", 346 | "with tf.Session() as sess:\n", 347 | " #state=0 \n", 348 | " sess.run(tf.global_variables_initializer())\n", 349 | " #print(state)\n", 350 | " print(sess.run(state))\n", 351 | " for _ in range(3):\n", 352 | " #state=state+1\n", 353 | " sess.run(update)\n", 354 | " #print(state)\n", 355 | " print(sess.run(state))" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "Fetching variable state: \n", 363 | "\n", 364 | "* Calling ```sess.run(var)``` on a ```tf.Session()``` object retrieves its value. \n", 365 | "* We can retrieve multiple variables simultaneously with ```sess.run([var1,var2])```\n", 366 | "\n", 367 | "For example, let's evaluate the following computational graph: \n", 368 | "![Computation Graph Eval Example](tutorial_images/comp_graph_eval.png) \n" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "input1 = tf.constant(3.0)\n", 378 | "input2 = tf.constant(2.0)\n", 379 | "input3 = tf.constant(5.0)\n", 380 | "intermed = tf.add(input2, input3)\n", 381 | "prod = tf.multiply(input1, intermed)\n", 382 | "with tf.Session() as sess:\n", 383 | " result = sess.run([prod, intermed])\n", 384 | " print(result)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "Data inputs to TensorFlow: " 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "#importing data from a numpy array with \"convert_to_tensor\" function \n", 401 | "a=np.zeros((3,3))\n", 402 | "ta=tf.convert_to_tensor(a)\n", 403 | "with tf.Session() as sess: \n", 404 | " print(sess.run(ta))" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "A more scalable approach: \n", 412 | "* use ```tf.placeholder``` variablesl (dummy nodes that provide entry points for data to the computational graph) \n", 413 | "* a ```feed_dict``` is a Python dictionary mapping from ```tf.placeholder``` variables to data \n", 414 | "\n", 415 | "![placeholders and feed forward dictionaries](tutorial_images/placeholder_feedforward_dict.png)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "#define placeholder objects for data entry \n", 425 | "input1 = tf.placeholder(tf.float32)\n", 426 | "input2 = tf.placeholder(tf.float32)\n", 427 | "\n", 428 | "output = tf.multiply(input1,input2)\n", 429 | "with tf.Session() as sess: \n", 430 | " #fetch value of output from computational graph and \n", 431 | " #feed data into the computational graph \n", 432 | " print(sess.run([output], feed_dict={input1:[7.],input2:[2.]}))\n", 433 | " " 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "Variable scope is necessary to avoid name clashes between variables in complex models. \n", 441 | "* ```tf.variable_scope()``` provides simple name-spacing \n", 442 | "* ```tf.get_variable()``` creates/accesses variables from within a variable scope " 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "#setting a variable's scope adds the corresponding prefix to the variable name \n", 452 | "with tf.variable_scope(\"foo\",reuse=None):\n", 453 | " with tf.variable_scope(\"bar\",reuse=None):\n", 454 | " v = tf.get_variable(\"v\", [1])\n", 455 | "assert v.name == \"foo/bar/v:0\"" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "with tf.variable_scope(\"foo\",reuse=None):\n", 465 | " v = tf.get_variable(\"v\", [1])\n", 466 | " tf.get_variable_scope().reuse_variables()\n", 467 | " v1 = tf.get_variable(\"v\", [1])\n", 468 | "assert v1 == v" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "```get_variable()``` will behave differently depending on whether or not reuse is enabled." 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "#case 1: reuse is set to false \n", 485 | "# A new variable is created and returned -- but this will give an error if the variable already exists in this scope, \n", 486 | "#as is the case here \n", 487 | "\n", 488 | "#with tf.variable_scope(\"foo\"): \n", 489 | "# v=tf.get_variable(\"v\", [1])\n", 490 | "#assert v.name==\"foo/v:0\"" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "#case 2: reuse is set to true \n", 500 | "# search for existing variable with a given name \n", 501 | "#raise ValueError if none is found \n", 502 | "with tf.variable_scope(\"foo\", reuse=True):\n", 503 | " v1 = tf.get_variable(\"v\", [1])\n", 504 | "assert v1 == v" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "TensorFlow supports auto-differentiation to compute gradients without user input.\n", 512 | "* ```tf.train.Optimizer``` creates an optimizer. \n", 513 | "* ```tf.train.Optimizer.minimize(loss, var_list)``` adds optimization operation to the computation graph. " 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "Check out TensorBoard for visualizing the computational graph and training metrics: https://www.tensorflow.org/versions/r0.11/how_tos/summaries_and_tensorboard/index.html\n" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": {}, 526 | "source": [ 527 | "## MNIST ConvNet Example ##" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "A Convolutional Network implementation example using TensorFlow library.\n", 535 | "This example is using the MNIST database of handwritten digits\n", 536 | "(http://yann.lecun.com/exdb/mnist/)\n", 537 | "\n", 538 | "Author: Aymeric Damien\n", 539 | "Project: https://github.com/aymericdamien/TensorFlow-Examples/\n" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [] 548 | } 549 | ], 550 | "metadata": { 551 | "kernelspec": { 552 | "display_name": "Python 3", 553 | "language": "python", 554 | "name": "python3" 555 | }, 556 | "language_info": { 557 | "codemirror_mode": { 558 | "name": "ipython", 559 | "version": 3 560 | }, 561 | "file_extension": ".py", 562 | "mimetype": "text/x-python", 563 | "name": "python", 564 | "nbconvert_exporter": "python", 565 | "pygments_lexer": "ipython3", 566 | "version": "3.8.3" 567 | } 568 | }, 569 | "nbformat": 4, 570 | "nbformat_minor": 1 571 | } 572 | -------------------------------------------------------------------------------- /tutorial_images/CTCF.Tut4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/CTCF.Tut4.png -------------------------------------------------------------------------------- /tutorial_images/CTCF_known1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/CTCF_known1.png -------------------------------------------------------------------------------- /tutorial_images/ChangeRuntime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/ChangeRuntime.png -------------------------------------------------------------------------------- /tutorial_images/GenomeWideModel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/GenomeWideModel.png -------------------------------------------------------------------------------- /tutorial_images/MultiLayerTraining.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/MultiLayerTraining.png -------------------------------------------------------------------------------- /tutorial_images/RunAllCollab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/RunAllCollab.png -------------------------------------------------------------------------------- /tutorial_images/RunCellArrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/RunCellArrow.png -------------------------------------------------------------------------------- /tutorial_images/RuntimeType.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/RuntimeType.png -------------------------------------------------------------------------------- /tutorial_images/SIX5_known1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/SIX5_known1.png -------------------------------------------------------------------------------- /tutorial_images/SPI1.Tut4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/SPI1.Tut4.png -------------------------------------------------------------------------------- /tutorial_images/SPIB.Kat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/SPIB.Kat.png -------------------------------------------------------------------------------- /tutorial_images/SimArch1Layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/SimArch1Layer.png -------------------------------------------------------------------------------- /tutorial_images/TAL1_known4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/TAL1_known4.png -------------------------------------------------------------------------------- /tutorial_images/ZNF143_known2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/ZNF143_known2.png -------------------------------------------------------------------------------- /tutorial_images/classification_task.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/classification_task.jpg -------------------------------------------------------------------------------- /tutorial_images/comp_graph_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/comp_graph_eval.png -------------------------------------------------------------------------------- /tutorial_images/dnn_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/dnn_figure.png -------------------------------------------------------------------------------- /tutorial_images/dragonn_and_pssm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/dragonn_and_pssm.jpg -------------------------------------------------------------------------------- /tutorial_images/dragonn_model_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/dragonn_model_figure.jpg -------------------------------------------------------------------------------- /tutorial_images/heterodimer_simulation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/heterodimer_simulation.jpg -------------------------------------------------------------------------------- /tutorial_images/homotypic_motif_density_localization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/homotypic_motif_density_localization.jpg -------------------------------------------------------------------------------- /tutorial_images/homotypic_motif_density_localization_task.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/homotypic_motif_density_localization_task.jpg -------------------------------------------------------------------------------- /tutorial_images/inspecting_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/inspecting_code.png -------------------------------------------------------------------------------- /tutorial_images/multi-input-multi-output-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/multi-input-multi-output-graph.png -------------------------------------------------------------------------------- /tutorial_images/numpy_to_tensorflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/numpy_to_tensorflow.png -------------------------------------------------------------------------------- /tutorial_images/one_hot_encoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/one_hot_encoding.png -------------------------------------------------------------------------------- /tutorial_images/placeholder_feedforward_dict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/placeholder_feedforward_dict.png -------------------------------------------------------------------------------- /tutorial_images/play_all_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/play_all_button.png -------------------------------------------------------------------------------- /tutorial_images/play_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/play_button.png -------------------------------------------------------------------------------- /tutorial_images/sequence_properties_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/sequence_properties_1.jpg -------------------------------------------------------------------------------- /tutorial_images/sequence_properties_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/sequence_properties_2.jpg -------------------------------------------------------------------------------- /tutorial_images/sequence_properties_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/sequence_properties_3.jpg -------------------------------------------------------------------------------- /tutorial_images/sequence_properties_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/sequence_properties_4.jpg -------------------------------------------------------------------------------- /tutorial_images/sequence_simulations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/sequence_simulations.png -------------------------------------------------------------------------------- /tutorial_images/tensor_definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/tensor_definition.png -------------------------------------------------------------------------------- /tutorial_images/tensorflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/tensorflow.png -------------------------------------------------------------------------------- /tutorial_images/tf_binding.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/cs273b/d142b80d28a37c4592dc7e07d396b6c58ee1eb4d/tutorial_images/tf_binding.jpg --------------------------------------------------------------------------------