├── .github └── FUNDING.yml ├── A_Gentle_Introduction_to_PyTorch_1_2.ipynb ├── LICENSE ├── README.md ├── RNN_PT.ipynb ├── nn.ipynb ├── pytorch_hello_world.ipynb ├── pytorch_logistic_regression.ipynb └── pytorch_quick_start.ipynb /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [dair-ai] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: dairai 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zaid Alyafeai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Notebooks 2 | A collection of PyTorch notebooks for studying and practicing deep learning. Each notebook contains a set of exercises that are specifically designed to engage and encourage the learner to conduct more research and experiments. (Work in progress!) 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 160 | 161 | 162 | 163 | 164 |
NameDescriptionCategoryLevelNotebook Blog
Implementing a Logistic Regression Model from ScratchLearn how to implement the fundamental building blocks of a neural network using PyTorch.Machine LearningBeginner 21 | 22 | read
PyTorch Hello WorldCreate a hello world for deep learning using PyTorch.Deep LearningBeginner 33 | 34 | read
PyTorch QuickstartLearn about PyTorch's basic building blocks to build and train a CNN model for image classification.Image ClassificationIntermediate 44 | 45 | read
A Gentle Introduction to PyTorch 1.2This comprehensive tutorial aims to introduce the fundamentals of PyTorch building blocks for training neural networks.Neural NetworksBeginner 55 | 56 | read
Building RNNs is Fun with PyTorchThis notebook teaches you how to build a recurrent neural network (RNN) with a single layer, consisting of one single neuron. It also teaches how to implement a simple RNN-based model for image classification.Neural NetworksBeginner 66 | 67 | read
A Simple Neural Network from Scratch with PyTorch and Google ColabIn this tutorial we implement a simple neural network from scratch using PyTorch.Neural NetworksBeginner 77 | 78 | read
NLP BasicsIn this tutorial we show the basics of preparing your textual data for NLP.NLPBeginner 88 | 89 | coming soon!
Deep Learning for NLPIn this notebook we are going to use deep learning (RNN model) for approaching NLP tasks.Deep Learning NLPBeginner 100 | 101 | coming soon!
Neural Machine Translation with Attention using PyTorchIn this notebook we are going to perform neural machine translation using a deep learning based approach and attention mechanism.Deep Learning NLPAdvanced 112 | 113 | coming soon!
Fine-tuning BERT Language Model for English Sentiment ClassificationIn this tutorial we demonstrate how to fine-tune BERT-based model for sentiment classification.Deep Learning NLPIntermediate 124 | 125 | coming soon!
Fine-tuning BERT Language Model for English Emotion ClassificationIn this tutorial we demonstrate how to fine-tune BERT-based model for multiclass emotion classification.Deep Learning NLPAdvanced 136 | 137 | coming soon!
Text Similarity Search using Pretrained Language ModelsIn this tutorial we show how to build a simple text similarity search application using pretrained language models and Elasticsearch.Deep Learning NLP ApplicationsAdvanced 147 | 148 | coming soon!
Spinal Cord Gray Matter Segmentation Using PyTorchIn this notebook we are going to explore a medical imaging open-source library known as MedicalTorch, which was built on top of PyTorch.Deep Learning in MedicineAdvanced 158 | 159 | coming soon!
165 | 166 | -------------------------------------------------------------------------------- /nn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "nn.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "Ee4B4v5tAp1C", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "## A Simple Neural Network from Scratch with PyTorch and Google Colab" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "id": "w4cEhtf_Ap1E", 42 | "colab_type": "text" 43 | }, 44 | "source": [ 45 | "In this tutorial we will implement a simple neural network from scratch using PyTorch. The idea of the tutorial is to teach you the basics of PyTorch and how it can be used to implement a neural network from scratch. I will go over some of the basic functionalities and concepts available in PyTorch that will allow you to build your own neural networks. \n", 46 | "\n", 47 | "This tutorial assumes you have prior knowledge of how a neural network works. Don’t worry! Even if you are not so sure, you will be okay. For advanced PyTorch users, this tutorial may still serve as a refresher. This tutorial is heavily inspired by this [Neural Network implementation](https://repl.it/talk/announcements/Build-a-Neural-Network-in-Python/5457) coded purely using Numpy. In fact, I tried re-implementing the code using PyTorch instead and added my own intuitions and explanations. Thanks to [Samay](https://repl.it/@shamdasani) for his phenomenal work, I hope this inspires many others as it did with me.\n", 48 | "\n", 49 | "Since we are working on Google Colab, we will need to install the PyTorch library. You can do this by using the following command:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "SpBiWQF2BrJK", 56 | "colab_type": "code", 57 | "outputId": "858d4853-ca0d-4d5b-f61c-10481b46f309", 58 | "colab": { 59 | "base_uri": "https://localhost:8080/", 60 | "height": 326 61 | } 62 | }, 63 | "source": [ 64 | "!pip3 install torch torchvision" 65 | ], 66 | "execution_count": 0, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "text": [ 71 | "Collecting torch\n", 72 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7e/60/66415660aa46b23b5e1b72bc762e816736ce8d7260213e22365af51e8f9c/torch-1.0.0-cp36-cp36m-manylinux1_x86_64.whl (591.8MB)\n", 73 | "\u001b[K 100% |████████████████████████████████| 591.8MB 26kB/s \n", 74 | "tcmalloc: large alloc 1073750016 bytes == 0x61f82000 @ 0x7f400bb202a4 0x591a07 0x5b5d56 0x502e9a 0x506859 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x504c28 0x502540 0x502f3d 0x507641\n", 75 | "\u001b[?25hCollecting torchvision\n", 76 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)\n", 77 | "\u001b[K 100% |████████████████████████████████| 61kB 23.4MB/s \n", 78 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.6)\n", 79 | "Collecting pillow>=4.1.1 (from torchvision)\n", 80 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/92/e3/217dfd0834a51418c602c96b110059c477260c7fee898542b100913947cf/Pillow-5.4.0-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)\n", 81 | "\u001b[K 100% |████████████████████████████████| 2.0MB 6.8MB/s \n", 82 | "\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)\n", 83 | "Installing collected packages: torch, pillow, torchvision\n", 84 | " Found existing installation: Pillow 4.0.0\n", 85 | " Uninstalling Pillow-4.0.0:\n", 86 | " Successfully uninstalled Pillow-4.0.0\n", 87 | "Successfully installed pillow-5.4.0 torch-1.0.0 torchvision-0.2.1\n" 88 | ], 89 | "name": "stdout" 90 | } 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": { 96 | "id": "MP9ewMSlC7JU", 97 | "colab_type": "text" 98 | }, 99 | "source": [ 100 | "\n", 101 | "The `torch` module provides all the necessary **tensor** operators you will need to implement your first neural network from scratch in PyTorch. That's right! In PyTorch everything is a Tensor, so this is the first thing you will need to get used to." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "bKmXKSQnAp1G", 108 | "colab_type": "code", 109 | "colab": {} 110 | }, 111 | "source": [ 112 | "import torch\n", 113 | "import torch.nn as nn" 114 | ], 115 | "execution_count": 0, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "1EWBBl1nAp1M", 122 | "colab_type": "text" 123 | }, 124 | "source": [ 125 | "## Data\n", 126 | "Let's start by creating some sample data using the `torch.tensor` command. In Numpy, this could be done with `np.array`. Both functions serve the same purpose, but in PyTorch everything is a Tensor as opposed to a vector or matrix. We define types in PyTorch using the `dtype=torch.xxx` command. \n", 127 | "\n", 128 | "In the data below, `X` represents the amount of hours studied and how much time students spent sleeping, whereas `y` represent grades. The variable `xPredicted` is a single input for which we want to predict a grade using the parameters learned by the neural network. Remember, the neural network wants to learn a mapping between `X` and `y`, so it will try to take a guess from what it has learned from the training data. " 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "fsAVbHnjAp1P", 135 | "colab_type": "code", 136 | "colab": {} 137 | }, 138 | "source": [ 139 | "X = torch.tensor(([2, 9], [1, 5], [3, 6]), dtype=torch.float) # 3 X 2 tensor\n", 140 | "y = torch.tensor(([92], [100], [89]), dtype=torch.float) # 3 X 1 tensor\n", 141 | "xPredicted = torch.tensor(([4, 8]), dtype=torch.float) # 1 X 2 tensor" 142 | ], 143 | "execution_count": 0, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "RC0ru9kCAp1U", 150 | "colab_type": "text" 151 | }, 152 | "source": [ 153 | "You can check the size of the tensors we have just created with the `size` command. This is equivalent to the `shape` command used in tools such as Numpy and Tensorflow. " 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "sfC-B1BEAp1W", 160 | "colab_type": "code", 161 | "outputId": "d2ec7994-41ad-41fa-a69c-c0b7123ef7cd", 162 | "colab": { 163 | "base_uri": "https://localhost:8080/", 164 | "height": 51 165 | } 166 | }, 167 | "source": [ 168 | "print(X.size())\n", 169 | "print(y.size())" 170 | ], 171 | "execution_count": 0, 172 | "outputs": [ 173 | { 174 | "output_type": "stream", 175 | "text": [ 176 | "torch.Size([3, 2])\n", 177 | "torch.Size([3, 1])\n" 178 | ], 179 | "name": "stdout" 180 | } 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "id": "zrND9MS9Ap1f", 187 | "colab_type": "text" 188 | }, 189 | "source": [ 190 | "## Scaling\n", 191 | "\n", 192 | "Below we are performing some scaling on the sample data. Notice that the `max` function returns both a tensor and the corresponding indices. So we use `_` to capture the indices which we won't use here because we are only interested in the max values to conduct the scaling. Perfect! Our data is now in a very nice format our neural network will appreciate later on. " 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "hlBvtfAmAp1i", 199 | "colab_type": "code", 200 | "outputId": "23e1d24b-fa29-4173-f884-f44bc8a48cea", 201 | "colab": { 202 | "base_uri": "https://localhost:8080/", 203 | "height": 34 204 | } 205 | }, 206 | "source": [ 207 | "# scale units\n", 208 | "X_max, _ = torch.max(X, 0)\n", 209 | "xPredicted_max, _ = torch.max(xPredicted, 0)\n", 210 | "\n", 211 | "X = torch.div(X, X_max)\n", 212 | "xPredicted = torch.div(xPredicted, xPredicted_max)\n", 213 | "y = y / 100 # max test score is 100\n", 214 | "print(xPredicted)" 215 | ], 216 | "execution_count": 0, 217 | "outputs": [ 218 | { 219 | "output_type": "stream", 220 | "text": [ 221 | "tensor([0.5000, 1.0000])\n" 222 | ], 223 | "name": "stdout" 224 | } 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": { 230 | "id": "R1kTs5S5Ap1m", 231 | "colab_type": "text" 232 | }, 233 | "source": [ 234 | "Notice that there are two functions `max` and `div` that I didn't discuss above. They do exactly what they imply: `max` finds the maximum value in a vector... I mean tensor; and `div` is basically a nice little function to divide two tensors. " 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": { 240 | "id": "xRvMSpEFAp1n", 241 | "colab_type": "text" 242 | }, 243 | "source": [ 244 | "## Model (Computation Graph)\n", 245 | "Once the data has been processed and it is in the proper format, all you need to do now is to define your model. Here is where things begin to change a little as compared to how you would build your neural networks using, say, something like Keras or Tensorflow. However, you will realize quickly as you go along that PyTorch doesn't differ much from other deep learning tools. At the end of the day we are constructing a computation graph, which is used to dictate how data should flow and what type of operations are performed on this information. \n", 246 | "\n", 247 | "For illustration purposes, we are building the following neural network or computation graph:\n", 248 | "\n", 249 | "\n", 250 | "![alt text](https://drive.google.com/uc?export=view&id=1l-sKpcCJCEUJV1BlAqcVAvLXLpYCInV6)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "C7pDC5SfAp1p", 257 | "colab_type": "code", 258 | "colab": {} 259 | }, 260 | "source": [ 261 | "class Neural_Network(nn.Module):\n", 262 | " def __init__(self, ):\n", 263 | " super(Neural_Network, self).__init__()\n", 264 | " # parameters\n", 265 | " # TODO: parameters can be parameterized instead of declaring them here\n", 266 | " self.inputSize = 2\n", 267 | " self.outputSize = 1\n", 268 | " self.hiddenSize = 3\n", 269 | " \n", 270 | " # weights\n", 271 | " self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 3 X 2 tensor\n", 272 | " self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor\n", 273 | " \n", 274 | " def forward(self, X):\n", 275 | " self.z = torch.matmul(X, self.W1) # 3 X 3 \".dot\" does not broadcast in PyTorch\n", 276 | " self.z2 = self.sigmoid(self.z) # activation function\n", 277 | " self.z3 = torch.matmul(self.z2, self.W2)\n", 278 | " o = self.sigmoid(self.z3) # final activation function\n", 279 | " return o\n", 280 | " \n", 281 | " def sigmoid(self, s):\n", 282 | " return 1 / (1 + torch.exp(-s))\n", 283 | " \n", 284 | " def sigmoidPrime(self, s):\n", 285 | " # derivative of sigmoid\n", 286 | " return s * (1 - s)\n", 287 | " \n", 288 | " def backward(self, X, y, o):\n", 289 | " self.o_error = y - o # error in output\n", 290 | " self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error\n", 291 | " self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2))\n", 292 | " self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)\n", 293 | " self.W1 += torch.matmul(torch.t(X), self.z2_delta)\n", 294 | " self.W2 += torch.matmul(torch.t(self.z2), self.o_delta)\n", 295 | " \n", 296 | " def train(self, X, y):\n", 297 | " # forward + backward pass for training\n", 298 | " o = self.forward(X)\n", 299 | " self.backward(X, y, o)\n", 300 | " \n", 301 | " def saveWeights(self, model):\n", 302 | " # we will use the PyTorch internal storage functions\n", 303 | " torch.save(model, \"NN\")\n", 304 | " # you can reload model with all the weights and so forth with:\n", 305 | " # torch.load(\"NN\")\n", 306 | " \n", 307 | " def predict(self):\n", 308 | " print (\"Predicted data based on trained weights: \")\n", 309 | " print (\"Input (scaled): \\n\" + str(xPredicted))\n", 310 | " print (\"Output: \\n\" + str(self.forward(xPredicted)))\n", 311 | " " 312 | ], 313 | "execution_count": 0, 314 | "outputs": [] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "qm5gimnyAp1s", 320 | "colab_type": "text" 321 | }, 322 | "source": [ 323 | "For the purpose of this tutorial, we are not going to be talking math stuff, that's for another day. I just want you to get a gist of what it takes to build a neural network from scratch using PyTorch. Let's break down the model which was declared via the class above. \n", 324 | "\n", 325 | "## Class Header\n", 326 | "First, we defined our model via a class because that is the recommended way to build the computation graph. The class header contains the name of the class `Neural Network` and the parameter `nn.Module` which basically indicates that we are defining our own neural network. \n", 327 | "\n", 328 | "```python\n", 329 | "class Neural_Network(nn.Module):\n", 330 | "```\n", 331 | "\n", 332 | "## Initialization\n", 333 | "The next step is to define the initializations ( `def __init__(self,)`) that will be performed upon creating an instance of the customized neural network. You can declare the parameters of your model here, but typically, you would declare the structure of your network in this section -- the size of the hidden layers and so forth. Since we are building the neural network from scratch, we explicitly declared the size of the weights matrices: one that stores the parameters from the input to hidden layer; and one that stores the parameter from the hidden to output layer. Both weight matrices are initialized with values randomly chosen from a normal distribution via `torch.randn(...)`. Note that we are not using bias just to keep things as simple as possible. \n", 334 | "\n", 335 | "```python\n", 336 | "def __init__(self, ):\n", 337 | " super(Neural_Network, self).__init__()\n", 338 | " # parameters\n", 339 | " # TODO: parameters can be parameterized instead of declaring them here\n", 340 | " self.inputSize = 2\n", 341 | " self.outputSize = 1\n", 342 | " self.hiddenSize = 3\n", 343 | "\n", 344 | " # weights\n", 345 | " self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 3 X 2 tensor\n", 346 | " self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor\n", 347 | "```\n", 348 | "\n", 349 | "## The Forward Function\n", 350 | "The `forward` function is where all the magic happens (see below). This is where the data enters and is fed into the computation graph (i.e., the neural network structure we have built). Since we are building a simple neural network with one hidden layer, our forward function looks very simple:\n", 351 | "\n", 352 | "```python\n", 353 | "def forward(self, X):\n", 354 | " self.z = torch.matmul(X, self.W1) \n", 355 | " self.z2 = self.sigmoid(self.z) # activation function\n", 356 | " self.z3 = torch.matmul(self.z2, self.W2)\n", 357 | " o = self.sigmoid(self.z3) # final activation function\n", 358 | " return o\n", 359 | "```\n", 360 | "\n", 361 | "The `forward` function above takes the input `X`and then performs a matrix multiplication (`torch.matmul(...)`) with the first weight matrix `self.W1`. Then the result is applied an activation function, `sigmoid`. The resulting matrix of the activation is then multiplied with the second weight matrix `self.W2`. Then another activation if performed, which renders the output of the neural network or computation graph. The process I described above is simply what's known as a `feedforward pass`. In order for the weights to optimize when training, we need a backpropagation algorithm. \n", 362 | "\n", 363 | "## The Backward Function\n", 364 | "The `backward` function contains the backpropagation algorithm, where the goal is to essentially minimize the loss with respect to our weights. In other words, the weights need to be updated in such a way that the loss decreases while the neural network is training (well, that is what we hope for). All this magic is possible with the gradient descent algorithm which is declared in the `backward` function. Take a minute or two to inspect what is happening in the code below:\n", 365 | "\n", 366 | "```python\n", 367 | "def backward(self, X, y, o):\n", 368 | " self.o_error = y - o # error in output\n", 369 | " self.o_delta = self.o_error * self.sigmoidPrime(o) \n", 370 | " self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2))\n", 371 | " self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)\n", 372 | " self.W1 += torch.matmul(torch.t(X), self.z2_delta)\n", 373 | " self.W2 += torch.matmul(torch.t(self.z2), self.o_delta)\n", 374 | "```\n", 375 | "\n", 376 | "Notice that we are performing a lot of matrix multiplications along with the transpose operations via the `torch.matmul(...)` and `torch.t(...)` operations, respectively. The rest is simply gradient descent -- there is nothing to it." 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": { 382 | "id": "9t26Dr5zAp1u", 383 | "colab_type": "text" 384 | }, 385 | "source": [ 386 | "## Training\n", 387 | "All that is left now is to train the neural network. First we create an instance of the computation graph we have just built:\n", 388 | "\n", 389 | "```python\n", 390 | "NN = Neural_Network()\n", 391 | "```\n", 392 | "\n", 393 | "Then we train the model for `1000` rounds. Notice that in PyTorch `NN(X)` automatically calls the `forward` function so there is no need to explicitly call `NN.forward(X)`. \n", 394 | "\n", 395 | "After we have obtained the predicted output for ever round of training, we compute the loss, with the following code:\n", 396 | "\n", 397 | "```python\n", 398 | "torch.mean((y - NN(X))**2).detach().item()\n", 399 | "```\n", 400 | "\n", 401 | "The next step is to start the training (foward + backward) via `NN.train(X, y)`. After we have trained the neural network, we can store the model and output the predicted value of the single instance we declared in the beginning, `xPredicted`. \n", 402 | "\n", 403 | "Let's train!" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "metadata": { 409 | "id": "9sTddOpLAp1w", 410 | "colab_type": "code", 411 | "outputId": "a02d2b93-34da-4068-f1f2-843f1e30abf8", 412 | "colab": { 413 | "base_uri": "https://localhost:8080/", 414 | "height": 17156 415 | } 416 | }, 417 | "source": [ 418 | "NN = Neural_Network()\n", 419 | "for i in range(1000): # trains the NN 1,000 times\n", 420 | " print (\"#\" + str(i) + \" Loss: \" + str(torch.mean((y - NN(X))**2).detach().item())) # mean sum squared loss\n", 421 | " NN.train(X, y)\n", 422 | "NN.saveWeights(NN)\n", 423 | "NN.predict()" 424 | ], 425 | "execution_count": 0, 426 | "outputs": [ 427 | { 428 | "output_type": "stream", 429 | "text": [ 430 | "#0 Loss: 0.28770461678504944\n", 431 | "#1 Loss: 0.19437099993228912\n", 432 | "#2 Loss: 0.129642054438591\n", 433 | "#3 Loss: 0.08898762613534927\n", 434 | "#4 Loss: 0.0638350322842598\n", 435 | "#5 Loss: 0.04783045873045921\n", 436 | "#6 Loss: 0.037219222635030746\n", 437 | "#7 Loss: 0.029889358207583427\n", 438 | "#8 Loss: 0.024637090042233467\n", 439 | "#9 Loss: 0.020752854645252228\n", 440 | "#10 Loss: 0.01780204102396965\n", 441 | "#11 Loss: 0.015508432872593403\n", 442 | "#12 Loss: 0.013690348714590073\n", 443 | "#13 Loss: 0.012224685400724411\n", 444 | "#14 Loss: 0.011025689542293549\n", 445 | "#15 Loss: 0.0100322300568223\n", 446 | "#16 Loss: 0.009199750609695911\n", 447 | "#17 Loss: 0.008495191112160683\n", 448 | "#18 Loss: 0.007893583737313747\n", 449 | "#19 Loss: 0.007375772576779127\n", 450 | "#20 Loss: 0.006926907692104578\n", 451 | "#21 Loss: 0.006535270716995001\n", 452 | "#22 Loss: 0.006191555876284838\n", 453 | "#23 Loss: 0.005888286512345076\n", 454 | "#24 Loss: 0.005619380157440901\n", 455 | "#25 Loss: 0.0053798723965883255\n", 456 | "#26 Loss: 0.005165652371942997\n", 457 | "#27 Loss: 0.004973314236849546\n", 458 | "#28 Loss: 0.0048000202514231205\n", 459 | "#29 Loss: 0.004643348511308432\n", 460 | "#30 Loss: 0.00450127711519599\n", 461 | "#31 Loss: 0.004372074268758297\n", 462 | "#32 Loss: 0.004254247527569532\n", 463 | "#33 Loss: 0.004146536346524954\n", 464 | "#34 Loss: 0.004047831054776907\n", 465 | "#35 Loss: 0.003957169130444527\n", 466 | "#36 Loss: 0.0038737261202186346\n", 467 | "#37 Loss: 0.0037967758253216743\n", 468 | "#38 Loss: 0.0037256714422255754\n", 469 | "#39 Loss: 0.0036598537117242813\n", 470 | "#40 Loss: 0.003598827635869384\n", 471 | "#41 Loss: 0.0035421468783169985\n", 472 | "#42 Loss: 0.0034894247073680162\n", 473 | "#43 Loss: 0.003440307453274727\n", 474 | "#44 Loss: 0.0033944963943213224\n", 475 | "#45 Loss: 0.003351695602759719\n", 476 | "#46 Loss: 0.003311669686809182\n", 477 | "#47 Loss: 0.003274182789027691\n", 478 | "#48 Loss: 0.0032390293199568987\n", 479 | "#49 Loss: 0.0032060390803962946\n", 480 | "#50 Loss: 0.0031750358175486326\n", 481 | "#51 Loss: 0.0031458677258342505\n", 482 | "#52 Loss: 0.003118406282737851\n", 483 | "#53 Loss: 0.0030925225000828505\n", 484 | "#54 Loss: 0.0030680971685796976\n", 485 | "#55 Loss: 0.0030450366903096437\n", 486 | "#56 Loss: 0.003023233264684677\n", 487 | "#57 Loss: 0.0030026088934391737\n", 488 | "#58 Loss: 0.002983089303597808\n", 489 | "#59 Loss: 0.0029645822942256927\n", 490 | "#60 Loss: 0.00294703827239573\n", 491 | "#61 Loss: 0.0029303862247616053\n", 492 | "#62 Loss: 0.002914572134613991\n", 493 | "#63 Loss: 0.0028995368629693985\n", 494 | "#64 Loss: 0.0028852447867393494\n", 495 | "#65 Loss: 0.002871639095246792\n", 496 | "#66 Loss: 0.002858673455193639\n", 497 | "#67 Loss: 0.0028463276103138924\n", 498 | "#68 Loss: 0.0028345445170998573\n", 499 | "#69 Loss: 0.0028233081102371216\n", 500 | "#70 Loss: 0.0028125671669840813\n", 501 | "#71 Loss: 0.002802313072606921\n", 502 | "#72 Loss: 0.0027925113681703806\n", 503 | "#73 Loss: 0.002783131552860141\n", 504 | "#74 Loss: 0.0027741591911762953\n", 505 | "#75 Loss: 0.00276556215249002\n", 506 | "#76 Loss: 0.0027573201805353165\n", 507 | "#77 Loss: 0.002749415347352624\n", 508 | "#78 Loss: 0.002741842297837138\n", 509 | "#79 Loss: 0.0027345670387148857\n", 510 | "#80 Loss: 0.00272758468054235\n", 511 | "#81 Loss: 0.0027208721730858088\n", 512 | "#82 Loss: 0.002714422531425953\n", 513 | "#83 Loss: 0.002708215033635497\n", 514 | "#84 Loss: 0.0027022461872547865\n", 515 | "#85 Loss: 0.0026964957360178232\n", 516 | "#86 Loss: 0.002690958557650447\n", 517 | "#87 Loss: 0.0026856244076043367\n", 518 | "#88 Loss: 0.002680474892258644\n", 519 | "#89 Loss: 0.002675510011613369\n", 520 | "#90 Loss: 0.002670713933184743\n", 521 | "#91 Loss: 0.0026660896837711334\n", 522 | "#92 Loss: 0.0026616165414452553\n", 523 | "#93 Loss: 0.0026572979986667633\n", 524 | "#94 Loss: 0.0026531198527663946\n", 525 | "#95 Loss: 0.002649075584486127\n", 526 | "#96 Loss: 0.002645164029672742\n", 527 | "#97 Loss: 0.0026413705199956894\n", 528 | "#98 Loss: 0.0026377029716968536\n", 529 | "#99 Loss: 0.002634142292663455\n", 530 | "#100 Loss: 0.00263069081120193\n", 531 | "#101 Loss: 0.0026273438706994057\n", 532 | "#102 Loss: 0.0026240937877446413\n", 533 | "#103 Loss: 0.0026209382340312004\n", 534 | "#104 Loss: 0.002617868361994624\n", 535 | "#105 Loss: 0.002614888595417142\n", 536 | "#106 Loss: 0.0026119956746697426\n", 537 | "#107 Loss: 0.002609172137454152\n", 538 | "#108 Loss: 0.0026064326521009207\n", 539 | "#109 Loss: 0.002603760687634349\n", 540 | "#110 Loss: 0.00260116346180439\n", 541 | "#111 Loss: 0.002598624676465988\n", 542 | "#112 Loss: 0.0025961531791836023\n", 543 | "#113 Loss: 0.0025937433820217848\n", 544 | "#114 Loss: 0.0025913880672305822\n", 545 | "#115 Loss: 0.0025890925899147987\n", 546 | "#116 Loss: 0.002586849732324481\n", 547 | "#117 Loss: 0.002584656234830618\n", 548 | "#118 Loss: 0.0025825174525380135\n", 549 | "#119 Loss: 0.0025804194156080484\n", 550 | "#120 Loss: 0.0025783723685890436\n", 551 | "#121 Loss: 0.002576368162408471\n", 552 | "#122 Loss: 0.002574402838945389\n", 553 | "#123 Loss: 0.002572478959336877\n", 554 | "#124 Loss: 0.0025705902371555567\n", 555 | "#125 Loss: 0.0025687431916594505\n", 556 | "#126 Loss: 0.002566935494542122\n", 557 | "#127 Loss: 0.0025651559699326754\n", 558 | "#128 Loss: 0.002563410671427846\n", 559 | "#129 Loss: 0.0025617002975195646\n", 560 | "#130 Loss: 0.0025600148364901543\n", 561 | "#131 Loss: 0.0025583638343960047\n", 562 | "#132 Loss: 0.002556734485551715\n", 563 | "#133 Loss: 0.002555140992626548\n", 564 | "#134 Loss: 0.0025535663589835167\n", 565 | "#135 Loss: 0.0025520166382193565\n", 566 | "#136 Loss: 0.002550497418269515\n", 567 | "#137 Loss: 0.0025489996187388897\n", 568 | "#138 Loss: 0.002547516720369458\n", 569 | "#139 Loss: 0.0025460589677095413\n", 570 | "#140 Loss: 0.0025446258950978518\n", 571 | "#141 Loss: 0.0025432079564779997\n", 572 | "#142 Loss: 0.0025418128352612257\n", 573 | "#143 Loss: 0.0025404333136975765\n", 574 | "#144 Loss: 0.0025390759110450745\n", 575 | "#145 Loss: 0.002537728287279606\n", 576 | "#146 Loss: 0.0025364060420542955\n", 577 | "#147 Loss: 0.0025350917130708694\n", 578 | "#148 Loss: 0.002533797873184085\n", 579 | "#149 Loss: 0.002532513812184334\n", 580 | "#150 Loss: 0.0025312507059425116\n", 581 | "#151 Loss: 0.0025300011038780212\n", 582 | "#152 Loss: 0.0025287508033216\n", 583 | "#153 Loss: 0.0025275293737649918\n", 584 | "#154 Loss: 0.002526313764974475\n", 585 | "#155 Loss: 0.00252510909922421\n", 586 | "#156 Loss: 0.0025239146780222654\n", 587 | "#157 Loss: 0.0025227360893040895\n", 588 | "#158 Loss: 0.002521563321352005\n", 589 | "#159 Loss: 0.002520401030778885\n", 590 | "#160 Loss: 0.002519249450415373\n", 591 | "#161 Loss: 0.0025181034579873085\n", 592 | "#162 Loss: 0.0025169753935188055\n", 593 | "#163 Loss: 0.0025158498901873827\n", 594 | "#164 Loss: 0.0025147362612187862\n", 595 | "#165 Loss: 0.002513629151508212\n", 596 | "#166 Loss: 0.002512530190870166\n", 597 | "#167 Loss: 0.0025114361196756363\n", 598 | "#168 Loss: 0.0025103483349084854\n", 599 | "#169 Loss: 0.0025092700961977243\n", 600 | "#170 Loss: 0.0025081969797611237\n", 601 | "#171 Loss: 0.0025071338750422\n", 602 | "#172 Loss: 0.0025060747284442186\n", 603 | "#173 Loss: 0.0025050221011042595\n", 604 | "#174 Loss: 0.002503973664715886\n", 605 | "#175 Loss: 0.002502931747585535\n", 606 | "#176 Loss: 0.002501895884051919\n", 607 | "#177 Loss: 0.0025008656084537506\n", 608 | "#178 Loss: 0.00249984092079103\n", 609 | "#179 Loss: 0.002498818328604102\n", 610 | "#180 Loss: 0.002497798763215542\n", 611 | "#181 Loss: 0.0024967871140688658\n", 612 | "#182 Loss: 0.00249578058719635\n", 613 | "#183 Loss: 0.0024947759229689837\n", 614 | "#184 Loss: 0.0024937766138464212\n", 615 | "#185 Loss: 0.002492778468877077\n", 616 | "#186 Loss: 0.0024917826522141695\n", 617 | "#187 Loss: 0.0024907945189625025\n", 618 | "#188 Loss: 0.002489812206476927\n", 619 | "#189 Loss: 0.002488828031346202\n", 620 | "#190 Loss: 0.0024878503754734993\n", 621 | "#191 Loss: 0.0024868694599717855\n", 622 | "#192 Loss: 0.002485897159203887\n", 623 | "#193 Loss: 0.002484926488250494\n", 624 | "#194 Loss: 0.0024839574471116066\n", 625 | "#195 Loss: 0.0024829902686178684\n", 626 | "#196 Loss: 0.002482031239196658\n", 627 | "#197 Loss: 0.0024810675531625748\n", 628 | "#198 Loss: 0.002480114810168743\n", 629 | "#199 Loss: 0.00247915368527174\n", 630 | "#200 Loss: 0.0024782009422779083\n", 631 | "#201 Loss: 0.002477245405316353\n", 632 | "#202 Loss: 0.0024762984830886126\n", 633 | "#203 Loss: 0.002475348999723792\n", 634 | "#204 Loss: 0.002474398585036397\n", 635 | "#205 Loss: 0.0024734551552683115\n", 636 | "#206 Loss: 0.002472516382113099\n", 637 | "#207 Loss: 0.002471569227054715\n", 638 | "#208 Loss: 0.002470628125593066\n", 639 | "#209 Loss: 0.0024696916807442904\n", 640 | "#210 Loss: 0.002468749647960067\n", 641 | "#211 Loss: 0.0024678176268935204\n", 642 | "#212 Loss: 0.0024668758269399405\n", 643 | "#213 Loss: 0.002465949160978198\n", 644 | "#214 Loss: 0.0024650150444358587\n", 645 | "#215 Loss: 0.00246407906524837\n", 646 | "#216 Loss: 0.002463151467964053\n", 647 | "#217 Loss: 0.002462216652929783\n", 648 | "#218 Loss: 0.0024612878914922476\n", 649 | "#219 Loss: 0.002460360061377287\n", 650 | "#220 Loss: 0.0024594322312623262\n", 651 | "#221 Loss: 0.0024585050996392965\n", 652 | "#222 Loss: 0.002457576571032405\n", 653 | "#223 Loss: 0.0024566520005464554\n", 654 | "#224 Loss: 0.002455727430060506\n", 655 | "#225 Loss: 0.002454800298437476\n", 656 | "#226 Loss: 0.002453884808346629\n", 657 | "#227 Loss: 0.0024529551155865192\n", 658 | "#228 Loss: 0.002452034503221512\n", 659 | "#229 Loss: 0.002451109467074275\n", 660 | "#230 Loss: 0.0024501883890479803\n", 661 | "#231 Loss: 0.002449269639328122\n", 662 | "#232 Loss: 0.0024483499582856894\n", 663 | "#233 Loss: 0.002447424689307809\n", 664 | "#234 Loss: 0.0024465022142976522\n", 665 | "#235 Loss: 0.0024455797392874956\n", 666 | "#236 Loss: 0.0024446637835353613\n", 667 | "#237 Loss: 0.002443745033815503\n", 668 | "#238 Loss: 0.0024428225588053465\n", 669 | "#239 Loss: 0.0024419049732387066\n", 670 | "#240 Loss: 0.002440983895212412\n", 671 | "#241 Loss: 0.0024400672409683466\n", 672 | "#242 Loss: 0.002439146162942052\n", 673 | "#243 Loss: 0.0024382262490689754\n", 674 | "#244 Loss: 0.002437308896332979\n", 675 | "#245 Loss: 0.0024363857228308916\n", 676 | "#246 Loss: 0.002435472561046481\n", 677 | "#247 Loss: 0.0024345542769879103\n", 678 | "#248 Loss: 0.0024336313363164663\n", 679 | "#249 Loss: 0.00243271142244339\n", 680 | "#250 Loss: 0.00243179383687675\n", 681 | "#251 Loss: 0.0024308778811246157\n", 682 | "#252 Loss: 0.0024299558717757463\n", 683 | "#253 Loss: 0.0024290340952575207\n", 684 | "#254 Loss: 0.002428111620247364\n", 685 | "#255 Loss: 0.002427193336188793\n", 686 | "#256 Loss: 0.002426273887977004\n", 687 | "#257 Loss: 0.002425355603918433\n", 688 | "#258 Loss: 0.002424436155706644\n", 689 | "#259 Loss: 0.002423514612019062\n", 690 | "#260 Loss: 0.002422596327960491\n", 691 | "#261 Loss: 0.0024216733872890472\n", 692 | "#262 Loss: 0.0024207504466176033\n", 693 | "#263 Loss: 0.002419829135760665\n", 694 | "#264 Loss: 0.0024189057294279337\n", 695 | "#265 Loss: 0.0024179841857403517\n", 696 | "#266 Loss: 0.002417063107714057\n", 697 | "#267 Loss: 0.0024161438923329115\n", 698 | "#268 Loss: 0.0024152155965566635\n", 699 | "#269 Loss: 0.0024142952170222998\n", 700 | "#270 Loss: 0.0024133676197379827\n", 701 | "#271 Loss: 0.002412450732663274\n", 702 | "#272 Loss: 0.002411528956145048\n", 703 | "#273 Loss: 0.0024105983320623636\n", 704 | "#274 Loss: 0.0024096802808344364\n", 705 | "#275 Loss: 0.0024087547790259123\n", 706 | "#276 Loss: 0.0024078262504190207\n", 707 | "#277 Loss: 0.0024068995844572783\n", 708 | "#278 Loss: 0.0024059752468019724\n", 709 | "#279 Loss: 0.002405051840469241\n", 710 | "#280 Loss: 0.002404116792604327\n", 711 | "#281 Loss: 0.0024031943175941706\n", 712 | "#282 Loss: 0.0024022667203098536\n", 713 | "#283 Loss: 0.002401341451331973\n", 714 | "#284 Loss: 0.002400410594418645\n", 715 | "#285 Loss: 0.0023994811344891787\n", 716 | "#286 Loss: 0.0023985551670193672\n", 717 | "#287 Loss: 0.0023976238444447517\n", 718 | "#288 Loss: 0.0023966955486685038\n", 719 | "#289 Loss: 0.0023957621306180954\n", 720 | "#290 Loss: 0.002394832205027342\n", 721 | "#291 Loss: 0.0023939006496220827\n", 722 | "#292 Loss: 0.002392966765910387\n", 723 | "#293 Loss: 0.00239203916862607\n", 724 | "#294 Loss: 0.002391106216236949\n", 725 | "#295 Loss: 0.0023901707027107477\n", 726 | "#296 Loss: 0.002389240777119994\n", 727 | "#297 Loss: 0.0023883050307631493\n", 728 | "#298 Loss: 0.0023873704485595226\n", 729 | "#299 Loss: 0.0023864342365413904\n", 730 | "#300 Loss: 0.0023854991886764765\n", 731 | "#301 Loss: 0.0023845701944082975\n", 732 | "#302 Loss: 0.0023836297914385796\n", 733 | "#303 Loss: 0.0023826900869607925\n", 734 | "#304 Loss: 0.0023817545734345913\n", 735 | "#305 Loss: 0.002380818361416459\n", 736 | "#306 Loss: 0.0023798795882612467\n", 737 | "#307 Loss: 0.0023789377883076668\n", 738 | "#308 Loss: 0.0023780011106282473\n", 739 | "#309 Loss: 0.0023770590778440237\n", 740 | "#310 Loss: 0.0023761214688420296\n", 741 | "#311 Loss: 0.0023751859553158283\n", 742 | "#312 Loss: 0.0023742406629025936\n", 743 | "#313 Loss: 0.002373295836150646\n", 744 | "#314 Loss: 0.0023723554331809282\n", 745 | "#315 Loss: 0.002371413866057992\n", 746 | "#316 Loss: 0.0023704750929027796\n", 747 | "#317 Loss: 0.002369531663134694\n", 748 | "#318 Loss: 0.0023685868363827467\n", 749 | "#319 Loss: 0.002367644337937236\n", 750 | "#320 Loss: 0.002366698579862714\n", 751 | "#321 Loss: 0.0023657495621591806\n", 752 | "#322 Loss: 0.0023648033384233713\n", 753 | "#323 Loss: 0.002363859675824642\n", 754 | "#324 Loss: 0.0023629090283066034\n", 755 | "#325 Loss: 0.0023619639687240124\n", 756 | "#326 Loss: 0.0023610175121575594\n", 757 | "#327 Loss: 0.002360069891437888\n", 758 | "#328 Loss: 0.002359122270718217\n", 759 | "#329 Loss: 0.0023581702262163162\n", 760 | "#330 Loss: 0.0023572223726660013\n", 761 | "#331 Loss: 0.002356275450438261\n", 762 | "#332 Loss: 0.0023553166538476944\n", 763 | "#333 Loss: 0.0023543667048215866\n", 764 | "#334 Loss: 0.0023534176871180534\n", 765 | "#335 Loss: 0.002352464245632291\n", 766 | "#336 Loss: 0.0023515131324529648\n", 767 | "#337 Loss: 0.0023505568969994783\n", 768 | "#338 Loss: 0.0023496015928685665\n", 769 | "#339 Loss: 0.002348652807995677\n", 770 | "#340 Loss: 0.002347696339711547\n", 771 | "#341 Loss: 0.0023467380087822676\n", 772 | "#342 Loss: 0.0023457861971110106\n", 773 | "#343 Loss: 0.0023448301944881678\n", 774 | "#344 Loss: 0.0023438704665750265\n", 775 | "#345 Loss: 0.002342912135645747\n", 776 | "#346 Loss: 0.002341957064345479\n", 777 | "#347 Loss: 0.0023409996647387743\n", 778 | "#348 Loss: 0.0023400387726724148\n", 779 | "#349 Loss: 0.002339078113436699\n", 780 | "#350 Loss: 0.002338117454200983\n", 781 | "#351 Loss: 0.0023371621500700712\n", 782 | "#352 Loss: 0.0023361986968666315\n", 783 | "#353 Loss: 0.00233523640781641\n", 784 | "#354 Loss: 0.0023342801723629236\n", 785 | "#355 Loss: 0.002333313226699829\n", 786 | "#356 Loss: 0.002332353265956044\n", 787 | "#357 Loss: 0.002331388648599386\n", 788 | "#358 Loss: 0.0023304217029362917\n", 789 | "#359 Loss: 0.0023294605780392885\n", 790 | "#360 Loss: 0.002328496426343918\n", 791 | "#361 Loss: 0.002327530412003398\n", 792 | "#362 Loss: 0.0023265639320015907\n", 793 | "#363 Loss: 0.0023255993146449327\n", 794 | "#364 Loss: 0.0023246288765221834\n", 795 | "#365 Loss: 0.0023236607667058706\n", 796 | "#366 Loss: 0.002322700573131442\n", 797 | "#367 Loss: 0.0023217289708554745\n", 798 | "#368 Loss: 0.0023207550402730703\n", 799 | "#369 Loss: 0.002319787396118045\n", 800 | "#370 Loss: 0.002318824175745249\n", 801 | "#371 Loss: 0.0023178488481789827\n", 802 | "#372 Loss: 0.002316881902515888\n", 803 | "#373 Loss: 0.0023159075062721968\n", 804 | "#374 Loss: 0.002314941259101033\n", 805 | "#375 Loss: 0.0023139675613492727\n", 806 | "#376 Loss: 0.0023129950277507305\n", 807 | "#377 Loss: 0.0023120215628296137\n", 808 | "#378 Loss: 0.002311046002432704\n", 809 | "#379 Loss: 0.002310073934495449\n", 810 | "#380 Loss: 0.002309101400896907\n", 811 | "#381 Loss: 0.0023081284016370773\n", 812 | "#382 Loss: 0.00230714725330472\n", 813 | "#383 Loss: 0.00230617169290781\n", 814 | "#384 Loss: 0.0023051972966641188\n", 815 | "#385 Loss: 0.002304219640791416\n", 816 | "#386 Loss: 0.0023032415192574263\n", 817 | "#387 Loss: 0.002302265027537942\n", 818 | "#388 Loss: 0.0023012871388345957\n", 819 | "#389 Loss: 0.002300310181453824\n", 820 | "#390 Loss: 0.002299328101798892\n", 821 | "#391 Loss: 0.0022983483504503965\n", 822 | "#392 Loss: 0.0022973709274083376\n", 823 | "#393 Loss: 0.002296391176059842\n", 824 | "#394 Loss: 0.002295407932251692\n", 825 | "#395 Loss: 0.00229442841373384\n", 826 | "#396 Loss: 0.002293441677466035\n", 827 | "#397 Loss: 0.0022924619261175394\n", 828 | "#398 Loss: 0.0022914784494787455\n", 829 | "#399 Loss: 0.0022904963698238134\n", 830 | "#400 Loss: 0.0022895135916769505\n", 831 | "#401 Loss: 0.0022885303478688\n", 832 | "#402 Loss: 0.0022875459399074316\n", 833 | "#403 Loss: 0.0022865592036396265\n", 834 | "#404 Loss: 0.0022855724673718214\n", 835 | "#405 Loss: 0.0022845915518701077\n", 836 | "#406 Loss: 0.002283601788803935\n", 837 | "#407 Loss: 0.002282612957060337\n", 838 | "#408 Loss: 0.002281626919284463\n", 839 | "#409 Loss: 0.0022806443739682436\n", 840 | "#410 Loss: 0.0022796487901359797\n", 841 | "#411 Loss: 0.0022786634508520365\n", 842 | "#412 Loss: 0.0022776739206165075\n", 843 | "#413 Loss: 0.0022766822949051857\n", 844 | "#414 Loss: 0.0022756929975003004\n", 845 | "#415 Loss: 0.0022747062612324953\n", 846 | "#416 Loss: 0.00227371440269053\n", 847 | "#417 Loss: 0.0022727230098098516\n", 848 | "#418 Loss: 0.002271731849759817\n", 849 | "#419 Loss: 0.0022707392927259207\n", 850 | "#420 Loss: 0.002269746968522668\n", 851 | "#421 Loss: 0.002268751384690404\n", 852 | "#422 Loss: 0.002267759060487151\n", 853 | "#423 Loss: 0.0022667646408081055\n", 854 | "#424 Loss: 0.0022657769732177258\n", 855 | "#425 Loss: 0.002264777896925807\n", 856 | "#426 Loss: 0.002263784408569336\n", 857 | "#427 Loss: 0.0022627897560596466\n", 858 | "#428 Loss: 0.0022617937065660954\n", 859 | "#429 Loss: 0.002260798355564475\n", 860 | "#430 Loss: 0.0022597969509661198\n", 861 | "#431 Loss: 0.002258802531287074\n", 862 | "#432 Loss: 0.0022578088100999594\n", 863 | "#433 Loss: 0.0022568099666386843\n", 864 | "#434 Loss: 0.002255811123177409\n", 865 | "#435 Loss: 0.0022548120468854904\n", 866 | "#436 Loss: 0.0022538129705935717\n", 867 | "#437 Loss: 0.0022528113331645727\n", 868 | "#438 Loss: 0.002251812256872654\n", 869 | "#439 Loss: 0.00225081411190331\n", 870 | "#440 Loss: 0.0022498099133372307\n", 871 | "#441 Loss: 0.002248812699690461\n", 872 | "#442 Loss: 0.002247813157737255\n", 873 | "#443 Loss: 0.0022468070965260267\n", 874 | "#444 Loss: 0.002245804527774453\n", 875 | "#445 Loss: 0.0022448061499744654\n", 876 | "#446 Loss: 0.002243800787255168\n", 877 | "#447 Loss: 0.0022427986841648817\n", 878 | "#448 Loss: 0.0022417923901230097\n", 879 | "#449 Loss: 0.0022407902870327234\n", 880 | "#450 Loss: 0.0022397860884666443\n", 881 | "#451 Loss: 0.002238777931779623\n", 882 | "#452 Loss: 0.002237774431705475\n", 883 | "#453 Loss: 0.00223676860332489\n", 884 | "#454 Loss: 0.0022357627749443054\n", 885 | "#455 Loss: 0.002234755316749215\n", 886 | "#456 Loss: 0.0022337529808282852\n", 887 | "#457 Loss: 0.0022327450569719076\n", 888 | "#458 Loss: 0.0022317382972687483\n", 889 | "#459 Loss: 0.002230728277936578\n", 890 | "#460 Loss: 0.0022297168616205454\n", 891 | "#461 Loss: 0.0022287091705948114\n", 892 | "#462 Loss: 0.002227703807875514\n", 893 | "#463 Loss: 0.002226694021373987\n", 894 | "#464 Loss: 0.002225684467703104\n", 895 | "#465 Loss: 0.0022246765438467264\n", 896 | "#466 Loss: 0.0022236653603613377\n", 897 | "#467 Loss: 0.0022226530127227306\n", 898 | "#468 Loss: 0.002221642527729273\n", 899 | "#469 Loss: 0.0022206297144293785\n", 900 | "#470 Loss: 0.0022196185309439898\n", 901 | "#471 Loss: 0.0022186103742569685\n", 902 | "#472 Loss: 0.0022175933700054884\n", 903 | "#473 Loss: 0.0022165849804878235\n", 904 | "#474 Loss: 0.0022155700717121363\n", 905 | "#475 Loss: 0.0022145553957670927\n", 906 | "#476 Loss: 0.0022135439794510603\n", 907 | "#477 Loss: 0.0022125281393527985\n", 908 | "#478 Loss: 0.002211514627560973\n", 909 | "#479 Loss: 0.002210496924817562\n", 910 | "#480 Loss: 0.0022094829473644495\n", 911 | "#481 Loss: 0.0022084659431129694\n", 912 | "#482 Loss: 0.0022074568551033735\n", 913 | "#483 Loss: 0.002206437522545457\n", 914 | "#484 Loss: 0.0022054200526326895\n", 915 | "#485 Loss: 0.0022044044453650713\n", 916 | "#486 Loss: 0.0022033853456377983\n", 917 | "#487 Loss: 0.0022023695055395365\n", 918 | "#488 Loss: 0.002201352035626769\n", 919 | "#489 Loss: 0.0022003341000527143\n", 920 | "#490 Loss: 0.002199317794293165\n", 921 | "#491 Loss: 0.0021982965990900993\n", 922 | "#492 Loss: 0.0021972774993628263\n", 923 | "#493 Loss: 0.00219626072794199\n", 924 | "#494 Loss: 0.0021952392999082804\n", 925 | "#495 Loss: 0.002194217639043927\n", 926 | "#496 Loss: 0.002193200634792447\n", 927 | "#497 Loss: 0.002192180836573243\n", 928 | "#498 Loss: 0.0021911589428782463\n", 929 | "#499 Loss: 0.0021901384461671114\n", 930 | "#500 Loss: 0.002189117018133402\n", 931 | "#501 Loss: 0.0021880920976400375\n", 932 | "#502 Loss: 0.0021870729979127645\n", 933 | "#503 Loss: 0.0021860499400645494\n", 934 | "#504 Loss: 0.0021850315388292074\n", 935 | "#505 Loss: 0.002184005454182625\n", 936 | "#506 Loss: 0.0021829840261489153\n", 937 | "#507 Loss: 0.002181959105655551\n", 938 | "#508 Loss: 0.0021809397730976343\n", 939 | "#509 Loss: 0.002179911592975259\n", 940 | "#510 Loss: 0.002178889000788331\n", 941 | "#511 Loss: 0.0021778629161417484\n", 942 | "#512 Loss: 0.002176836598664522\n", 943 | "#513 Loss: 0.002175812376663089\n", 944 | "#514 Loss: 0.0021747888531535864\n", 945 | "#515 Loss: 0.0021737609058618546\n", 946 | "#516 Loss: 0.002172738080844283\n", 947 | "#517 Loss: 0.002171711064875126\n", 948 | "#518 Loss: 0.0021706840489059687\n", 949 | "#519 Loss: 0.0021696598269045353\n", 950 | "#520 Loss: 0.0021686323452740908\n", 951 | "#521 Loss: 0.0021676046308130026\n", 952 | "#522 Loss: 0.0021665773820132017\n", 953 | "#523 Loss: 0.0021655478049069643\n", 954 | "#524 Loss: 0.0021645205561071634\n", 955 | "#525 Loss: 0.002163497731089592\n", 956 | "#526 Loss: 0.002162465127184987\n", 957 | "#527 Loss: 0.0021614336874336004\n", 958 | "#528 Loss: 0.0021604085341095924\n", 959 | "#529 Loss: 0.0021593787241727114\n", 960 | "#530 Loss: 0.0021583528723567724\n", 961 | "#531 Loss: 0.0021573195699602365\n", 962 | "#532 Loss: 0.0021562918554991484\n", 963 | "#533 Loss: 0.0021552571561187506\n", 964 | "#534 Loss: 0.0021542287431657314\n", 965 | "#535 Loss: 0.0021532000973820686\n", 966 | "#536 Loss: 0.0021521716844290495\n", 967 | "#537 Loss: 0.0021511383820325136\n", 968 | "#538 Loss: 0.0021501071751117706\n", 969 | "#539 Loss: 0.0021490773651748896\n", 970 | "#540 Loss: 0.0021480440627783537\n", 971 | "#541 Loss: 0.002147009363397956\n", 972 | "#542 Loss: 0.0021459797862917185\n", 973 | "#543 Loss: 0.002144948346540332\n", 974 | "#544 Loss: 0.002143915044143796\n", 975 | "#545 Loss: 0.0021428829059004784\n", 976 | "#546 Loss: 0.002141848672181368\n", 977 | "#547 Loss: 0.0021408156026154757\n", 978 | "#548 Loss: 0.002139780670404434\n", 979 | "#549 Loss: 0.0021387485321611166\n", 980 | "#550 Loss: 0.002137715695425868\n", 981 | "#551 Loss: 0.0021366847213357687\n", 982 | "#552 Loss: 0.0021356476936489344\n", 983 | "#553 Loss: 0.0021346136927604675\n", 984 | "#554 Loss: 0.0021335785277187824\n", 985 | "#555 Loss: 0.002132538938894868\n", 986 | "#556 Loss: 0.002131509128957987\n", 987 | "#557 Loss: 0.002130476525053382\n", 988 | "#558 Loss: 0.0021294394973665476\n", 989 | "#559 Loss: 0.002128403866663575\n", 990 | "#560 Loss: 0.002127366838976741\n", 991 | "#561 Loss: 0.0021263323724269867\n", 992 | "#562 Loss: 0.002125295577570796\n", 993 | "#563 Loss: 0.002124261111021042\n", 994 | "#564 Loss: 0.0021232208237051964\n", 995 | "#565 Loss: 0.002122187288478017\n", 996 | "#566 Loss: 0.0021211470011621714\n", 997 | "#567 Loss: 0.002120112767443061\n", 998 | "#568 Loss: 0.002119072712957859\n", 999 | "#569 Loss: 0.0021180338226258755\n", 1000 | "#570 Loss: 0.00211700308136642\n", 1001 | "#571 Loss: 0.0021159613970667124\n", 1002 | "#572 Loss: 0.0021149280946701765\n", 1003 | "#573 Loss: 0.0021138915326446295\n", 1004 | "#574 Loss: 0.0021128482185304165\n", 1005 | "#575 Loss: 0.0021118095610290766\n", 1006 | "#576 Loss: 0.0021107716020196676\n", 1007 | "#577 Loss: 0.002109734108671546\n", 1008 | "#578 Loss: 0.0021087005734443665\n", 1009 | "#579 Loss: 0.002107657492160797\n", 1010 | "#580 Loss: 0.00210661836899817\n", 1011 | "#581 Loss: 0.002105577616021037\n", 1012 | "#582 Loss: 0.0021045382600277662\n", 1013 | "#583 Loss: 0.002103500533849001\n", 1014 | "#584 Loss: 0.0021024595480412245\n", 1015 | "#585 Loss: 0.0021014243829995394\n", 1016 | "#586 Loss: 0.002100378042086959\n", 1017 | "#587 Loss: 0.002099341945722699\n", 1018 | "#588 Loss: 0.00209829886443913\n", 1019 | "#589 Loss: 0.0020972639322280884\n", 1020 | "#590 Loss: 0.0020962206181138754\n", 1021 | "#591 Loss: 0.002095181494951248\n", 1022 | "#592 Loss: 0.0020941428374499083\n", 1023 | "#593 Loss: 0.002093098359182477\n", 1024 | "#594 Loss: 0.002092057839035988\n", 1025 | "#595 Loss: 0.0020910180173814297\n", 1026 | "#596 Loss: 0.002089978661388159\n", 1027 | "#597 Loss: 0.0020889334846287966\n", 1028 | "#598 Loss: 0.0020878936629742384\n", 1029 | "#599 Loss: 0.0020868529099971056\n", 1030 | "#600 Loss: 0.002085815416648984\n", 1031 | "#601 Loss: 0.0020847702398896217\n", 1032 | "#602 Loss: 0.0020837283227592707\n", 1033 | "#603 Loss: 0.0020826871041208506\n", 1034 | "#604 Loss: 0.0020816465839743614\n", 1035 | "#605 Loss: 0.002080598147585988\n", 1036 | "#606 Loss: 0.002079556928947568\n", 1037 | "#607 Loss: 0.0020785192027688026\n", 1038 | "#608 Loss: 0.0020774772856384516\n", 1039 | "#609 Loss: 0.002076430944725871\n", 1040 | "#610 Loss: 0.00207538646645844\n", 1041 | "#611 Loss: 0.00207435037009418\n", 1042 | "#612 Loss: 0.002073307754471898\n", 1043 | "#613 Loss: 0.002072261879220605\n", 1044 | "#614 Loss: 0.0020712194964289665\n", 1045 | "#615 Loss: 0.0020701782777905464\n", 1046 | "#616 Loss: 0.0020691361278295517\n", 1047 | "#617 Loss: 0.0020680923480540514\n", 1048 | "#618 Loss: 0.0020670518279075623\n", 1049 | "#619 Loss: 0.0020660050213336945\n", 1050 | "#620 Loss: 0.0020649584475904703\n", 1051 | "#621 Loss: 0.0020639190915971994\n", 1052 | "#622 Loss: 0.002062877407297492\n", 1053 | "#623 Loss: 0.0020618324633687735\n", 1054 | "#624 Loss: 0.0020607870537787676\n", 1055 | "#625 Loss: 0.00205974536947906\n", 1056 | "#626 Loss: 0.0020587043836712837\n", 1057 | "#627 Loss: 0.0020576564129441977\n", 1058 | "#628 Loss: 0.0020566147286444902\n", 1059 | "#629 Loss: 0.002055570250377059\n", 1060 | "#630 Loss: 0.0020545274019241333\n", 1061 | "#631 Loss: 0.0020534859504550695\n", 1062 | "#632 Loss: 0.002052436349913478\n", 1063 | "#633 Loss: 0.0020513960625976324\n", 1064 | "#634 Loss: 0.0020503487903624773\n", 1065 | "#635 Loss: 0.0020493092015385628\n", 1066 | "#636 Loss: 0.002048263093456626\n", 1067 | "#637 Loss: 0.002047223038971424\n", 1068 | "#638 Loss: 0.002046172507107258\n", 1069 | "#639 Loss: 0.002045132452622056\n", 1070 | "#640 Loss: 0.002044085180386901\n", 1071 | "#641 Loss: 0.002043043961748481\n", 1072 | "#642 Loss: 0.0020420013461261988\n", 1073 | "#643 Loss: 0.0020409554708749056\n", 1074 | "#644 Loss: 0.002039908664301038\n", 1075 | "#645 Loss: 0.002038867911323905\n", 1076 | "#646 Loss: 0.0020378208719193935\n", 1077 | "#647 Loss: 0.0020367794204503298\n", 1078 | "#648 Loss: 0.0020357321482151747\n", 1079 | "#649 Loss: 0.002034691860899329\n", 1080 | "#650 Loss: 0.002033643191680312\n", 1081 | "#651 Loss: 0.002032601274549961\n", 1082 | "#652 Loss: 0.002031555864959955\n", 1083 | "#653 Loss: 0.0020305109210312366\n", 1084 | "#654 Loss: 0.002029466675594449\n", 1085 | "#655 Loss: 0.002028421498835087\n", 1086 | "#656 Loss: 0.002027378184720874\n", 1087 | "#657 Loss: 0.0020263351034373045\n", 1088 | "#658 Loss: 0.0020252885296940804\n", 1089 | "#659 Loss: 0.0020242466125637293\n", 1090 | "#660 Loss: 0.002023200271651149\n", 1091 | "#661 Loss: 0.002022160217165947\n", 1092 | "#662 Loss: 0.0020211131777614355\n", 1093 | "#663 Loss: 0.0020200731232762337\n", 1094 | "#664 Loss: 0.0020190232899039984\n", 1095 | "#665 Loss: 0.0020179767161607742\n", 1096 | "#666 Loss: 0.002016937592998147\n", 1097 | "#667 Loss: 0.002015892183408141\n", 1098 | "#668 Loss: 0.0020148518960922956\n", 1099 | "#669 Loss: 0.0020138081163167953\n", 1100 | "#670 Loss: 0.0020127587486058474\n", 1101 | "#671 Loss: 0.0020117172971367836\n", 1102 | "#672 Loss: 0.0020106742158532143\n", 1103 | "#673 Loss: 0.002009629737585783\n", 1104 | "#674 Loss: 0.002008582465350628\n", 1105 | "#675 Loss: 0.0020075414795428514\n", 1106 | "#676 Loss: 0.002006495138630271\n", 1107 | "#677 Loss: 0.0020054553169757128\n", 1108 | "#678 Loss: 0.002004409907385707\n", 1109 | "#679 Loss: 0.002003363100811839\n", 1110 | "#680 Loss: 0.002002324676141143\n", 1111 | "#681 Loss: 0.002001277869567275\n", 1112 | "#682 Loss: 0.002000238513574004\n", 1113 | "#683 Loss: 0.001999191241338849\n", 1114 | "#684 Loss: 0.0019981495570391417\n", 1115 | "#685 Loss: 0.0019971048459410667\n", 1116 | "#686 Loss: 0.0019960617646574974\n", 1117 | "#687 Loss: 0.0019950189162045717\n", 1118 | "#688 Loss: 0.0019939783960580826\n", 1119 | "#689 Loss: 0.001992932753637433\n", 1120 | "#690 Loss: 0.001991888275370002\n", 1121 | "#691 Loss: 0.0019908458925783634\n", 1122 | "#692 Loss: 0.001989804906770587\n", 1123 | "#693 Loss: 0.0019887599628418684\n", 1124 | "#694 Loss: 0.0019877159502357244\n", 1125 | "#695 Loss: 0.001986677525565028\n", 1126 | "#696 Loss: 0.0019856367725878954\n", 1127 | "#697 Loss: 0.001984592527151108\n", 1128 | "#698 Loss: 0.001983546419069171\n", 1129 | "#699 Loss: 0.001982505898922682\n", 1130 | "#700 Loss: 0.0019814646802842617\n", 1131 | "#701 Loss: 0.0019804220646619797\n", 1132 | "#702 Loss: 0.0019793810788542032\n", 1133 | "#703 Loss: 0.0019783375319093466\n", 1134 | "#704 Loss: 0.0019772977102547884\n", 1135 | "#705 Loss: 0.0019762550946325064\n", 1136 | "#706 Loss: 0.0019752129446715117\n", 1137 | "#707 Loss: 0.001974171493202448\n", 1138 | "#708 Loss: 0.001973131438717246\n", 1139 | "#709 Loss: 0.001972092781215906\n", 1140 | "#710 Loss: 0.0019710464403033257\n", 1141 | "#711 Loss: 0.0019700077828019857\n", 1142 | "#712 Loss: 0.001968963770195842\n", 1143 | "#713 Loss: 0.0019679246470332146\n", 1144 | "#714 Loss: 0.0019668852910399437\n", 1145 | "#715 Loss: 0.001965844538062811\n", 1146 | "#716 Loss: 0.001964807277545333\n", 1147 | "#717 Loss: 0.0019637665245682\n", 1148 | "#718 Loss: 0.0019627264700829983\n", 1149 | "#719 Loss: 0.00196168408729136\n", 1150 | "#720 Loss: 0.00196064286865294\n", 1151 | "#721 Loss: 0.0019596030469983816\n", 1152 | "#722 Loss: 0.001958560897037387\n", 1153 | "#723 Loss: 0.001957525731995702\n", 1154 | "#724 Loss: 0.001956489635631442\n", 1155 | "#725 Loss: 0.001955445623025298\n", 1156 | "#726 Loss: 0.001954407896846533\n", 1157 | "#727 Loss: 0.0019533671438694\n", 1158 | "#728 Loss: 0.0019523290684446692\n", 1159 | "#729 Loss: 0.0019512904109433293\n", 1160 | "#730 Loss: 0.0019502503564581275\n", 1161 | "#731 Loss: 0.0019492128631100059\n", 1162 | "#732 Loss: 0.0019481779308989644\n", 1163 | "#733 Loss: 0.0019471339182928205\n", 1164 | "#734 Loss: 0.0019461024785414338\n", 1165 | "#735 Loss: 0.0019450596300885081\n", 1166 | "#736 Loss: 0.0019440216710790992\n", 1167 | "#737 Loss: 0.001942987204529345\n", 1168 | "#738 Loss: 0.0019419504096731544\n", 1169 | "#739 Loss: 0.0019409122178331017\n", 1170 | "#740 Loss: 0.0019398737931624055\n", 1171 | "#741 Loss: 0.0019388411892578006\n", 1172 | "#742 Loss: 0.001937802298925817\n", 1173 | "#743 Loss: 0.001936764339916408\n", 1174 | "#744 Loss: 0.001935729756951332\n", 1175 | "#745 Loss: 0.0019346913322806358\n", 1176 | "#746 Loss: 0.0019336584955453873\n", 1177 | "#747 Loss: 0.0019326211186125875\n", 1178 | "#748 Loss: 0.001931585487909615\n", 1179 | "#749 Loss: 0.0019305492751300335\n", 1180 | "#750 Loss: 0.0019295121310278773\n", 1181 | "#751 Loss: 0.0019284767331555486\n", 1182 | "#752 Loss: 0.0019274475052952766\n", 1183 | "#753 Loss: 0.0019264090806245804\n", 1184 | "#754 Loss: 0.0019253772916272283\n", 1185 | "#755 Loss: 0.0019243452697992325\n", 1186 | "#756 Loss: 0.0019233074272051454\n", 1187 | "#757 Loss: 0.0019222754053771496\n", 1188 | "#758 Loss: 0.0019212419865652919\n", 1189 | "#759 Loss: 0.0019202110124751925\n", 1190 | "#760 Loss: 0.0019191773608326912\n", 1191 | "#761 Loss: 0.0019181432435289025\n", 1192 | "#762 Loss: 0.0019171085441485047\n", 1193 | "#763 Loss: 0.0019160775700584054\n", 1194 | "#764 Loss: 0.0019150450825691223\n", 1195 | "#765 Loss: 0.0019140088697895408\n", 1196 | "#766 Loss: 0.0019129784777760506\n", 1197 | "#767 Loss: 0.0019119485514238477\n", 1198 | "#768 Loss: 0.0019109140848740935\n", 1199 | "#769 Loss: 0.0019098850898444653\n", 1200 | "#770 Loss: 0.001908852718770504\n", 1201 | "#771 Loss: 0.001907822792418301\n", 1202 | "#772 Loss: 0.0019067925168201327\n", 1203 | "#773 Loss: 0.0019057630561292171\n", 1204 | "#774 Loss: 0.0019047335954383016\n", 1205 | "#775 Loss: 0.0019037051824852824\n", 1206 | "#776 Loss: 0.0019026693189516664\n", 1207 | "#777 Loss: 0.0019016433507204056\n", 1208 | "#778 Loss: 0.0019006148213520646\n", 1209 | "#779 Loss: 0.0018995892023667693\n", 1210 | "#780 Loss: 0.0018985569477081299\n", 1211 | "#781 Loss: 0.0018975288840010762\n", 1212 | "#782 Loss: 0.0018965002382174134\n", 1213 | "#783 Loss: 0.0018954715924337506\n", 1214 | "#784 Loss: 0.0018944436451420188\n", 1215 | "#785 Loss: 0.0018934140680357814\n", 1216 | "#786 Loss: 0.0018923920579254627\n", 1217 | "#787 Loss: 0.0018913644598796964\n", 1218 | "#788 Loss: 0.001890333485789597\n", 1219 | "#789 Loss: 0.0018893079832196236\n", 1220 | "#790 Loss: 0.0018882853910326958\n", 1221 | "#791 Loss: 0.001887254766188562\n", 1222 | "#792 Loss: 0.0018862345023080707\n", 1223 | "#793 Loss: 0.0018852058565244079\n", 1224 | "#794 Loss: 0.00188418326433748\n", 1225 | "#795 Loss: 0.0018831556662917137\n", 1226 | "#796 Loss: 0.0018821310950443149\n", 1227 | "#797 Loss: 0.0018811067566275597\n", 1228 | "#798 Loss: 0.001880083349533379\n", 1229 | "#799 Loss: 0.001879060291685164\n", 1230 | "#800 Loss: 0.0018780353711917996\n", 1231 | "#801 Loss: 0.0018770135939121246\n", 1232 | "#802 Loss: 0.0018759918166324496\n", 1233 | "#803 Loss: 0.0018749730661511421\n", 1234 | "#804 Loss: 0.0018739477964118123\n", 1235 | "#805 Loss: 0.0018729200819507241\n", 1236 | "#806 Loss: 0.0018719009822234511\n", 1237 | "#807 Loss: 0.001870879321359098\n", 1238 | "#808 Loss: 0.001869861502200365\n", 1239 | "#809 Loss: 0.0018688408890739083\n", 1240 | "#810 Loss: 0.001867820625193417\n", 1241 | "#811 Loss: 0.0018667984986677766\n", 1242 | "#812 Loss: 0.0018657720647752285\n", 1243 | "#813 Loss: 0.001864760648459196\n", 1244 | "#814 Loss: 0.0018637363100424409\n", 1245 | "#815 Loss: 0.0018627209356054664\n", 1246 | "#816 Loss: 0.0018617023015394807\n", 1247 | "#817 Loss: 0.0018606797093525529\n", 1248 | "#818 Loss: 0.0018596658483147621\n", 1249 | "#819 Loss: 0.0018586452351883054\n", 1250 | "#820 Loss: 0.0018576303264126182\n", 1251 | "#821 Loss: 0.001856614020653069\n", 1252 | "#822 Loss: 0.001855594920925796\n", 1253 | "#823 Loss: 0.0018545795464888215\n", 1254 | "#824 Loss: 0.001853560097515583\n", 1255 | "#825 Loss: 0.0018525446066632867\n", 1256 | "#826 Loss: 0.0018515288829803467\n", 1257 | "#827 Loss: 0.00185050955042243\n", 1258 | "#828 Loss: 0.0018494967371225357\n", 1259 | "#829 Loss: 0.0018484825268387794\n", 1260 | "#830 Loss: 0.001847467734478414\n", 1261 | "#831 Loss: 0.0018464555032551289\n", 1262 | "#832 Loss: 0.0018454398959875107\n", 1263 | "#833 Loss: 0.0018444285960868\n", 1264 | "#834 Loss: 0.0018434150842949748\n", 1265 | "#835 Loss: 0.0018424022709950805\n", 1266 | "#836 Loss: 0.001841390854679048\n", 1267 | "#837 Loss: 0.0018403776921331882\n", 1268 | "#838 Loss: 0.0018393672071397305\n", 1269 | "#839 Loss: 0.00183835718780756\n", 1270 | "#840 Loss: 0.0018373435596004128\n", 1271 | "#841 Loss: 0.001836334471590817\n", 1272 | "#842 Loss: 0.0018353263149037957\n", 1273 | "#843 Loss: 0.0018343138508498669\n", 1274 | "#844 Loss: 0.0018333062762394547\n", 1275 | "#845 Loss: 0.001832296489737928\n", 1276 | "#846 Loss: 0.0018312829779461026\n", 1277 | "#847 Loss: 0.0018302792450413108\n", 1278 | "#848 Loss: 0.0018292715540155768\n", 1279 | "#849 Loss: 0.0018282626988366246\n", 1280 | "#850 Loss: 0.0018272522138431668\n", 1281 | "#851 Loss: 0.001826247200369835\n", 1282 | "#852 Loss: 0.0018252409063279629\n", 1283 | "#853 Loss: 0.001824233098886907\n", 1284 | "#854 Loss: 0.0018232259899377823\n", 1285 | "#855 Loss: 0.001822225865907967\n", 1286 | "#856 Loss: 0.0018212157301604748\n", 1287 | "#857 Loss: 0.0018202122300863266\n", 1288 | "#858 Loss: 0.0018192125717177987\n", 1289 | "#859 Loss: 0.0018182039493694901\n", 1290 | "#860 Loss: 0.0018171994015574455\n", 1291 | "#861 Loss: 0.0018161969492211938\n", 1292 | "#862 Loss: 0.00181519181933254\n", 1293 | "#863 Loss: 0.0018141911132261157\n", 1294 | "#864 Loss: 0.001813187263906002\n", 1295 | "#865 Loss: 0.0018121921457350254\n", 1296 | "#866 Loss: 0.0018111892277374864\n", 1297 | "#867 Loss: 0.0018101868918165565\n", 1298 | "#868 Loss: 0.0018091824604198337\n", 1299 | "#869 Loss: 0.001808184664696455\n", 1300 | "#870 Loss: 0.0018071848899126053\n", 1301 | "#871 Loss: 0.0018061831360682845\n", 1302 | "#872 Loss: 0.0018051863880828023\n", 1303 | "#873 Loss: 0.0018041870789602399\n", 1304 | "#874 Loss: 0.0018031877698376775\n", 1305 | "#875 Loss: 0.0018021933501586318\n", 1306 | "#876 Loss: 0.0018011946231126785\n", 1307 | "#877 Loss: 0.0018001968273892999\n", 1308 | "#878 Loss: 0.0017991961212828755\n", 1309 | "#879 Loss: 0.001798199606128037\n", 1310 | "#880 Loss: 0.0017972056521102786\n", 1311 | "#881 Loss: 0.0017962086712941527\n", 1312 | "#882 Loss: 0.0017952205380424857\n", 1313 | "#883 Loss: 0.0017942209960892797\n", 1314 | "#884 Loss: 0.001793228555470705\n", 1315 | "#885 Loss: 0.0017922349506989121\n", 1316 | "#886 Loss: 0.001791241578757763\n", 1317 | "#887 Loss: 0.001790247275494039\n", 1318 | "#888 Loss: 0.0017892572795972228\n", 1319 | "#889 Loss: 0.0017882628599181771\n", 1320 | "#890 Loss: 0.0017872735625132918\n", 1321 | "#891 Loss: 0.0017862803069874644\n", 1322 | "#892 Loss: 0.001785286352969706\n", 1323 | "#893 Loss: 0.0017842984525486827\n", 1324 | "#894 Loss: 0.0017833089223131537\n", 1325 | "#895 Loss: 0.0017823184607550502\n", 1326 | "#896 Loss: 0.0017813298618420959\n", 1327 | "#897 Loss: 0.0017803410300984979\n", 1328 | "#898 Loss: 0.0017793524311855435\n", 1329 | "#899 Loss: 0.0017783649964258075\n", 1330 | "#900 Loss: 0.001777378492988646\n", 1331 | "#901 Loss: 0.0017763897776603699\n", 1332 | "#902 Loss: 0.0017754010623320937\n", 1333 | "#903 Loss: 0.001774418051354587\n", 1334 | "#904 Loss: 0.0017734314315021038\n", 1335 | "#905 Loss: 0.0017724483041092753\n", 1336 | "#906 Loss: 0.0017714608693495393\n", 1337 | "#907 Loss: 0.0017704787896946073\n", 1338 | "#908 Loss: 0.0017694927519187331\n", 1339 | "#909 Loss: 0.0017685088096186519\n", 1340 | "#910 Loss: 0.0017675244016572833\n", 1341 | "#911 Loss: 0.001766547211445868\n", 1342 | "#912 Loss: 0.001765563734807074\n", 1343 | "#913 Loss: 0.001764580956660211\n", 1344 | "#914 Loss: 0.0017636003904044628\n", 1345 | "#915 Loss: 0.0017626197077333927\n", 1346 | "#916 Loss: 0.0017616351833567023\n", 1347 | "#917 Loss: 0.0017606564797461033\n", 1348 | "#918 Loss: 0.0017596777761355042\n", 1349 | "#919 Loss: 0.0017587020993232727\n", 1350 | "#920 Loss: 0.001757721765898168\n", 1351 | "#921 Loss: 0.0017567459726706147\n", 1352 | "#922 Loss: 0.0017557647079229355\n", 1353 | "#923 Loss: 0.0017547908937558532\n", 1354 | "#924 Loss: 0.0017538117244839668\n", 1355 | "#925 Loss: 0.0017528367461636662\n", 1356 | "#926 Loss: 0.0017518624663352966\n", 1357 | "#927 Loss: 0.0017508859746158123\n", 1358 | "#928 Loss: 0.0017499076202511787\n", 1359 | "#929 Loss: 0.0017489390447735786\n", 1360 | "#930 Loss: 0.0017479656962677836\n", 1361 | "#931 Loss: 0.0017469911836087704\n", 1362 | "#932 Loss: 0.0017460188828408718\n", 1363 | "#933 Loss: 0.0017450453015044332\n", 1364 | "#934 Loss: 0.00174407206941396\n", 1365 | "#935 Loss: 0.0017430986044928432\n", 1366 | "#936 Loss: 0.0017421283992007375\n", 1367 | "#937 Loss: 0.001741158775985241\n", 1368 | "#938 Loss: 0.0017401917139068246\n", 1369 | "#939 Loss: 0.0017392206937074661\n", 1370 | "#940 Loss: 0.0017382544465363026\n", 1371 | "#941 Loss: 0.0017372820293530822\n", 1372 | "#942 Loss: 0.001736316829919815\n", 1373 | "#943 Loss: 0.0017353454604744911\n", 1374 | "#944 Loss: 0.0017343764193356037\n", 1375 | "#945 Loss: 0.0017334137810394168\n", 1376 | "#946 Loss: 0.0017324457876384258\n", 1377 | "#947 Loss: 0.0017314818687736988\n", 1378 | "#948 Loss: 0.001730515738017857\n", 1379 | "#949 Loss: 0.0017295492580160499\n", 1380 | "#950 Loss: 0.0017285882495343685\n", 1381 | "#951 Loss: 0.0017276207217946649\n", 1382 | "#952 Loss: 0.0017266602953895926\n", 1383 | "#953 Loss: 0.0017256977735087276\n", 1384 | "#954 Loss: 0.0017247359501197934\n", 1385 | "#955 Loss: 0.0017237764550372958\n", 1386 | "#956 Loss: 0.0017228134674951434\n", 1387 | "#957 Loss: 0.0017218533903360367\n", 1388 | "#958 Loss: 0.0017208936624228954\n", 1389 | "#959 Loss: 0.001719936146400869\n", 1390 | "#960 Loss: 0.001718974090181291\n", 1391 | "#961 Loss: 0.0017180143622681499\n", 1392 | "#962 Loss: 0.001717058359645307\n", 1393 | "#963 Loss: 0.0017161048017442226\n", 1394 | "#964 Loss: 0.001715144026093185\n", 1395 | "#965 Loss: 0.0017141870921477675\n", 1396 | "#966 Loss: 0.0017132310895249248\n", 1397 | "#967 Loss: 0.0017122785793617368\n", 1398 | "#968 Loss: 0.0017113216454163194\n", 1399 | "#969 Loss: 0.001710368786007166\n", 1400 | "#970 Loss: 0.0017094146460294724\n", 1401 | "#971 Loss: 0.001708458294160664\n", 1402 | "#972 Loss: 0.0017075081123039126\n", 1403 | "#973 Loss: 0.0017065554857254028\n", 1404 | "#974 Loss: 0.0017056027427315712\n", 1405 | "#975 Loss: 0.0017046512803062797\n", 1406 | "#976 Loss: 0.0017037037760019302\n", 1407 | "#977 Loss: 0.0017027502181008458\n", 1408 | "#978 Loss: 0.0017018018988892436\n", 1409 | "#979 Loss: 0.001700854511000216\n", 1410 | "#980 Loss: 0.0016999054932966828\n", 1411 | "#981 Loss: 0.001698957639746368\n", 1412 | "#982 Loss: 0.0016980115324258804\n", 1413 | "#983 Loss: 0.0016970612341538072\n", 1414 | "#984 Loss: 0.0016961172223091125\n", 1415 | "#985 Loss: 0.0016951701836660504\n", 1416 | "#986 Loss: 0.001694221398793161\n", 1417 | "#987 Loss: 0.0016932813450694084\n", 1418 | "#988 Loss: 0.0016923333751037717\n", 1419 | "#989 Loss: 0.0016913922736421227\n", 1420 | "#990 Loss: 0.0016904502408578992\n", 1421 | "#991 Loss: 0.0016895070439204574\n", 1422 | "#992 Loss: 0.0016885654767975211\n", 1423 | "#993 Loss: 0.001687621814198792\n", 1424 | "#994 Loss: 0.0016866797814145684\n", 1425 | "#995 Loss: 0.001685741706751287\n", 1426 | "#996 Loss: 0.0016847997903823853\n", 1427 | "#997 Loss: 0.0016838625306263566\n", 1428 | "#998 Loss: 0.0016829235246405005\n", 1429 | "#999 Loss: 0.0016819849843159318\n", 1430 | "Predicted data based on trained weights: \n", 1431 | "Input (scaled): \n", 1432 | "tensor([0.5000, 1.0000])\n", 1433 | "Output: \n", 1434 | "tensor([0.9505])\n" 1435 | ], 1436 | "name": "stdout" 1437 | }, 1438 | { 1439 | "output_type": "stream", 1440 | "text": [ 1441 | "/usr/local/lib/python3.6/dist-packages/torch/serialization.py:241: UserWarning: Couldn't retrieve source code for container of type Neural_Network. It won't be checked for correctness upon loading.\n", 1442 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 1443 | ], 1444 | "name": "stderr" 1445 | } 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "markdown", 1450 | "metadata": { 1451 | "id": "L9nBzkgdbjcA", 1452 | "colab_type": "text" 1453 | }, 1454 | "source": [ 1455 | "The loss keeps decreasing, which means that the neural network is learning something. That's it. Congratulations! You have just learned how to create and train a neural network from scratch using PyTorch. There are so many things you can do with the shallow network we have just implemented. You can add more hidden layers or try to incorporate the bias terms for practice. I would love to see what you will build from here. Reach me out on [Twitter](https://twitter.com/omarsar0) if you have any further questions or leave your comments here. Until next time!" 1456 | ] 1457 | }, 1458 | { 1459 | "cell_type": "markdown", 1460 | "metadata": { 1461 | "id": "zcms4BCySKXj", 1462 | "colab_type": "text" 1463 | }, 1464 | "source": [ 1465 | "## References:\n", 1466 | "- [PyTorch nn. Modules](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-custom-nn-modules)\n", 1467 | "- [Build a Neural Network with Numpy](https://enlight.nyc/neural-network)\n" 1468 | ] 1469 | } 1470 | ] 1471 | } -------------------------------------------------------------------------------- /pytorch_hello_world.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pytorch_hello_world.ipynb", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "view-in-github", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "\"Open" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "H7gQFbUxOQtb", 30 | "colab_type": "text" 31 | }, 32 | "source": [ 33 | "# A First Shot at Deep Learning with PyTorch\n", 34 | "\n", 35 | "In this notebook, we are going to take a baby step into the world of deep learning using PyTorch. There are a ton of notebooks out there that teach you the fundamentals of deep learning and PyTorch, so here the idea is to give you some basic introduction to deep learning and PyTorch at a very high level. Therefore, this notebook is targeting beginners but it can also serve as a review for more experienced developers.\n", 36 | "\n", 37 | "After completion of this notebook, you are expected to know the basic components of training a basic neural network with PyTorch. I have also left a couple of exercises towards the end with the intention of encouraging more research and practise of your deep learning skills. \n", 38 | "\n", 39 | "---\n", 40 | "\n", 41 | "**Author:** Elvis Saravia([Twitter]((https://twitter.com/omarsar0)) | [LinkedIn](https://www.linkedin.com/in/omarsar/))\n", 42 | "\n", 43 | "**Complete Code Walkthrough:** [Blog post]()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "CkzttrQCwaSQ", 50 | "colab_type": "text" 51 | }, 52 | "source": [ 53 | "## Importing the libraries\n", 54 | "\n", 55 | "Like with any other programming exercise, the first step is to import the necessary libraries. As we are going to be using Google Colab to program our neural network, we need to install and import the necessary PyTorch libraries." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "7Exoj-CDskQD", 62 | "colab_type": "code", 63 | "outputId": "a1817018-cec5-4b96-e42b-20ccc9337179", 64 | "colab": { 65 | "base_uri": "https://localhost:8080/", 66 | "height": 119 67 | } 68 | }, 69 | "source": [ 70 | "!pip3 install torch torchvision" 71 | ], 72 | "execution_count": 11, 73 | "outputs": [ 74 | { 75 | "output_type": "stream", 76 | "text": [ 77 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.3.1)\n", 78 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.4.2)\n", 79 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.17.4)\n", 80 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.12.0)\n", 81 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.3.0)\n", 82 | "Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)\n" 83 | ], 84 | "name": "stdout" 85 | } 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "FuhJIaeXO2W9", 92 | "colab_type": "code", 93 | "colab": { 94 | "base_uri": "https://localhost:8080/", 95 | "height": 34 96 | }, 97 | "outputId": "f883a3f5-020a-4c81-d9e9-83a4ad60953b" 98 | }, 99 | "source": [ 100 | "## The usual imports\n", 101 | "import torch\n", 102 | "import torch.nn as nn\n", 103 | "\n", 104 | "## print out the pytorch version used\n", 105 | "print(torch.__version__)" 106 | ], 107 | "execution_count": 12, 108 | "outputs": [ 109 | { 110 | "output_type": "stream", 111 | "text": [ 112 | "1.3.1\n" 113 | ], 114 | "name": "stdout" 115 | } 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "0a2C_nneO_wp", 122 | "colab_type": "text" 123 | }, 124 | "source": [ 125 | "## The Neural Network\n", 126 | "\n", 127 | "![alt text](https://drive.google.com/uc?export=view&id=1Lpi4VPBfAV3JkOLopcsGK4L8dyxmPF1b)\n", 128 | "\n", 129 | "Before building and training a neural network the first step is to process and prepare the data. In this notebook, we are going to use syntethic data (i.e., fake data) so we won't be using any real world data. \n", 130 | "\n", 131 | "For the sake of simplicity, we are going to use the following input and output pairs converted to tensors, which is how data is typically represented in the world of deep learning. The x values represent the input of dimension `(6,1)` and the y values represent the output of similar dimension. The example is taken from this [tutorial](https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb). \n", 132 | "\n", 133 | "The objective of the neural network model that we are going to build and train is to automatically learn patterns that better characterize the relationship between the `x` and `y` values. Essentially, the model learns the relationship that exists between inputs and outputs which can then be used to predict the corresponding `y` value for any given input `x`." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "id": "JWFtgUX85iwO", 140 | "colab_type": "code", 141 | "colab": {} 142 | }, 143 | "source": [ 144 | "## our data in tensor form\n", 145 | "x = torch.tensor([[-1.0], [0.0], [1.0], [2.0], [3.0], [4.0]], dtype=torch.float)\n", 146 | "y = torch.tensor([[-3.0], [-1.0], [1.0], [3.0], [5.0], [7.0]], dtype=torch.float)" 147 | ], 148 | "execution_count": 0, 149 | "outputs": [] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "metadata": { 154 | "id": "NcQUjR_95z5J", 155 | "colab_type": "code", 156 | "outputId": "9f7cfdfe-de53-4640-ddc1-7f4fc55e9aa0", 157 | "colab": { 158 | "base_uri": "https://localhost:8080/", 159 | "height": 34 160 | } 161 | }, 162 | "source": [ 163 | "## print size of the input tensor\n", 164 | "x.size()" 165 | ], 166 | "execution_count": 0, 167 | "outputs": [ 168 | { 169 | "output_type": "execute_result", 170 | "data": { 171 | "text/plain": [ 172 | "torch.Size([6, 1])" 173 | ] 174 | }, 175 | "metadata": { 176 | "tags": [] 177 | }, 178 | "execution_count": 5 179 | } 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "id": "9CJXO5WX1QtQ", 186 | "colab_type": "text" 187 | }, 188 | "source": [ 189 | "## The Neural Network Components\n", 190 | "As said earlier, we are going to first define and build out the components of our neural network before training the model.\n", 191 | "\n", 192 | "### Model\n", 193 | "\n", 194 | "Typically, when building a neural network model, we define the layers and weights which form the basic components of the model. Below we show an example of how to define a hidden layer named `layer1` with size `(1, 1)`. For the purpose of this tutorial, we won't explicitly define the `weights` and allow the built-in functions provided by PyTorch to handle that part for us. By the way, the `nn.Linear(...)` function applies a linear transformation ($y = xA^T + b$) to the data that was provided as its input. We ignore the bias for now by setting `bias=False`.\n", 195 | "\n", 196 | "\n", 197 | "\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "metadata": { 203 | "id": "N1Ii5JRz3Jud", 204 | "colab_type": "code", 205 | "colab": {} 206 | }, 207 | "source": [ 208 | "## Neural network with 1 hidden layer\n", 209 | "layer1 = nn.Linear(1,1, bias=False)\n", 210 | "model = nn.Sequential(layer1)" 211 | ], 212 | "execution_count": 0, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "id": "9HTWYD4aMBXQ", 219 | "colab_type": "text" 220 | }, 221 | "source": [ 222 | "### Loss and Optimizer\n", 223 | "The loss function, `nn.MSELoss()`, is in charge of letting the model know how good it has learned the relationship between the input and output. The optimizer (in this case an `SGD`) primary role is to minimize or lower that loss value as it tunes its weights." 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "id": "3hglFpejArxx", 230 | "colab_type": "code", 231 | "colab": {} 232 | }, 233 | "source": [ 234 | "## loss function\n", 235 | "criterion = nn.MSELoss()\n", 236 | "\n", 237 | "## optimizer algorithm\n", 238 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)" 239 | ], 240 | "execution_count": 0, 241 | "outputs": [] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "id": "FKj6jvZTUtGh", 247 | "colab_type": "text" 248 | }, 249 | "source": [ 250 | "## Training the Neural Network Model\n", 251 | "We have all the components we need to train our model. Below is the code used to train our model. \n", 252 | "\n", 253 | "In simple terms, we train the model by feeding it the input and output pairs for a couple of rounds (i.e., `epoch`). After a series of forward and backward steps, the model somewhat learns the relationship between x and y values. This is notable by the decrease in the computed `loss`. For a more detailed explanation of this code check out this [tutorial](https://medium.com/dair-ai/a-simple-neural-network-from-scratch-with-pytorch-and-google-colab-c7f3830618e0). " 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "metadata": { 259 | "id": "JeOr9i-aBzRv", 260 | "colab_type": "code", 261 | "outputId": "b0633322-f198-4f09-814c-61166bd7f8a3", 262 | "colab": { 263 | "base_uri": "https://localhost:8080/", 264 | "height": 1000 265 | } 266 | }, 267 | "source": [ 268 | "## training\n", 269 | "epoch = 150\n", 270 | "for i in range(150):\n", 271 | " model = model.train()\n", 272 | " train_running_loss = 0.0\n", 273 | "\n", 274 | " ## forward\n", 275 | " output = model(x)\n", 276 | " loss = criterion(output, y)\n", 277 | " optimizer.zero_grad()\n", 278 | "\n", 279 | " ## backward + update model params \n", 280 | " loss.backward()\n", 281 | " optimizer.step()\n", 282 | " train_running_loss += loss.detach().item()\n", 283 | "\n", 284 | " model.eval()\n", 285 | " print('Epoch: %d | Loss: %.4f' %(i, train_running_loss) )" 286 | ], 287 | "execution_count": 0, 288 | "outputs": [ 289 | { 290 | "output_type": "stream", 291 | "text": [ 292 | "Epoch: 0 | Loss: 4.0147\n", 293 | "Epoch: 1 | Loss: 3.3385\n", 294 | "Epoch: 2 | Loss: 2.7949\n", 295 | "Epoch: 3 | Loss: 2.3577\n", 296 | "Epoch: 4 | Loss: 2.0063\n", 297 | "Epoch: 5 | Loss: 1.7237\n", 298 | "Epoch: 6 | Loss: 1.4965\n", 299 | "Epoch: 7 | Loss: 1.3139\n", 300 | "Epoch: 8 | Loss: 1.1670\n", 301 | "Epoch: 9 | Loss: 1.0489\n", 302 | "Epoch: 10 | Loss: 0.9540\n", 303 | "Epoch: 11 | Loss: 0.8776\n", 304 | "Epoch: 12 | Loss: 0.8163\n", 305 | "Epoch: 13 | Loss: 0.7669\n", 306 | "Epoch: 14 | Loss: 0.7273\n", 307 | "Epoch: 15 | Loss: 0.6954\n", 308 | "Epoch: 16 | Loss: 0.6697\n", 309 | "Epoch: 17 | Loss: 0.6491\n", 310 | "Epoch: 18 | Loss: 0.6325\n", 311 | "Epoch: 19 | Loss: 0.6192\n", 312 | "Epoch: 20 | Loss: 0.6085\n", 313 | "Epoch: 21 | Loss: 0.5999\n", 314 | "Epoch: 22 | Loss: 0.5929\n", 315 | "Epoch: 23 | Loss: 0.5874\n", 316 | "Epoch: 24 | Loss: 0.5829\n", 317 | "Epoch: 25 | Loss: 0.5793\n", 318 | "Epoch: 26 | Loss: 0.5764\n", 319 | "Epoch: 27 | Loss: 0.5741\n", 320 | "Epoch: 28 | Loss: 0.5722\n", 321 | "Epoch: 29 | Loss: 0.5707\n", 322 | "Epoch: 30 | Loss: 0.5695\n", 323 | "Epoch: 31 | Loss: 0.5685\n", 324 | "Epoch: 32 | Loss: 0.5677\n", 325 | "Epoch: 33 | Loss: 0.5671\n", 326 | "Epoch: 34 | Loss: 0.5666\n", 327 | "Epoch: 35 | Loss: 0.5662\n", 328 | "Epoch: 36 | Loss: 0.5659\n", 329 | "Epoch: 37 | Loss: 0.5656\n", 330 | "Epoch: 38 | Loss: 0.5654\n", 331 | "Epoch: 39 | Loss: 0.5652\n", 332 | "Epoch: 40 | Loss: 0.5651\n", 333 | "Epoch: 41 | Loss: 0.5650\n", 334 | "Epoch: 42 | Loss: 0.5649\n", 335 | "Epoch: 43 | Loss: 0.5648\n", 336 | "Epoch: 44 | Loss: 0.5648\n", 337 | "Epoch: 45 | Loss: 0.5647\n", 338 | "Epoch: 46 | Loss: 0.5647\n", 339 | "Epoch: 47 | Loss: 0.5646\n", 340 | "Epoch: 48 | Loss: 0.5646\n", 341 | "Epoch: 49 | Loss: 0.5646\n", 342 | "Epoch: 50 | Loss: 0.5646\n", 343 | "Epoch: 51 | Loss: 0.5646\n", 344 | "Epoch: 52 | Loss: 0.5646\n", 345 | "Epoch: 53 | Loss: 0.5645\n", 346 | "Epoch: 54 | Loss: 0.5645\n", 347 | "Epoch: 55 | Loss: 0.5645\n", 348 | "Epoch: 56 | Loss: 0.5645\n", 349 | "Epoch: 57 | Loss: 0.5645\n", 350 | "Epoch: 58 | Loss: 0.5645\n", 351 | "Epoch: 59 | Loss: 0.5645\n", 352 | "Epoch: 60 | Loss: 0.5645\n", 353 | "Epoch: 61 | Loss: 0.5645\n", 354 | "Epoch: 62 | Loss: 0.5645\n", 355 | "Epoch: 63 | Loss: 0.5645\n", 356 | "Epoch: 64 | Loss: 0.5645\n", 357 | "Epoch: 65 | Loss: 0.5645\n", 358 | "Epoch: 66 | Loss: 0.5645\n", 359 | "Epoch: 67 | Loss: 0.5645\n", 360 | "Epoch: 68 | Loss: 0.5645\n", 361 | "Epoch: 69 | Loss: 0.5645\n", 362 | "Epoch: 70 | Loss: 0.5645\n", 363 | "Epoch: 71 | Loss: 0.5645\n", 364 | "Epoch: 72 | Loss: 0.5645\n", 365 | "Epoch: 73 | Loss: 0.5645\n", 366 | "Epoch: 74 | Loss: 0.5645\n", 367 | "Epoch: 75 | Loss: 0.5645\n", 368 | "Epoch: 76 | Loss: 0.5645\n", 369 | "Epoch: 77 | Loss: 0.5645\n", 370 | "Epoch: 78 | Loss: 0.5645\n", 371 | "Epoch: 79 | Loss: 0.5645\n", 372 | "Epoch: 80 | Loss: 0.5645\n", 373 | "Epoch: 81 | Loss: 0.5645\n", 374 | "Epoch: 82 | Loss: 0.5645\n", 375 | "Epoch: 83 | Loss: 0.5645\n", 376 | "Epoch: 84 | Loss: 0.5645\n", 377 | "Epoch: 85 | Loss: 0.5645\n", 378 | "Epoch: 86 | Loss: 0.5645\n", 379 | "Epoch: 87 | Loss: 0.5645\n", 380 | "Epoch: 88 | Loss: 0.5645\n", 381 | "Epoch: 89 | Loss: 0.5645\n", 382 | "Epoch: 90 | Loss: 0.5645\n", 383 | "Epoch: 91 | Loss: 0.5645\n", 384 | "Epoch: 92 | Loss: 0.5645\n", 385 | "Epoch: 93 | Loss: 0.5645\n", 386 | "Epoch: 94 | Loss: 0.5645\n", 387 | "Epoch: 95 | Loss: 0.5645\n", 388 | "Epoch: 96 | Loss: 0.5645\n", 389 | "Epoch: 97 | Loss: 0.5645\n", 390 | "Epoch: 98 | Loss: 0.5645\n", 391 | "Epoch: 99 | Loss: 0.5645\n", 392 | "Epoch: 100 | Loss: 0.5645\n", 393 | "Epoch: 101 | Loss: 0.5645\n", 394 | "Epoch: 102 | Loss: 0.5645\n", 395 | "Epoch: 103 | Loss: 0.5645\n", 396 | "Epoch: 104 | Loss: 0.5645\n", 397 | "Epoch: 105 | Loss: 0.5645\n", 398 | "Epoch: 106 | Loss: 0.5645\n", 399 | "Epoch: 107 | Loss: 0.5645\n", 400 | "Epoch: 108 | Loss: 0.5645\n", 401 | "Epoch: 109 | Loss: 0.5645\n", 402 | "Epoch: 110 | Loss: 0.5645\n", 403 | "Epoch: 111 | Loss: 0.5645\n", 404 | "Epoch: 112 | Loss: 0.5645\n", 405 | "Epoch: 113 | Loss: 0.5645\n", 406 | "Epoch: 114 | Loss: 0.5645\n", 407 | "Epoch: 115 | Loss: 0.5645\n", 408 | "Epoch: 116 | Loss: 0.5645\n", 409 | "Epoch: 117 | Loss: 0.5645\n", 410 | "Epoch: 118 | Loss: 0.5645\n", 411 | "Epoch: 119 | Loss: 0.5645\n", 412 | "Epoch: 120 | Loss: 0.5645\n", 413 | "Epoch: 121 | Loss: 0.5645\n", 414 | "Epoch: 122 | Loss: 0.5645\n", 415 | "Epoch: 123 | Loss: 0.5645\n", 416 | "Epoch: 124 | Loss: 0.5645\n", 417 | "Epoch: 125 | Loss: 0.5645\n", 418 | "Epoch: 126 | Loss: 0.5645\n", 419 | "Epoch: 127 | Loss: 0.5645\n", 420 | "Epoch: 128 | Loss: 0.5645\n", 421 | "Epoch: 129 | Loss: 0.5645\n", 422 | "Epoch: 130 | Loss: 0.5645\n", 423 | "Epoch: 131 | Loss: 0.5645\n", 424 | "Epoch: 132 | Loss: 0.5645\n", 425 | "Epoch: 133 | Loss: 0.5645\n", 426 | "Epoch: 134 | Loss: 0.5645\n", 427 | "Epoch: 135 | Loss: 0.5645\n", 428 | "Epoch: 136 | Loss: 0.5645\n", 429 | "Epoch: 137 | Loss: 0.5645\n", 430 | "Epoch: 138 | Loss: 0.5645\n", 431 | "Epoch: 139 | Loss: 0.5645\n", 432 | "Epoch: 140 | Loss: 0.5645\n", 433 | "Epoch: 141 | Loss: 0.5645\n", 434 | "Epoch: 142 | Loss: 0.5645\n", 435 | "Epoch: 143 | Loss: 0.5645\n", 436 | "Epoch: 144 | Loss: 0.5645\n", 437 | "Epoch: 145 | Loss: 0.5645\n", 438 | "Epoch: 146 | Loss: 0.5645\n", 439 | "Epoch: 147 | Loss: 0.5645\n", 440 | "Epoch: 148 | Loss: 0.5645\n", 441 | "Epoch: 149 | Loss: 0.5645\n" 442 | ], 443 | "name": "stdout" 444 | } 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": { 450 | "id": "Bp50Q7J0Xkiw", 451 | "colab_type": "text" 452 | }, 453 | "source": [ 454 | "## Testing the Model\n", 455 | "After training the model we have the ability to test the model predictive capability by passing it an input. Below is a simple example of how you could achieve this with our model. The result we obtained aligns with the results obtained in this [notebook](https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb), which inspired this entire tutorial. " 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "metadata": { 461 | "id": "V1odfZpGFoBi", 462 | "colab_type": "code", 463 | "outputId": "7353c0a0-92ef-4dc4-b8fa-5a0de17fbe3a", 464 | "colab": { 465 | "base_uri": "https://localhost:8080/", 466 | "height": 34 467 | } 468 | }, 469 | "source": [ 470 | "## test the model\n", 471 | "sample = torch.tensor([10.0], dtype=torch.float)\n", 472 | "predicted = model(sample)\n", 473 | "print(predicted.detach().item())" 474 | ], 475 | "execution_count": 0, 476 | "outputs": [ 477 | { 478 | "output_type": "stream", 479 | "text": [ 480 | "17.096769332885742\n" 481 | ], 482 | "name": "stdout" 483 | } 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": { 489 | "id": "ozX4V1GhPLyr", 490 | "colab_type": "text" 491 | }, 492 | "source": [ 493 | "## Final Words\n", 494 | "\n", 495 | "Congratulations! In this tutorial you learned how to train a simple neural network using PyTorch. You also learned about the basic components that make up a neural network model such as the linear transformation layer, optimizer, and loss function. We then trained the model and tested its predictive capabilities. You are well on your way to become more knowledgeable about deep learning and PyTorch. I have provided a bunch of references below if you are interested in practising and learning more. \n", 496 | "\n", 497 | "*I would like to thank Laurence Moroney for his excellent [tutorial](https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb) which I used as an inspiration for this tutorial.*" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "id": "LAABGiMHeDOr", 504 | "colab_type": "text" 505 | }, 506 | "source": [ 507 | "## Exercises\n", 508 | "- Add more examples in the input and output tensors. In addition, try to change the dimensions of the data, say by adding an extra value in each array. What needs to be changed to successfully train the network with the new data?\n", 509 | "- The model converged really fast, which means it learned the relationship between x and y values after a couple of iterations. Do you think it makes sense to continue training? How would you automate the process of stopping the training after the model loss doesn't subtantially change?\n", 510 | "- In our example, we used a single hidden layer. Try to take a look at the PyTorch documentation to figure out what you need to do to get a model with more layers. What happens if you add more hidden layers?\n", 511 | "- We did not discuss the learning rate (`lr-0.001`) and the optimizer in great detail. Check out the [PyTorch documentation](https://pytorch.org/docs/stable/optim.html) to learn more about what other optimizers you can use.\n" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "4-o4w9vpPHZz", 518 | "colab_type": "text" 519 | }, 520 | "source": [ 521 | "## References\n", 522 | "- [The Hello World of Deep Learning with Neural Networks](https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb)\n", 523 | "- [A Simple Neural Network from Scratch with PyTorch and Google Colab](https://medium.com/dair-ai/a-simple-neural-network-from-scratch-with-pytorch-and-google-colab-c7f3830618e0?source=collection_category---4------1-----------------------)\n", 524 | "- [PyTorch Official Docs](https://pytorch.org/docs/stable/nn.html)\n", 525 | "- [PyTorch 1.2 Quickstart with Google Colab](https://medium.com/dair-ai/pytorch-1-2-quickstart-with-google-colab-6690a30c38d)\n", 526 | "- [A Gentle Intoduction to PyTorch](https://medium.com/dair-ai/pytorch-1-2-introduction-guide-f6fa9bb7597c)" 527 | ] 528 | } 529 | ] 530 | } -------------------------------------------------------------------------------- /pytorch_quick_start.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pytorch_quick_start.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "view-in-github", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\"Open" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "9XHd5ExbUIUg", 32 | "colab_type": "text" 33 | }, 34 | "source": [ 35 | "# PyTorch 1.2 Quickstart with Google Colab\n", 36 | "In this code tutorial we will learn how to quickly train a model to understand some of PyTorch's basic building blocks to train a deep learning model. This notebook is inspired by the [\"Tensorflow 2.0 Quickstart for experts\"](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/quickstart/advanced.ipynb#scrollTo=DUNzJc4jTj6G) notebook. \n", 37 | "\n", 38 | "After completion of this tutorial, you should be able to import data, transform it, and efficiently feed the data in batches to a convolution neural network (CNN) model for image classification.\n", 39 | "\n", 40 | "**Author:** [Elvis Saravia](https://twitter.com/omarsar0)\n", 41 | "\n", 42 | "**Complete Code Walkthrough:** [Blog post](https://medium.com/dair-ai/pytorch-1-2-quickstart-with-google-colab-6690a30c38d)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "metadata": { 48 | "id": "KzsiN3l_Vy1p", 49 | "colab_type": "code", 50 | "outputId": "a5f58f2e-64e6-41da-ca0f-73801869c8e0", 51 | "colab": { 52 | "base_uri": "https://localhost:8080/", 53 | "height": 374 54 | } 55 | }, 56 | "source": [ 57 | "!pip3 install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html" 58 | ], 59 | "execution_count": 0, 60 | "outputs": [ 61 | { 62 | "output_type": "stream", 63 | "text": [ 64 | "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n", 65 | "Collecting torch==1.2.0+cu92\n", 66 | "\u001b[?25l Downloading https://download.pytorch.org/whl/cu92/torch-1.2.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl (663.1MB)\n", 67 | "\u001b[K |████████████████████████████████| 663.1MB 20kB/s \n", 68 | "\u001b[?25hCollecting torchvision==0.4.0+cu92\n", 69 | "\u001b[?25l Downloading https://download.pytorch.org/whl/cu92/torchvision-0.4.0%2Bcu92-cp36-cp36m-manylinux1_x86_64.whl (8.8MB)\n", 70 | "\u001b[K |████████████████████████████████| 8.8MB 44.2MB/s \n", 71 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.2.0+cu92) (1.16.4)\n", 72 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==0.4.0+cu92) (1.12.0)\n", 73 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.4.0+cu92) (4.3.0)\n", 74 | "Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision==0.4.0+cu92) (0.46)\n", 75 | "Installing collected packages: torch, torchvision\n", 76 | " Found existing installation: torch 1.1.0\n", 77 | " Uninstalling torch-1.1.0:\n", 78 | " Successfully uninstalled torch-1.1.0\n", 79 | " Found existing installation: torchvision 0.3.0\n", 80 | " Uninstalling torchvision-0.3.0:\n", 81 | " Successfully uninstalled torchvision-0.3.0\n", 82 | "Successfully installed torch-1.2.0+cu92 torchvision-0.4.0+cu92\n" 83 | ], 84 | "name": "stdout" 85 | } 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "uF1P_cRoWpvM", 92 | "colab_type": "text" 93 | }, 94 | "source": [ 95 | "Note: We will be using the latest stable version of PyTorch so be sure to run the command above to install the latest version of PyTorch, which as the time of this tutorial was 1.2.0. We PyTorch belowing using the `torch` module. " 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "Su0COdCqT2Wk", 102 | "colab_type": "code", 103 | "colab": {} 104 | }, 105 | "source": [ 106 | "import torch\n", 107 | "import torch.nn as nn\n", 108 | "import torch.nn.functional as F\n", 109 | "import torchvision\n", 110 | "import torchvision.transforms as transforms" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "metadata": { 118 | "id": "rXCYmmjyVRq5", 119 | "colab_type": "code", 120 | "outputId": "a9ea67e6-cd29-4c4e-bb8f-127eac9ab764", 121 | "colab": { 122 | "base_uri": "https://localhost:8080/", 123 | "height": 34 124 | } 125 | }, 126 | "source": [ 127 | "print(torch.__version__)" 128 | ], 129 | "execution_count": 0, 130 | "outputs": [ 131 | { 132 | "output_type": "stream", 133 | "text": [ 134 | "1.2.0+cu92\n" 135 | ], 136 | "name": "stdout" 137 | } 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "hhuQyU7AYE6K", 144 | "colab_type": "text" 145 | }, 146 | "source": [ 147 | "## Import The Data\n", 148 | "The first step before training the model is to import the data. We will use the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) which is like the Hello World dataset of machine learning. \n", 149 | "\n", 150 | "Besides importing the data, we will also do a few more things:\n", 151 | "- We will tranform the data into tensors using the `transforms` module\n", 152 | "- We will use `DataLoader` to build convenient data loaders or what are referred to as iterators, which makes it easy to efficiently feed data in batches to deep learning models. \n", 153 | "- As hinted above, we will also create batches of the data by setting the `batch` parameter inside the data loader. Notice we use batches of `32` in this tutorial but you can change it to `64` if you like. I encourage you to experiment with different batches." 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "tSjjLXrOVWBy", 160 | "colab_type": "code", 161 | "outputId": "47502e82-f178-452b-995f-8a469670a471", 162 | "colab": { 163 | "base_uri": "https://localhost:8080/", 164 | "height": 285 165 | } 166 | }, 167 | "source": [ 168 | "BATCH_SIZE = 32\n", 169 | "\n", 170 | "## transformations\n", 171 | "transform = transforms.Compose(\n", 172 | " [transforms.ToTensor()])\n", 173 | "\n", 174 | "## download and load training dataset\n", 175 | "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n", 176 | " download=True, transform=transform)\n", 177 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,\n", 178 | " shuffle=True, num_workers=2)\n", 179 | "\n", 180 | "## download and load testing dataset\n", 181 | "testset = torchvision.datasets.MNIST(root='./data', train=False,\n", 182 | " download=True, transform=transform)\n", 183 | "testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,\n", 184 | " shuffle=False, num_workers=2)" 185 | ], 186 | "execution_count": 0, 187 | "outputs": [ 188 | { 189 | "output_type": "stream", 190 | "text": [ 191 | "\r0it [00:00, ?it/s]" 192 | ], 193 | "name": "stderr" 194 | }, 195 | { 196 | "output_type": "stream", 197 | "text": [ 198 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" 199 | ], 200 | "name": "stdout" 201 | }, 202 | { 203 | "output_type": "stream", 204 | "text": [ 205 | "9920512it [00:02, 3643813.85it/s] \n" 206 | ], 207 | "name": "stderr" 208 | }, 209 | { 210 | "output_type": "stream", 211 | "text": [ 212 | "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n" 213 | ], 214 | "name": "stdout" 215 | }, 216 | { 217 | "output_type": "stream", 218 | "text": [ 219 | "\r0it [00:00, ?it/s]" 220 | ], 221 | "name": "stderr" 222 | }, 223 | { 224 | "output_type": "stream", 225 | "text": [ 226 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 227 | ], 228 | "name": "stdout" 229 | }, 230 | { 231 | "output_type": "stream", 232 | "text": [ 233 | "32768it [00:00, 57582.77it/s] \n", 234 | "0it [00:00, ?it/s]" 235 | ], 236 | "name": "stderr" 237 | }, 238 | { 239 | "output_type": "stream", 240 | "text": [ 241 | "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", 242 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" 243 | ], 244 | "name": "stdout" 245 | }, 246 | { 247 | "output_type": "stream", 248 | "text": [ 249 | "1654784it [00:01, 973571.63it/s] \n", 250 | "0it [00:00, ?it/s]" 251 | ], 252 | "name": "stderr" 253 | }, 254 | { 255 | "output_type": "stream", 256 | "text": [ 257 | "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", 258 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" 259 | ], 260 | "name": "stdout" 261 | }, 262 | { 263 | "output_type": "stream", 264 | "text": [ 265 | "8192it [00:00, 21777.08it/s] " 266 | ], 267 | "name": "stderr" 268 | }, 269 | { 270 | "output_type": "stream", 271 | "text": [ 272 | "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", 273 | "Processing...\n", 274 | "Done!\n" 275 | ], 276 | "name": "stdout" 277 | }, 278 | { 279 | "output_type": "stream", 280 | "text": [ 281 | "\n" 282 | ], 283 | "name": "stderr" 284 | } 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": { 290 | "id": "0nZwZukWXUDn", 291 | "colab_type": "text" 292 | }, 293 | "source": [ 294 | "## Exploring the Data\n", 295 | "As a practioner and researcher, I am always spending a bit of time and effort exploring and understanding the dataset. It's fun and this is a good practise to ensure that everything is in order. " 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": { 301 | "id": "NW_loWKga7CH", 302 | "colab_type": "text" 303 | }, 304 | "source": [ 305 | "Let's check what the train and test dataset contains. I will use `matplotlib` to print out some of the images from our dataset. " 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "zWd9Pt1Ca6K9", 312 | "colab_type": "code", 313 | "outputId": "1c02a3b5-f5bb-4c51-a999-52d0472f43af", 314 | "colab": { 315 | "base_uri": "https://localhost:8080/", 316 | "height": 220 317 | } 318 | }, 319 | "source": [ 320 | "import matplotlib.pyplot as plt\n", 321 | "import numpy as np\n", 322 | "\n", 323 | "## functions to show an image\n", 324 | "def imshow(img):\n", 325 | " #img = img / 2 + 0.5 # unnormalize\n", 326 | " npimg = img.numpy()\n", 327 | " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", 328 | "\n", 329 | "## get some random training images\n", 330 | "dataiter = iter(trainloader)\n", 331 | "images, labels = dataiter.next()\n", 332 | "\n", 333 | "## show images\n", 334 | "imshow(torchvision.utils.make_grid(images))" 335 | ], 336 | "execution_count": 0, 337 | "outputs": [ 338 | { 339 | "output_type": "display_data", 340 | "data": { 341 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAADLCAYAAABgQVj0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsnXd4FFXXwH8DhEACBAmRmtCkSkR6\nbwLShAACIkVUhFcEQUAp6guIwiclKqJSBUFAehGkGEDhlSZNaVKkBwRCCRICQnbP98dmLrtphCS7\nE+L9Pc99sjN7Z+dkdvbMueeee44hImg0Go0m45LJagE0Go1G4160otdoNJoMjlb0Go1Gk8HRil6j\n0WgyOFrRazQaTQZHK3qNRqPJ4LhN0RuG0cwwjKOGYfxpGMZQd51Ho9FoNEljuCOO3jCMzMAxoAkQ\nDuwCXhSRw2l+Mo1Go9Ekibss+mrAnyJyUkTuAguAEDedS6PRaDRJkMVNn1sIOOe0HQ5UT6yzYRh6\nea5Go9E8PFdEJOBBndyl6B+IYRi9gF5WnV+j0WgyAGeS08ldiv48EOi0XTh2n0JEpgHTQFv0Go1G\n407c5aPfBZQ0DKOYYRhZgU7A9246l0aj0WiSwC0WvYjEGIbRF1gPZAZmisghd5xLo9FoNEnjlvDK\nhxZCu240Go0mJewRkSoP6qRXxmo0Gk0GRyt6zQOpU6cO06dPZ/r06dhsNpfWvXt3q8XTaNIlXl5e\n5MyZk5w5c7J+/Xq2bt3K1q1b1W+nffv2tG/f3iOyWBZemdaMHz+esmXLsnbtWgCmTZvGvXv3LJYq\naX7//XcAAgICaNSoEX/88YfFEt0nd+7c1KhRA4AZM2ZQoEABAOx2u0u/d955hyeeeAKA0aNHc+fO\nHc8KGofp06fTpUsXwPGA2rt3r6XyJISPjw8Af/zxB/369WPlypUWS6RxB6+//jqffvqp2jYMAwAr\n3OXaotdoNJoMToax6EWEZs2a0axZMwAKFy7Mn3/+CcCOHTs4dCh9Bf2UKFGC3LlzA5A/f36WL19O\nmTJlLJbKQe7cuVm6dCn16tV7YN/SpUvz7rvvAhAWFsaWLVvcLV6SnD59mmzZsgFQsmTJdGnRe3l5\nARAYGKhGQ5qUMWLECAD27t3LqlWrEu03fPhwPvjgAwCqVq3K7t273SrXyJEj6d+/f4LvRUREMH/+\nfLfL4IKIWN4ASW2rXr26xMTEJNhu374t/fv3l8qVK0vlypVTfa60aF26dBG73a7auXPnLJfJbIsX\nL5Z79+6pduDAAdm4caNs3LhRmjRpIv369ZOoqCiJioqSe/fuic1mE5vNJvXq1bNc9m7duqlr+sMP\nP1guT0LNz89P/Pz8xG63S9u2bS2XJzktICBApkyZIps3b5bNmzeL3W4Xm80mly9flsuXL0vXrl0t\nkatv377St29fmT9/fqJ9ihcvLtHR0eq+qFKlitvkyZUrl+TKlUsOHz7sooP27t0rJ0+elJMnT8oT\nTzyRlufcnRwdm2Es+r179/Lee+9RtmxZAJ577jllMWfNmpXQ0FAOHjwIQMOGDbl+/bplsqZ3/P39\nAThy5AgAbdq04cSJE+r9sLAwZa0EBQV5XsBkkt7naADLR0AJERDgSJ3Stm1bevbsCUDevHkJCgpS\n/mVTgZj3Su3atZk7d65H5SxSpAhvvfUWQJLW8eeff062bNm4cOECALdv33abTBMnTgSgVKlSLvvr\n1q1L27ZtAZSnwZNoH71Go9FkcDKMRX/v3j3Gjh2rtv39/cmcOTPg8M+9/vrrlC9fHnD477VFnzg/\n/vgjp0+f5ujRowAu1rxJpkyZ1F/ztRlVYCWm1QTw3XffWShJ4phRQekNX19f2rZty5w5cwCH1e4c\nKXL79m2XyLC8efNStGhRAHr16kWOHDkA6Natm9tlzZEjB5s2bSJXrlwADBkyJF6f6tUdCXMbN24M\nwNdffw3A4cPuKYtRoUIFQkJcs7GPGzcOgDt37nh8xONMhlH0cbl69ap6/dFHH9G7d2+1nR4UUnrm\n448/TvL9l19+WQ3vncMtrVxlXbFiRQBatmypvvvvv09/6ZUee+wx3nnnHavFcMF0dy5ZsoTSpUu7\nuGf+97//AbB8+XLWr1+v3HngUO6TJ09WfT0RTJA1a1YlT7FixTh9+jQAd+/edelnGAZ169YFHMZI\nVFQUy5cvV7K6g/79+6sHD8DKlSvVZLHNZnPLOZOLdt1oNBpNBifDWvTOBAQEuDzF00N+n0eZIkWK\n4O3tbbUYLpiWnpeXlxpluHPSLaUULVqUIkWKWC2G4qOPPmLYsGGAwwqOiIjg7NmzAIwZM0ZZwXFp\n1qwZH374oXLb2e12rly54nZ5zdFQo0aNAOjbty8Af/31l0u/7NmzK7cJwPz58/ntt9/cJpe/vz+V\nKlVS26dPn2b16tXpJiAgQyv6Bg0aADBz5kzgvjsnKirKKpEyJEuWLAHg119/tUyG559/3rJzp5QT\nJ05YvpJ42LBhyvC5cuUKAwcOZN68eQn2DQgIUP77Z599FhFRD9U//viDMWPGuFXWCRMm0K9fPwBi\nYmKYOHEiP/74o0sfMwps/Pjxat+2bdvUce4iKChIzQEC7Nq1i1mzZqntqlWrurznaTKsoi9evDjL\nli0DwM/PDxFRi2dOnTplpWiPNHnz5qV169Yu+z7//HMAS5WWmaIhvWNaouCYFLx165Ylcnz77beA\nw4o3/fC9e/d2mWwtW7asGn0MGzaMunXrqoeCOc9l/qaaN2/uFou+ePHiStbg4GCyZHGorKioKHXf\ngeO+HDVqlModkzdvXuW3f/XVV+P58N1N2bJl+e6772jRogXgGGmao59//vmH69evq8WdzvMe7kL7\n6DUajSaDk2Et+tq1a7vMgIM1Q6bEKFmypNUixKNOnToA/PLLL/Hee/nll4H7IWomv/zyC1u3bnW7\nbBmFwoULK3dHWFiYZXKYETLO81V169ZVDRwL5cwEbHFWsgOwbNkyFc3mLv/8kCFDqFmzpto2/exj\nx47l3LlzrFmzBkBZx86YaU88EWVXq1Ytl/MEBwcTHBzs0se06L28vMiRI4eSb9GiRWqexIwiSmsy\nrKIfNGiQuvCZMmXi+PHjylefHjCHdFZirhyePn06/v7+6sdvDiXNv+PGjVMrIONmr4yr+K0ga9as\nFCtWTG17YiicUkJCQtQc0RdffGGxNK5hiPXq1YsXO++svJxfDx8+nNGjR7tVtqZNm/LSSy+57DNX\nZBcpUoQDBw5Qrly5RI83Hz7uUp7ObNu2Lckgj1u3bql1BnH7dejQgW+++QZwn6zadaPRaDQZnAxn\n0ZvDpcKFC6snp91uZ9SoUR55sieX48ePq5l4q8I9TeuoTZs2wP2hpbkYyrT0evbs6RJGl97w9fWl\nVq1aanvDhg0WSpMwZm7/QoUKER0dbbE0qAiZLl26qO//ypUrKoABHKuMzXvBvEdNK96d1nzOnDkB\nR96YuGG8mzdvdtk2F2yFhIRQsGBBtf/3339X4ZX//POP22RNjF27djF16lS1vX//fqpVqwYkPJIz\nV/b+/PPPbpE3xYreMIxAYA6QD0cWtWkiMtEwjDzAQqAocBroKCIeyzdgxtWabglw+PU2btzoKRGS\nxZo1a3jxxRctO/93331HkyZNAIiMjOT33393GbInN967dOnSbpUzOcSNuFm3bp1FkiROoUKFAFTU\niNWY8fEJxcmb/u527dp53F1jnhfiJwYDVOqSDRs2sGDBApUUrmzZshQsWFDNF7355ptujZt/EDdu\n3OC7775zUdr79+8HHA/U119/XYV/AyoleNmyZd0id2ruuhhgkIjsNQwjJ7DHMIww4GVgo4h8bBjG\nUGAoED8RhRsoVaoUL7zwQrz9zZo1IyIiwhMipBhPW6H58uXDz88PcFhJZj4QkxEjRvD+++8/8HMG\nDx6sLKmPP/5Y5cfxJP/973/V6x9++CFd5qB3Jr3OIQQEBDBs2DA10hMR9bvp1q1bvJh1d3H+/HkA\nwsPDWbt2LceOHQMck9dnzpwBHIoUUNk1TaU5YcIEAEuVPDjy69SsWZOff/5Z7YuJiQEc606aN2/u\nUXlS7KMXkb9EZG/s65vAH0AhIASYHdttNtAmtUJqNBqNJuWkyTjSMIyiQEVgJ5BPRMz1yBdxuHYS\nOqYX0Cstzm8yYMAANbPtTHq05s3MeiZ///23R85rhlCWKVNG+d0LFCjAli1blE8+IT+82TchzNDL\nFi1aMGTIELVS1lMrkJ0XIV2/ft3yBFIPYunSpVaLkCCVK1emf//+Lu4ac1GSp6x5uD+6rVixIteu\nXUt0DqtEiRIuo7l3332XnTt3ekTGuBw/fpzvvvvOxR07b948Nf+xa9cuFaq6du3aeOHVptzuylWf\nakVvGEYOYCnwloj87XyTiIgYhpHgtyQi04BpsZ+RJrORvr6+LjdpYpNe/v7+REZGWqoQ4k4ymbla\n3I0ZI3/kyBE10Wb6Qk0FH1fRX7lyRf3Qv/76a0JCQpSLLF+++8/xPHnyMH36dPXA6N+/v9uVfb58\n+fDy8kr3GUk7duwIOHKymGkE0gumT3zy5MmIiApLHDNmjCqkYQXOGWjjki1bNsaPH0/hwoUBx2/9\nl19+4eLFi54Sz4WoqCjGjBmj5r3y5s1Lvnz5VHnD8+fPq7TpzqkSTEJDQ9XnuIVUlgD0AtYDA532\nHQUKxL4uABz1RClBX19f2bVrl0v5rpkzZ8rMmTOlYMGCUqFCBalZs6bUrFlTrly5Io8//rglpc/M\nFh4ebmkpwf79+7uUC3QuCWhuHz9+XI4fPy4VKlSId3zRokWlaNGisnv37njHme2ll15y+/+xcuVK\nl+vYrFkzS7/XhJphGLJ69WpZvXq17N2713J5nFuzZs3U92eWB2zXrp20a9fOctmSatWqVZPbt2+r\n771NmzaWywTI0KFDZejQobJhw4Z4JU3N6xx3/7Rp0yR79uySPXv2lJwzWaUEU+yjNxwm1NfAHyLy\nidNb3wPdY193B1am9BwajUajST2pcd3UBroBBwzDMKe43wU+BhYZhtEDOAN0TJ2IyePWrVusWbNG\nFaAA1Kq6Bg0aEBQUpKIxIiIiPJ7kKC4zZ85MVlSLu/jyyy+V79NMUma6PyIiIhgyZIhKUnbp0qV4\nx5trEpo2barC8SZOnKgiedyNOWSvXLkygAqfXb9+vUfO/zD4+/urldBr1661WBoHcd014HDRjRkz\nxiWWPr1hzhUNGjQIb29vJXt6SWpnFu1Zv349GzZsSPT3EBkZqdw1n3zyidtj/VOs6EXkFyAxx2ij\nRPa7la+//lrFfjsXBy9SpAjXrl1j8ODBgCMHRmRkpBUiKuKGIXo6dXJMTIyaaHPOAviwXL161SWt\n7TfffKPCB92ZA+fxxx8HUKGds2c7Ar3SY62Bjh07qoeoJyc1E6NKlSosXrwYcChOc06mQ4cO6bJY\nuTPmg71Dhw7A/XvXXDiVXti3bx/ff/+9ivuvUKEC27dvBxxG6cSJEz26cE6nQNBoNJqMTmomY9Oq\n4YZJkQEDBsjff/8tf//9t2zatEnq1Klj+USNc+vSpYuaSAoPD5eyZctaLtOj1CpVqiSVKlUSu90u\nW7ZskVy5ckmuXLkslyuh9vbbb8upU6fk1KlTUrRoUUtlKVu2rCxdutRlgnDUqFEyatQoyZs3r+XX\n6kEtLCxMwsLCxG63y/bt2yVPnjySJ08ey+WysCVrMtZyJe8uRa+bbrrdb0WKFJEiRYrIpUuXRERU\nBEh6j66J25YsWSJLliwRu90urVq1slyedNCSpeiN9ODTTKs4eo1GkzDmYp1hw4bx7rvvqkpSCcV0\nax4p9ohIlQd10j56jUajyeBoi16j0WgeXbRFr9FoNBqt6DUajSbDoxW9RqPRZHC0otdoNJoMjlb0\nGo1Gk8HRil6j0WgyOFrRazSaDMG6devUSlAz6ZnGQfooSa/RaDSpoFatWjRu3Fhl4kwP64PSE9qi\n12QosmbNytixYxk7diwiolLDajImgYGBBAYGMnXqVAzD4MqVK1y5csVttVcfVbSi12g0mgyOdt08\nIjz77LMALF++nC+//BJAFVLR3Kd9+/a8/fbbgGP4nh4qTvn6+jJixAgAqlWrxvHjx9V70dHR6vs8\nduyYJfIlxbvvvgtA9erVadWqldq/ceNGVQjbSr7++msAypUrB0C/fv0A+O233xI95l+J1SmKdZri\npJuXl5dMmjRJbt26Jbdu3RKbzSadOnWSTp06WS5bempmsXLzGtlsNjl69GhKCy6nuj3xxBPyxBNP\nyG+//SYTJkxQtQfu3bunXpv54G/cuCE3btyQ//znP5ZfR0AmTJggEyZMkBs3bihZzWtqtsOHD0vh\nwoUtlbNhw4Zy584duXPnjthsNlmzZo08/fTT8vTTT1t+DT3Y3FscXKPRaDSPBql23RiGkRnYDZwX\nkecMwygGLAD8gT1ANxGxthL3I8Jjjz0GQNGiRXn++ecBaNOmDWXLllU1R+fPn8/KlSvdKke+fPl4\n77331PZrr71G9uzZAVi0aBGXLl3iwIEDABw+fJi//voLgJMnT7pVrqQwr1e2bNnUvk8//ZTbt29b\nIo9Zy7ZgwYIMHDiQe/fuATBlyhRVKNzX15dcuXKRM2dOAMaPH8+JEycA2LBhgwVSQ4sWLejVq5eS\nLzFKly7NzJkzlUvR0wQGBjJ+/Hi8vLwA2LNnD88//7xl37e7yJ49O6VKlVK1Z8+ePZuiz0l1mmLD\nMAYCVYBcsYp+EbBMRBYYhjEF+F1EkqzcmxZpimvWrMmaNWvYv38/AI0aNSImJsalT+bMmQGoU6cO\n27dv5+5d654/xYsXp3bt2mo7W7ZsjB49GoC8efPGCw/btm0b4Ch6fuPGDbfIVKBAAQBWr17N008/\nnezjzp8/D0BQUJBb5HoQZcqUYdeuXcD9Ahtw//u2gqxZswIOhfTKK6+oYvArV64kUybHQLpWrVr4\n+Pjwf//3fwCUKFFCFYl/9tln2bFjh0dlLlGiBLt37yZXrlxqn2lgxL0fIyIi6Ny5M5s2bfKojIGB\ngYBjrqpixYpcvnwZgE6dOrF582aPypJSSpUqRevWrdV2rVq1KFasmEsf87pny5aNkiVLKkWfN2/e\nuB+XrDTFqbLoDcMoDLQERgMDDYd0zwCdY7vMBkYCbivRbv6wQ0ND8fPzo27dugCsXbuWEiVKxJUX\ngCJFijB79mxef/11AP755x93iZcow4cPp3v37kD8H5EpJzgmm4YPH66sZndSvXp1AKXkFy5cCMCW\nLVvUhCFAVFQUb775JgDBwcHcvHnT7bIlRY8ePVysz61bt1oojQPTiDhx4gTvv/9+gn3WrFlDq1at\nyJEjh9pnvs6fP7/7hYwlJCQEgFGjRrkoeXBMugIUKlSI0qVLq/379+/3uJIH+PbbbwGoWLEiAG+8\n8QbAI6HkJ092qMGXX35ZGQLg+L0npgPM/Xv27EnVuVPro/8MGAzYY7f9gUgRMU3pcKBQQgcahtHL\nMIzdhmHsTqUMGo1Go0mCFFv0hmE8B1wWkT2GYTR42ONFZBowLfazUuy6+eKLLwCoUaOGy/5GjRol\neVz37t2ZNm0agCWLaqZPn865c+fUdrFixdi3b5/aXrJkCeBwi8R1QbmLFStWADBp0iTefPNNwsPD\nAbDZbC79xo8fz5w5czwi04No3LgxAwYMUJbPrVu3ePXVVy2WKnGefPJJ2rVrp7ZHjBihXDmAcj3+\n8MMPHpOpb9++wP36sWfOnAFgzpw5jBw5EoDChQvz/fffU6FCBQAqVapE69at+f777z0mZ//+/alW\nrZranjNnDmFhYR47/8OSM2dONR/z3Xffqf3Hjh3jyy+/JDIyMlmfs3r1auW6STGpCIn8PxwW+2ng\nIhANzAOuAFli+9QE1rsrvLJ3794q3MsMA0tumzlzphiGIbEPGY82f39/2b59u2zevFk2b95sWQhg\nYi1z5szyww8/xAups9lscvr0aSlVqpTlMpqtefPmLvKdPHnScpkSaq1atZJWrVrJ+fPn492Lc+bM\nkTlz5sibb74pwcHBEhwc7HZ5cubMKTlz5pTPP/9cLl26JJcuXZLbt2/LoUOHpHPnztK5c+d4xwQF\nBcnZs2fl7NmzYrPZ5LffflOf405Zn3vuOXnuueckOjpafc+bN2+WHDlyWP69JtaqV68umzdvdrk3\n9+3bJ/v27ZNChQql5bncG14pIsNEpLCIFAU6AZtEpAvwE9A+tlt3wL0hIhqNRqNJEnesjB0CLDAM\n4yNgH/B1Wp+gcOHCgGPyyHniEu6viFu3bp2aEJs/fz47d+7Ez89P9Tt27JhliY9eeeUVqlWrxrBh\nwwDSXUiYzWZjwIABBAcHA46JOJNXXnklXa3gbNq0KYAaBrdv3z6p7pZhRosUKFBAJd46duwY0dHR\nyj3iyfDUxo0bA9CnTx+178iRIzz55JOJHnP27Fk1CT9ixAgqVKigXKSm2y+tKVCgAP/9738B8Pb2\n5sqVKwB8+OGHKkIpIXx8fNSEbePGjVm3bh07d+50i4wmmTNnViuJ33rrLXLnzs3Vq1cB6Natm6Wr\ntNNE0YvIz8DPsa9PAtWS6p9azDA+f39/l/0TJkxgwoQJACrsyiSun/nUqVNulDBpfvrpJ65du0bH\njh0BmDhxoiWRP0lx7Ngx5at1VvS3bt2ySqQEMRV7REQEAHv37rVSnEQpVaqUej1//nwAXnrpJavE\nwdvbW702fxvLli174HHmGo5bt26xfv16l89xB0uXLqVKFUf04NWrV+nc2RHQZ0YDOePr60vDhg0B\neOedd6hTp456b9iwYbzzzjsALF68mIsXL6a5rJ06dVIPbcMwCAsLUw9Gy40jq9MfpMRH/9RTT8lT\nTz0lt2/fliNHjsiRI0fkxRdflEyZMiXYv2XLli4+0evXr4u/v7+lPrzBgwcreebNm5fu/PSlSpWS\nyMhIiYyMdPEzVqtWzXLZAOnatat07dpVRETsdrvMmjVLZs2aZblcibWOHTtKx44dxW63y+3bt+X2\n7dvy3nvvSdasWS2R5/Dhw3L48GGx2WwycuRIGTly5EMd37hxY7HZbHLo0CE5dOiQW2Rs3bq1Sm9g\ns9mkWbNmifYtW7as/PzzzwnOK505c0YiIiLU9vjx490i79tvv61+0/v37/fUb1qnQNBoNBoNWG7N\npybqpmTJkpIjR44Hzr6/++67Lhb9uHHjPPGkfWDr3bu39O7dW27evCnz5s0Tb29v8fb2tlwuQP7z\nn/8oC8g52mHAgAGWRCrFbf3795f+/fuL3W6XGzduSLly5aRcuXJSvXp1yZ07t+TOndtyGZ2bGeG1\nYsUKl0RhixcvlgYNGkiDBg3Ey8vLI7IEBwfL5cuX5fLly2K326VRo0bSqFGjh/qMxo0bi91uVxE7\naRkp5O/vL/7+/rJjxw6x2WwSFhYmYWFh4uvrm+gxbdq0iRfh0r59e2nfvr0ULVpU6tevr967ffu2\nW65rjx491Dnu3Lkj06dP98T3mSyL3nIlnxpF/6Dm5eUlXl5esn37dhdF37FjR098AcluW7dulaio\nKClfvryUL1/ecnkA+frrr9VNGxoaqhSDzWZLF0r0hx9+kB9++EHsdrts27ZNBg0aJIMGDRKbzSbn\nzp2Tc+fOybFjx2TevHnqvWzZslkud758+WTVqlWyatUq+f333yUqKkrdlzdu3JCGDRtKw4YN3SrD\nV199pb7bjz76SLJkySJZsmR5qM8wXTdm69atW5rJN2bMGBkzZoyLyyYptw0gGzdudJGnffv2Lu+X\nLFnS7Yrey8tLNmzYIBs2bBC73S5RUVEyffp0dyt8rej9/PzEz88vXtxyelP0xYsXlxs3bsjs2bNl\n9uzZlssDroq+fPnyKg7cZrNJz549LZfPWdGPHDlSatSoITVq1BCbzabmFiIiIuTevXvq/9i9e3e6\nWgNQrVo1CQ8Pl+vXr8v169fFZrPJnj17ZM+ePUlar6lt4eHh6ppUrVo1RZ/xzjvvuEXRt2rVSu7d\nu6e+t+XLl4uPj4/4+PgkedyKFStc5NmwYYM0b95cmjdvLn379lXzETabTS5duuS2a1ugQAEpUKCA\nDB06VI2Go6Oj4z140rBpH71Go9FowHJr3pMW/d27d+Xu3btSsWJFtz3RU9ref/99VVzEalmyZs0q\ne/fuVRZQQECAFCpUSAoVKiRnz56VVatWWS6js0Xft29f8fX1FV9fX2nWrJkEBARIQECAAPLCCy+4\nRA+dO3dOihUrJsWKFbP8f4jbFixYoO7VFStWuO084eHhsn79elm/fv1DR4aY190s8HL8+HE5fvy4\n+Pn5pYls//3vf10s88DAwGQdV7t27QQjbhJq77zzjke+z0mTJrl4Etx0nmRZ9P+qUoJmTnDnnDLu\nYtKkSWzfvp2lS5cCD86QeefOHZVLPXv27JYuovLy8lI5TUzMVMQXLlygVq1aBAQEAPfj1z2Nmc3T\nMAxq1aqlch6tW7fOpd/ChQtZu3Yt4MgAWKJECTp06ADAuHHjUnz+NWvWqLjzRYsW8ffff6f4s0yc\n13bETVub1pj1BcqWLZustQcNGjSgatWqNGjQALif9/+zzz4DSLPU2XF/J6tWreKnn3564HFm/v+E\nEBEuX77Mxx9/DKRtdlNzTc/t27fj/RZGjRqlrle5cuVUnn8zx5Yn0a4bjUajyeBkaIveXEVnkuoM\ncA/B4cOH+fbbb1UaATPdQXLInTt3ukuL4Iy3t7da0m+VRX/69GnAYa3Vr18/0RFGjhw5VMZDX19f\nRCTJyknJpXDhwsoyGzJkiFrleuPGDUqXLs3q1auB+6PIpDCtUefi2+aqZHdhFr358ccf1epNs7qV\nSdOmTalfvz4A9erVS7CQiznSSyvGjRunMtGGhIQQHBysfkPJwcz0arfb1e/9ww8/VLng0xrTOzBm\nzBhCQ0Nd3ouIiFA1HT744AOVtdQKiz5DK/qnnnrKZdtd+TgSYsGCBQwdOpQhQ4YAjuXuPXv2BBzD\nfmdy5sxJ7969lTvCE0VGkkJEiI6OdqnW5Ez27Nl54oknAOtSDpjpfMFRpOPTTz8FHD+ookWLAg5l\nVqtWLZeU1RcuXFDutNQQGhq9TDgZAAAgAElEQVSqHiA9e/akd+/egKOkYfbs2ZWizpEjBzNnzuTw\n4cMA/Pnnn9y5cwdwPBReeeUVmjRpAjhSGJuKY9SoUamWMTk89thjzJ07N9H3zRTKZn4eZ7Zt2+aW\n/DGmQhw8eDBeXl5UqlQJcJTVTIyzZ88yY8YMFi9eDHgu5YBZ/jMuderUoV27di4ps61Mu2L5RKy7\nJmOLFSumwtbMyZAOHTpIhw4dPDIRA0ihQoXkjz/+kD/++EPFSd+4cUOmT58uzZo1kwoVKkiFChVk\n8uTJcufOHbVwxlPyJdWcwyu7d++u9puLWMwl/VbLOW7cOLl27VqCk27moiRzEv6XX36RsmXLprkM\nn3zyiQwfPlyGDx8u0dHRYrfbZdOmTbJp06aHTp/9wQcfyAcffODWa/bJJ58ke+LSJO7+AwcOSIEC\nBSz//q1uJna7Xf78889436eZ7mLy5MnukkGHV2o0Go0mA7tucufO7ZKWGO77dT3F+fPnadmyJeCo\naDVgwAAAXn31VV599VWXupCXLl1SBaTTG8WLF3fZvnbtmqUpV50ZPHgwU6dOVVlLnYsuz5s3j4iI\nCOWy27Jli1tkGDhwoHo9fvx4WrRoQadOnR76c2bPns306dPTUrQEmTZtGm3btgWSX9D97t27/PPP\nP8pFtm/fPstdjOkB0y/ftWtXihUrZnoo2L9/P9evX2fs2LEAlv9eMqyiTy+YOcZHjBjBkSNHABgw\nYACVK1dWfRYvXswnn3ySrn44zuGCffv2Valsq1atyrJly9IsnC4tOHHihFJcVnP79m2WLl2q5mGO\nHDnCiRMnVGHt2rVrq/sAoEyZMnz++eeAIwVwciZvU8uRI0dUHv927drx7LPPAo6J2fbt26sU32Yf\ngA4dOqgJZs193n77bcCh8M2AAHCdQ0oPaNeNRqPRZHSsnoh112RsixYt4k2MVK1aNcW5Pf5tzc/P\nT9asWSNr1qyJNxH3/vvvWy6fbrrphvBvXxlrDkdN7t27l+6qOKVnbty4wfvvvw9A/fr11UrIxYsX\n8/XXaV4dUqPRuJNUWuK5gSXAEeAPoCaQBwgDjsf+fcwKi75Zs2bKkj9z5oxLiKBuuummWwZpHgmv\nnAisE5EyQAUcyn4osFFESgIbY7c1Go1GYxGGGQ700Acahh/wG1BcnD7EMIyjQAMR+cswjALAzyJS\n+gGflTIhNBqN5t/NHhGp8qBOqbHoiwERwCzDMPYZhjHDMAxfIJ+ImHGCF4F8qTiHRqPRaFJJahR9\nFqASMFlEKgK3iOOmibX0E7TWDcPoZRjGbsMwdqdCBo1Go9E8gNQo+nAgXETMrEZLcCj+S7EuG2L/\nXk7oYBGZJiJVkjPs0Gg0Gk3KSbGiF5GLwDnDMEz/eyPgMPA90D12X3dgZaok1Gg0Gk2qSG0c/ZvA\nPMMwsgIngVdwPDwWGYbRAzgDdEzlOTQajUaTClIcdZOmQuioG41Go0kJbo+60Wg0GkupWbMmNWvW\nZNu2bYgIZ8+e5ezZs1aLle7IsCkQNGmLmYr3rbfeIjw8nFq1alkskSv58jmieLt3765K8pnl8pzT\nQU+YMIHBgwdbI6QmzQgMDGThwoXUrFlT7du+fbuFEqVzrE5oltoUCGPHjpWxY8cmWK1n/vz5Mn/+\nfKlfv77Vy5RT1Xr27CmTJk2SSZMmyb59++TMmTNy5swZefzxxz1y/kWLFonJ2bNnJTQ01PJr4txC\nQkJk586dsnPnzgdWTLp796506tRJOnXqZLncPj4+smrVKnVtd+3aJd7e3uLt7W25bOm1mZXN4hIa\nGiqBgYGqWS2nv7+/jBgxQumimTNnuutcusKURqPRaB7RydgyZcoA0KVLF9555x0AvLy8Eu1//fp1\ntmzZwn/+8x/AUZ09PZMnTx5VSPyZZ56hQoUKPP7444CjsMWsWbMAGDJkCLdu3XKbHKa7JjQ0VA2L\n05vLBhwFyitUqJDge/369SMmJgaAFi1aUK1aNa5cuQJAcHCwx2R0JmfOnADMmjWLNm3aKNfSzZs3\nlbvp0KFDlsiWngkMDHTxv7/wwgvqvjx37pxVYrlQr149AMaMGePiVpo4caJLJbI0JFmTsY+cjz5L\nliyMGjUKgOeffz5Zx+TJk4eQkBBy5MgBxE9hnB7IlCmTqlbz1ltvkT9/fvXeqVOn6NGjBwBhYWGE\nh4e7XZ6aNWuqMmnnzp3jhRdecPs504K7d+8yevRoALy9vZkyZYqqjjV16lRWrVpF1apVLZPPx8eH\nfv36AdCmTZt475slCP/73/96VC6A5557jpEjR1KxYkW1b+/evQCEhIRw4cIFj8vkjHk/guOe3L59\ne7pR8ABvvvkmH330EYDSNSabN29Wab8jIyNd5P7tt984c+aMW2XTrhuNRqPJ4DxyFv2IESNcLPk/\n//wTgCVLlgDw5JNPAqjIC2eqVEmf2RaqVKnC2LFjadiwodq3bds2AKZMmcLcuXMZOXIkEN9ScBft\n27dXr99+++10ZTklxKlTpwCHVbV27VqLpUmcJk2a8MEHH6jtGjVqqPrBX375JS1atABg3Lhx3Lx5\n0yMymVZo//798fHxUec9ePAgBQoUAGDPnj2UL1+eq1evekQmZ0yXR4cOHdS+F154IV3ck97e3oCj\nrvKHH36otuOyfPlyEnOT37p1i59//hlwLW6fljxyij7uhTC/fLMYr6+vL+Bw17z44osAfPzxx8B9\n3+iQIUNUdXYreemllwAYO3Ys+fLl4+7du4DDr7xlyxYAwsPDmTt3rhrmDx48mF69egEwd+5ct8gV\nGBjIwIEDWbx4MQCLFi1yy3nSijFjxqgfUXpW8uDw4Zo++dq1a7N7927lprtz5466f/38/Dyi6Nu1\na0ffvn0ByJ49O1u3bqV3796AY57gueeeAxyFy6dMmeKibD3FW2+9pV4PGjQISD+hlBMnTgSgZ8+e\nGIbhoswjIyNZtmxZoseWKlUKgLp16yp38v/93/8xbNiwtBfU6tDKhw2v/Pnnn1Wo3PXr16VgwYJS\nsGDBJI+5evWqS4jd6NGjLQ+/6tSpk1y8eFEuXrwodrtd9u7dK61bt5bWrVsLIOXKlZNy5crJhx9+\nKNHR0SpM686dOzJr1iyZNWuW5M6d2y2ymeFrixYtkkWLFll+rdKyrVq1Sl13T563Z8+e0rNnTzly\n5Ihs3LhRNm7cKP7+/i59Fi5cKDExMRITEyMNGzZ0qzzm7+bmzZvqnL/++qvkzJkzwf5Tp06V6Oho\nqVSpklSqVMlj1y00NFSFUKa3e7FXr14u4dw3btyQr776Sr766isJCgpK9ufUrl3b5XPKlSv3MHLo\n8EqNRqPRPIKum27durFw4UIAhg0blqxIgPfff58vvvhCbTdo0IDcuXMTGRnpNjmTIiAggC5duqiQ\nyVOnTtG0aVMV9lmqVCkVgdOtWzcyZ87M8OHDAVixYgUHDx50q3zVq1cH7s97ZBTq1q1L8+bNVXil\npwgMDFSRQD4+Prz88ssASfq7jx8/7laZzDkfHx8fbt++DTgiaxJzF0VFReHt7U3WrFndKldcnF1F\nptvGarJlywY45tbM3+yxY8f44IMP2LBhw0N/XtWqVV1cPsuXL+eZZ54B4Pz582kgMVjutnlY1w0g\n2bJlk2zZsiW7f+PGjeXatWsu7hsrVs9lyZJFsmTJItOmTZMbN27IzJkzZebMmVKhQgUBpF27dtKu\nXTtZunSpGsbdunVLXnrpJfHx8REfHx+PyGlSs2ZNqVmzpmVD47Ru9evXF5vNJv369ZN+/fp57Lyv\nvvqqREZGSmRkpISFhUmpUqWkVKlS8fo1adJEuVFMF547WsGCBdXvICYmRrp16ybdunVL8pjQ0FCJ\niYmRsLAwCQsLk5deesnt161mzZoiIrJt2zbZtm2b5fePu9qpU6dcdFNkZKSUKVNGypQpk5zjtetG\no9FoNI+g6wYc0QkPw4YNGzh37hx+fn5ukih5mKFXr732GnB/Nd/vv/9O1apVmTdvnupnRuB89NFH\nzJkzxyPyBQYGqtfbt293iWzo2NFRVqB9+/acO3dOuXXCw8MtCXMLCgoC4PLly1SsWNEl7PTo0aMA\nahWlmfBs+PDhnDx50m3RSgnh7e3NG2+8oSK+Ll68yLFjxxLse/z4cS5duuR2meIufkvKTWRGhpgL\nucwQ4LJly7r9vhwwYAAAn332WbL6BwYGuizyM49LD2GYSbFx40ZeeeUVwLE6ukKFCmm+gOqRVPQZ\nhSxZHJe/devWfP311+pB8L///Y+hQx3ldz0ZRua8ZLtmzZoqfrl9+/Yu78H92Obt27crxeHuH1Tx\n4sUBRxoG8/pcvXqVp59+2kXRHz58GHCEr+7fv586deoAjrmZjRs3cu3aNbfK6UzLli15+umnlQ/2\nww8/TLRvpUqV1EPJXJHqbvbu3ZvkuZYvXw5A/vz5XfzIq1atcrtsycE0Tt566614KQbM7RdeeCFd\nhgiboastW7ZU+zZt2uSWVbL/CkVfvnx58ubNq+KXrcLMuXLixAlKlCih0hqYcbM7duwAHLHNVixM\niUvcJeeAmiQ2J2wHDhyo4pzdNVmWNWtW3n77bbp06QKAzWajbNmy6v3Lly+rXEfe3t6UK1cOgN27\nd7NkyRK16OfIkSNqItRTmA9BU5kmZs079wUoWbKk21JdGIZBpkwOr23BggXV9UlIwZhrPapWrcqX\nX36p9nti5PEgAgMD2bp1q3p97tw5tfYjMDBQTeQuXLhQpUxILxQtWpTu3bsDjofo33//DbiuGUhL\ntI9eo9FoMjj/Cou+bt26LkPP8PBw/vnnH4/LYZ6zWbNmbNmyRVlS5nshISFA0mF37iQxi+fcuXPK\nJ25iDoUHDhzo4tt3B1WrVnVxedhsNjZu3Ag4XAgLFixQ4YsHDhygXbt2gGMVqnMqh6FDh3osMZc5\nLDdTGsyYMSPRvvXr1wccVryJ872R1ly4cAG73Q44rMlKlSoB8S364OBgatSoAcDnn3+OiKhRZ3pY\nWR4aGqruPdOF6Ow+NOeVFi5cSPv27dONRe/t7c3kyZPValgR4X//+x+A+6pjpTIscgBwCDgIfAdk\nA4oBO4E/gYVA1rQOr0ysZcuWTfz8/OK1X3/91SV8acKECeLr6+vxMCqzKMJnn30Wr0jKqVOnxNfX\n1xK54raESCzUctGiRXL27Fk5e/ZsmsthhiGePHlSbDabREdHS3R09AMLyXz22Wfy2WefxSs68tNP\nP3nsGnbt2lW6du2qzu286jluCwkJkZCQELHZbLJ3717Zu3dvoitU06qNHDlSRo4cKTExMXL58mW5\nfPmybN26VbZt2yZbt26VrVu3SmRkpAr3NEMxFyxYIAsWLPDINXxQeOXZs2dd7s/EPscsnOMJmUuW\nLCklS5aU/Pnzx3vPLCqzcOFCl/ty7dq1kiNHDsmRI0dKzpms8MoUW/SGYRQC+gHlROS2YRiLgE5A\nC+BTEVlgGMYUoAcwOaXnSS4dO3akd+/eKh90HFldJpIGDBhAixYtXKIGPv30UwC3WvrmBNZTTz0F\nONKTgsNibtWqFWPGjAEcyaWs5Ny5cwQGBiqf8YMmstxl0S9YsACAIkWK8OOPP6o5g82bNyd6TI4c\nORLNM1+7dm2aNm3K+vXr017YOJiWvIhw9epVZQknhJkW+ObNm6xevVq9didmIjNApc+tXr26y2/l\n2LFjql/9+vXp0aNHkiOTtGb79u0sXrxY+dpDQ0Nd5oF27Nih7r3kzGeYFn5aTczmzJlTzQe1bduW\nHj16qMVUNpuNiIgIli5dCjh0kDm6LFq0KIBaXNW2bduHjiR8WFLro88CZDcMIwvgA/wFPAOYSypn\nA/GTbms0Go3Gc6TSddMfiAIigHlAXuBPp/cDgYOJHNsL2B3bUjxU2r59u2zfvl1u3bqVaJ1Qu93+\nwFqipvuhSpUqaTaMMwxDGjZsKA0bNnSR79q1a1K3bl3x8vISLy8vKVq0qFy8eFGioqIkKipKnnrq\nKY8MMxNr5lA3OStjnevJpuVq45EjR8qdO3fkzp07EhERIdWrV0+yf+7cuSV37tzSq1cvl+912bJl\n8vfff8vff/8tNptNtmzZ4pFrOG/ePJk3b57ExMTInj17JGfOnAm6Y0qVKqVkPXHihBQoUEAKFChg\nyfdeuXJlqVy5coLvhYaGis1m86jrBhzuTvO36ezGCQwMjFc7NjQ0NMF7dtu2bcm+n5PTqlatKlWr\nVpVDhw6p727lypUyffp0uXbtWrxV+AnpoH79+kmuXLkkV65cqZUnWa6b1Cj5x4BNQADgBawAupJM\nRZ8WPvqOHTvKrVu3XJSomSVwy5YtD6XozTZv3rw0u0nr1avn4oc/ffq0nD59Wpo3b+7SL0+ePLJ4\n8WJ1w3700Uce+yEl1AYOHCgiSWevNOcbzB9hWvvoL1++LBERERIRESHPPffcA/vPmDFDZsyY4fJd\nnjp1Snx8fJS/PDIyUqKioqRly5bSsmVLt17D0NBQlTYgMcVYqlQpWb58udy4cUNu3LghtWvXtvR7\nf9D/42kfvdlMBe3sk3e+P815IlPZm835AZFW92fJkiXl+vXrcv36dYmIiJDOnTtL586d4/Vr2rSp\nS9ZZEXHRBV988YVHFX1qXDeNgVMiEiEi94BlQG0gd6wrB6AwkEZZeTQajUaTElITXnkWqGEYhg9w\nG2iEww3zE9AeWAB0B1amVsjEKF68uJr8MDHrce7cuVMVWjYxJzzMmrNmSoQhQ4aoPmmZNXD27Nnq\n9cmTJ1U41cmTJ136Xbt2jaCgIDUJZoa+WYUZhuY8Cea8DL1Dhw4ui6k++eQTt8hx4MABwLFaMC7m\ngqlKlSrRqVMnqlWrBjgWpf3xxx+AYwFSdHS0SnnQoUMHnnvuOQICAtwirzNm1SjDMDhy5IhaBQ33\nJz/NurCvv/46gFr8kx5xDv30NOb9GBQUpCZUJ0yYkGARlIQKcKdlzeNGjRqp77JFixZKtty5c9Og\nQQMV2hsSEoK3t7f6Ta9YsYLy5csDjuCF3r17q2v6/vvvs2vXrjSRL1FS6aP/ADiCI7zyW8AbKA78\niiO8cjHg7S7XTevWrWXTpk2yadOmJN0xIiIXL16UZ555Rp555hl1fKZMmSRTpkxqCJUrVy7JkiVL\nqoZSpt/9p59+ErvdLkeOHJEjR45I0aJFxTAMMQxDvL29pVixYipD3YEDB8Rutyv/XqtWrTw+NI/b\n4vo/EyM0NNQt5798+bL6/n744QeV6dNs9+7dk3v37qk+Bw8elIMHD0qHDh0S/cz8+fPLrl27ZOnS\npbJ06VK3Xr+3335b3n77bYmJiZHDhw+re2/VqlUqZDEmJkZGjBhh+XednDZ16lTLXDeJtdDQUBfX\nzdmzZ9W26VpM6yy1M2bMUK6bkJAQGTVqlIwaNUoiIiLiuYj3798vr732mrz22msun9G+fXu5ffu2\n6nflyhV1XUNCQuSJJ554GJmS5boxzCeOlRiGkWIh8uTJA8CkSZNo1KhRgtbawYMHGTx4sEfC6syc\nK5s3b6ZixYoqd7zzYhQ/Pz+Vf8Vkx44dqkSgu/PNJxczdM3Zeu/QoQOffPKJSmrmrkUoffr0IXv2\n7C77zNGbc81VcCS9+vzzz4GEl/E7U79+fRX66DySS2vMxW9meJ2Jc/jif/7zHzZs2OCW3CZpTWho\nKG+99RY//PAD4EjTYab0+Dfx+uuvu6SCMJMPnjx5kq1bt6pR6Pz584mOjla5/uPStGlTla+pfv36\nLuHfv//+u1rElgz2iMgDi2HrFAgajUaT0UmN6yatGmk0rKpRo4ZL9MXcuXNl7ty5lgwr+/TpI1u2\nbIm3AtZsR48elfHjx8v48eOlVKlSD1VIRbf038wiM99//72Lq2bTpk0yaNAgGTRokGTNmtVyOZPb\nzKgbs1WtWtVymaxouXLlklq1aqlWsWJFqVixYqo+s23btrJjxw7ZsWOH2Gw2GTdu3MMc797wyvSo\n6HXTTTf3tLiKfvz48ZbLpBuCrjCl0Wg0GuDRn4zVaDTup3DhwoSFhbFnzx7AkQnUXfnyNQ9FsiZj\ntaLXaDSaRxcddaPRaDQareg1Go0mw6MVvUaj0WRwtKLXaDSaDI5W9BqNRpPB0Ypeo9FoMjha0Ws0\nGk0GRyt6jUajyeBoRa/RaDQZHK3oNRqNJoOjFb1F/Pbbb+zcuZOdO3fi6+trtThJMmzYMPz8/FTp\nRY1G82ihFb1FiAiVK1emcuXKfP/991aLkyB58uQhT5489OnTh1q1alGrVi2rRcpQdO3aFZvNxqFD\nhzh06BBPPPGE1SIlSv/+/RERunTpQpcuXawWJ1lkzpyZzJkzU7x4cZdWuHBhq0XzOFrRazQaTUYn\nGUVBZgKXgYNO+/IAYcDx2L+Pxe43gM9xFAbfD1RyZ+ERu92uCiHELcCbUOvVq5f06tVLHZPS86ZF\n27dvn5Lj9u3b0qhRI6sLGMRrL7/8srz88suyceNGVUg9bp+SJUuqAslWyRkcHCzr1q2TdevWyYoV\nK5J9XKtWrSQyMlIiIyOlT58+Hpd706ZNqsj5vXv35ODBg5Z/54m1FStWiM1mk0uXLsmlS5ckODjY\ncpnitkaNGql7duXKleqecC7YbbPZJDIyUj799FP59NNPJXv27G6TJ3PmzC6vvb29xdvbW/z9/aVL\nly7Sp08f6dOnj0RFRUm/fv2kX79+KTlPsgqPZOHBfAN8Acxx2jcU2CgiHxuGMTR2ewjQHCgZ26oD\nk2P/ugW73Y7dbk/RcemJrFmzUqBAAavFiEfPnj0BWLZsWbxrZhgGAFOnTuXxxx8HYPjw4Z4VEChb\ntixr166lYMGCALz22msP7P/mm28C0K1bN0aPHg2g/oe0YuTIkQCICD///DObN29+4DGlS5dOUxnc\nQd68eQGoUqWKKoRtFfXq1VMFtgsXLkz58uVdimwnRs6cOenXrx8A+fPn58UXX0wzmSpUqADA6NGj\nyZ8/P5cuXQIgX758FClSBAB/f/94x7Vu3RpAFblPax7ouhGRLcC1OLtDgNmxr2cDbZz2zxEHO4Dc\nhmGkPw2m0Wg0/yKSY9EnRD4R+Sv29UUgX+zrQsA5p37hsfv+Ig6GYfQCeqXk5O+9915KDnNhzJgx\nqf6MlFK0aFECAgIsO39yKFKkCFWqOOoZDB48ON77Tz31FAANGjTgnXfe8ahsAJkyOWyUkSNHUrBg\nQT788EMAZs+enegxwcHBfPbZZzRs2BCA/fv3M3XqVAACAwPTTLbx48czcOBAwDF6fPPNN5Wld/78\n+QceC1hyTR8lChYsyGeffaauq8mNGzcA+Oeff9S+devWcezYMZd+7dq1Axz3b1piWustWrRI089N\nLSlV9AoRkZRUiBKRacA0ePgKUyEhIer18uXLAVixYsVDnX/lypUP1T8tyZo1K15eXsr9kSlTJurV\nq8fcuXMtkykur7zyClevXgXg4MGD8d433R8Ac+bMife+OwkODmbChAkANG7cmIkTJzJ58mQAbDZb\nosdVq1ZNKXmAo0ePEhkZCcD169fTTL64Q3M/Pz+yZEneTy2hYb2VzJs3D4Bnn30WgHPnHHbcwoUL\nLZHHdB0dOHCA3Llzq/3ffPMNYWFh/O9//wMe/ED9v//7P7fId/bsWcDxoPH29nbLOVJCShX9JcMw\nCojIX7Gumcux+88DzqZR4dh9aYqpIDNnzqyU0ZUrV5J9fHh4OLdv305rsZLNzZs3iY6OVv5Eu92u\nrOf0Qv369Vm9ejVw30pyxvzhX7hwgVu3bnlEJtOKnzBhAk2aNAEcD/gBAwYkeZypHLp37+6yv0qV\nKuozk3pAPCyGYajPfdh+5r2dXsiTJw+AUlrmiCM6OtoSefr27QuglLyp2N944w0XK94q/vzzT8Ah\nz4ABA9i1axfgePDs3r0bcMj+119/sWTJEsAxZ3DkyBG3ypXS8MrvAfNX0x1Y6bT/JcNBDeCGk4tH\no9FoNBbwQIveMIzvgAZAXsMwwoERwMfAIsMwegBngI6x3dcALXCEV0YDr7hBZpYtWwZAxYoVKVOm\nDABlypRJ9lNxzJgxbn+CJsVff/1FZGSki1+4ePHiVKpUCYC9e/daJRrgsNbr1avHu+++m+D7BQsW\nJEeOHIBjaO8p665OnToANGnSRPlcX3755SSPMQyD0NBQl+NPnDgBQNOmTdPUkgfIli0bvr6+KkrJ\nbrcze/ZsFX3hjIjEi2b6+OOP01Se1FC5cmV1zcAxEv39998tk6dGjRoMGzZMbdvtdtavXw+QLqx5\nZ2bNmsWsWbMSfb99+/bkzJlTbW/fvt29AiUnBtPdjRTGqcbExMjdu3fl7t27smnTJpk8ebJMmTJF\npkyZIl26dJGuXbtK165dJSgoSIoUKSI//fST/PTTT7Jp0yYJCgqyNOZ3wIABKo7eZrNJTEyMNGnS\nRJo0aWJ5PPLo0aNl165dicbOjxkzRux2u9jtdqlbt65HZPLz85Pz58/L+fPnxW63yxtvvCFvvPHG\nA4+bMWOGktVut8vRo0elSJEiUqRIEbfI2aBBA7l3756K2b53754MHz48wb5x4+jv3btn+XcPSEBA\ngAQEBMiPP/7oEn++bNkyS+Rp166dtGvXTo4fP+4iT2hoqOXXKiXt6aeflqtXr6p7cteuXZIlSxbJ\nkiVLSj4vzeLoHwnq1KlDnTp1lM/TecL2ypUrZMqUiZIlS6p9/v7+auLECh5mTsFTZMuWDYAOHTqw\nevXqRNcbVK1aVcn/yy+/eES2evXqqbUGP/74o5okTIqiRYvSqVMntX3q1CkGDx7MmTNn3CZnXO7c\nucPp06eT3f/bb78FHJPd5kSxpzEjUho1aqT2rV27lh49elgij7nWoXjx4i77mzVrxrZt2wBYunSp\nx+VKCDOaL1euXISEhKhJeOf4fl9fXxdrPigoSI2Q3fWd6xQIGo1Gk8F5pC36hCIbzH3OceoBAQFk\nypQp3a2IdQ6vtNvt1PtzjkYAAAuoSURBVKtXD4CwsDBL5GncuDEATzzxBFu2bEmwT3BwMLVq1VLx\n6slZiZgWdO3aVb0OCAhIcl7AtPzWrVuHj4+P6vvxxx97PKz20qVLDxV+2rlzZwCmT5+e6HfgTvz9\n/enTp4/avnnzJuCwVNMyBPVhSCwJWZkyZdR92LlzZ55//nlPipUgFy9eBKBLly4UKlRI6ZyDBw+q\ntScAUVFR6vvduHGj20dvj7SinzJlipqMdZ40goTTHKQ3Re8cXikiloZYent7q2X7t27dIk+ePCps\nMSAgQKUICAgIIHv27Dz22GOAY2l3QhONac2kSZOUO65ixYpqqfg333wDoNxyTZs2pXbt2oDDdfPL\nL7/QrVs3AI+4bEz3oWlwFCtWjKVLl3L06FEAhgwZot5L6H40HwpWKHmAr776iieffFJtmykGTBeJ\nFeTL51iPaabkMGnevLkK823bti3ff/+9SmfgqZDfuJj35eTJk8mcObNyg3Xo0MFF0Q8bNowvv/zS\nY3Jp141Go9FkcAxPDb2TFCIFK2tNErPo4xIUFKSsE4Dq1auzb9++lJ4WcFiMDzPR5kyBAgVUUqjH\nHnsMEVEum+bNm6dKrpTg5+fnMjS32+0uIWuXLzvWxAUFBbks6rl586aylDt37pzgKtq04rPPPgNQ\nCakexKlTpwgODvbo4p4TJ04QFBSUpNWe1HuffvopkHDaCXdTs2ZN1qxZQ65cuQDHymFzAtbt4X8p\nIFeuXGpSvkWLFhiGoUZ9q1atslI0hWm1v/HGG2oE379/fyZNmpRWp9gjIg90BTzSrhtAxcM/KC7e\nx8dHvXZW+KkhNcPDv/76i7t376aJHGnBrVu3GDJkCOC4VsuXL1fuhixZsqj/dfv27VSoUEEpW7vd\nripPlS5d2q2K3nQlHTx4UEViGIahVr6amMvfGzdu7PEVnJ9//jndunWjYsWKKTreCgVv8tRTT5Er\nVy6ioqIAx0pic2VneuTvv/9m+vTpQPrLLQOONR5vvPEG4HDTmu64NFTyyeaRV/TJJTo6mvDwcMAx\nKWcqsdQQERGRquPNJdEtW7ZMtSypJSYmRi1vj8s///xD5syZAYe/9MSJE8yYMcOT4gH35zRmzJih\nzp8lSxYmTJigHjw2m00tOjp16pTHZZw4cSJz586lWbNmgCNnENxX4FevXlWThmPHjvW4fAlRs2ZN\n4P5iLVO5p2clnxDpKX1EmTJlmDp1qrpnt2/fTq9eKcrhmCZoH71Go9FkcP41Fr0zQ4cOZfny5an2\n0acWcyjXqlUr7HZ7urJI4mKGLBYtWpStW7daLM19vLy8XEIvp0+f7tFohoS4evWq8h0ntLDr2rW4\n5R2sJTg4GHD4vC9cuJCsxWjupEaNGur1/v37k+1+Sw/zjdmzZwdgzZo1eHl5qe/ajASzin+lok9v\nmOGV6eFGTQ47duywWgTF3LlzyZMnjwr/S6v5F0+QnAyX7qZmzZouLqQ2bdqwZ88eCyW6H8opIhw4\ncICNGzcCsH79en788UcrRUsSHx8f5s+fDzgMInBdoW8l1t9pGo1Go3Er2qK3kJSGZlqBOfQ0DMOy\nFZLOmMP75s2bExUVxeuvvw44IjEeFdLDAr633npLhVPu2LHD0vxPJmatiGzZshEcHKxcS3379lX3\n3r1795g8ebIKrwbHCCCts5Eml4IFC/Lll1+q2q9nzpyhd+/e6cbN+a9U9JkyZWLXrl3JrvrjLszC\nA48CZhImEeHXX3+1WJr7hU+yZcvGt99+69awTk/Spo2j/PLDVkx7WLJmzQo4qm7FxMQA0KtXr1RH\nkqUF+fPnB2DQoEG0bdtWrYwNCAhwSW1ilo80OX36NGvWrPGcoE6sXLmSypUrq+2OHTumq6ilf6Wi\nTw+WVEJMmTLFahESpXTp0oAjfNHqxTNPPvkkI0aMABzW36hRo1SN1ZdeeklZgOkV8wF/4MABFi5c\nSJEiRdR748aNA+CPP/5IkxDgxDDTQgQFBbFgwQIADh065LbzPQxmfp2RI0cycuRISpQoATjqGHfs\n6Ch9kT9/fipUqEBQUBDgCKV96aWXPC6rmXPerF1rphGxep4jLtpHr9FoNBkdq4uOpKbwyMO2Xr16\nSa9evVSxEk+dV7e0beXLl3cpJmK32+XChQty4cIFefrppy2X72HaiBEjVNGRKVOmSO3ataV27dpu\nP2+PHj2kR48eYrPZpFSpUlKqVCnLr8XDtrx580qxYsWkWLFi4u/v7/Hzf/TRR+q7s9vtcuzYMfHx\n8REfHx9PypGswiOPfK6bh8FcmfbFF18A9/2UmkcLb29vlSWwZ8+eXLlyRa1EtboMoybjY078f/HF\nFypE9tChQ/To0cOK+atk5bp5oKI3DGMm8BxwWUTKx+4bD7QC7gIngFdEJDL2vWFAD8AG9BOR9Q8U\nwkOKXqPRaDIYyVL0yfHRfwM0i7MvDCgvIk8Bx4BhAIZhlAM6AU/GHvOVYRiZH0JojUaj0aQxD1T0\nIrIFuBZn348iEhO7uQMwS8CEAAtE5B8ROQX8CVRLQ3k1Go1G85CkRdTNq8Da2NeFgHNO74XH7tNo\nNBqNRaQqjt4wjPeAGOChsyAZhtELsC5vp0aj0fxLSLGiNwzjZRyTtI3k/ozueSDQqVvh2H3xEJFp\nwLTYz9KTsRqNRuMmUqToDcNoBgwG6ouIcw7R74H5hmF8AhQESgLJiTe6AtyK/au5T170NYmLvibx\n0dckYf4N16VIcjo9UNEbhvEd0ADIaxhGODACR5SNNxAWm0N9h4i8LiKHDMNYBBzG4dLpIyIPzDIk\nIgGGYexOTpjQvwl9TeKjr0l89DVJGH1d7vNARS8iLyaw++sk+o8GRqdGKI1Go9GkHTrXjUaj0WRw\n0pOin2a1AOkQfU3io69JfPQ1SRh9XWJJF7luNBqNRuM+0pNFr9FoNBo3YLmiNwyjmWEYRw3D+NMw\njEensrMbMAzjtGEYBwzD+M0wjN2x+/IYhhFmGMbx2L+PWS2nOzEMY6ZhGJcNwzjotC/Ba2A4+Dz2\n3tlvGEYl6yR3H4lck5GGYZyPvVd+MwyjhdN7w2KvyVHDMJpaI7V7MQwj0DCMnwzDOGwYxiHDMPrH\n7v9X3yuJYamij0149iXQHCgHvBibGO3fTEMRedopLGwosFFESgIbY7czMt8QP4leYtegOY61GiVx\nrLKe7CEZPc03xL8mAJ/G3itPi8ga+FclFowBBolIOaAG0Cf2f/+33ysJYrVFXw34U0ROishdYAGO\nxGia+4QAs2NfzwbaWCiL20koiR6JX4MQYI442AHkNgyjgGck9RyJXJPE+FckFhSRv0Rkb+zrm8Af\nOPJq/avvlcSwWtHrJGiuCPCjYRh7YnMBAeQTkb9iX18E8lkjmqUkdg3+7fdP31g3xEwnl96/7poY\nhlEUqAjsRN8rCWK1ote4UkdEKuEYZvYxDKOe85uxOYX+1WFS+hooJgMlgKeBv4D/b++OVRoIgjCO\n/6dQC7XRyjKCb2BhYS2Yzs7KFL6AfZ5BX0CsRKxUTO0b2GhURCSlRdJpKzoWuyFBcpAmDm6+HywX\ncldMhmFgN3d7h7HhxDCzBeACOHD3j+FzqpWB6EY/9iZo08Dd3/KxB1yRptzd/hQzH3txEYapysHU\n1o+7d939y92/gWMGyzNTkxMzmyE1+TN3v8xfq1ZGiG70t8CamdXMbJb0J1IrOKYQZjZvZov9z8AW\n8EjKRyNf1gCuYyIMVZWDFrCX76jYAN6Hpu1F+7W+vEOqFUg52TWzOTOrMf7Ggv+KpU22ToBndz8a\nOqVaGWWcN4hPcgB10usIO0AzOp7APKwC93k89XMBLJPuHngFboCl6FgnnIdz0lLEJ2kddb8qB4CR\n7trqAA/AenT8f5iT0/yb26QmtjJ0fTPn5AXYjo5/QjnZJC3LtIG7POrTXitVQ0/GiogULnrpRkRE\nJkyNXkSkcGr0IiKFU6MXESmcGr2ISOHU6EVECqdGLyJSODV6EZHC/QCu9GT091l2HwAAAABJRU5E\nrkJggg==\n", 342 | "text/plain": [ 343 | "
" 344 | ] 345 | }, 346 | "metadata": { 347 | "tags": [] 348 | } 349 | } 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "XFWll5Lseiht", 356 | "colab_type": "text" 357 | }, 358 | "source": [ 359 | "**EXERCISE:** Try to understand what the code above is doing. This will help you to better understand your dataset before moving forward. " 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": { 365 | "id": "d9mXAVmRvhrq", 366 | "colab_type": "text" 367 | }, 368 | "source": [ 369 | "Let's check the dimensions of a batch." 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "metadata": { 375 | "id": "cNFKWz1GZ4R5", 376 | "colab_type": "code", 377 | "outputId": "cc1fd627-b8b0-42d4-d1a7-cd1eeaefc7fb", 378 | "colab": { 379 | "base_uri": "https://localhost:8080/", 380 | "height": 52 381 | } 382 | }, 383 | "source": [ 384 | "for images, labels in trainloader:\n", 385 | " print(\"Image batch dimensions:\", images.shape)\n", 386 | " print(\"Image label dimensions:\", labels.shape)\n", 387 | " break" 388 | ], 389 | "execution_count": 0, 390 | "outputs": [ 391 | { 392 | "output_type": "stream", 393 | "text": [ 394 | "Image batch dimensions: torch.Size([32, 1, 28, 28])\n", 395 | "Image label dimensions: torch.Size([32])\n" 396 | ], 397 | "name": "stdout" 398 | } 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": { 404 | "id": "tmaCTw5tXowR", 405 | "colab_type": "text" 406 | }, 407 | "source": [ 408 | "## The Model\n", 409 | "Now using the classical deep learning framework pipeline, let's build the 1 convolutional layer model. \n", 410 | "\n", 411 | "Here are a few notes for those who are beginning with PyTorch:\n", 412 | "- The model below consists of an `__init__()` portion which is where you include the layers and components of the neural network. In our model, we have a convolutional layer denoted by `nn.Conv2d(...)`. We are dealing with an image dataset that is in a grayscale so we only need one channel going in, hence `in_channels=1`. We hope to get a nice representation of this layer, so we use `out_channels=32`. Kernel size is 3, and for the rest of parameters we use the default values which you can find [here](https://pytorch.org/docs/stable/nn.html?highlight=conv2d#conv2d). \n", 413 | "- We use 2 back to back dense layers or what we refer to as linear transformations to the incoming data. Notice for `d1` I have a dimension which looks like it came out of nowhere. 128 represents the size we want as output and the (`26*26*32`) represents the dimension of the incoming data. If you would like to find out how to calculate those numbers refer to the [PyTorch documentation](https://pytorch.org/docs/stable/nn.html?highlight=linear#conv2d). In short, the convolutional layer transforms the input data into a specific dimension that has to be considered in the linear layer. The same applies for the second linear transformation (`d2`) where the dimension of the output of the previous linear layer was added as `in_features=128`, and `10` is just the size of the output which also corresponds to the number of classes.\n", 414 | "- After each one of those layers, we also apply an activation function such as `ReLU`. For prediction purposes, we then apply a `softmax` layer to the last transformation and return the output of that. " 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "metadata": { 420 | "id": "_IYnV4ZBa3cJ", 421 | "colab_type": "code", 422 | "colab": {} 423 | }, 424 | "source": [ 425 | "class MyModel(nn.Module):\n", 426 | " def __init__(self):\n", 427 | " super(MyModel, self).__init__()\n", 428 | "\n", 429 | " # 28x28x1 => 26x26x32\n", 430 | " self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)\n", 431 | " self.d1 = nn.Linear(26 * 26 * 32, 128)\n", 432 | " self.d2 = nn.Linear(128, 10)\n", 433 | "\n", 434 | " def forward(self, x):\n", 435 | " # 32x1x28x28 => 32x32x26x26\n", 436 | " x = self.conv1(x)\n", 437 | " x = F.relu(x)\n", 438 | "\n", 439 | " # flatten => 32 x (32*26*26)\n", 440 | " x = x.flatten(start_dim = 1)\n", 441 | "\n", 442 | " # 32 x (32*26*26) => 32x128\n", 443 | " x = self.d1(x)\n", 444 | " x = F.relu(x)\n", 445 | "\n", 446 | " # logits => 32x10\n", 447 | " logits = self.d2(x)\n", 448 | " out = F.softmax(logits, dim=1)\n", 449 | " return out" 450 | ], 451 | "execution_count": 0, 452 | "outputs": [] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": { 457 | "id": "evsFbkq_X6bc", 458 | "colab_type": "text" 459 | }, 460 | "source": [ 461 | "As I have done in my previous tutorials, I always encourage to test the model with 1 batch to ensure that the output dimensions are what we expect. " 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "metadata": { 467 | "id": "1poxFYqftKov", 468 | "colab_type": "code", 469 | "outputId": "0a845d9b-54c8-43b9-c3d6-1abc1b7a4f28", 470 | "colab": { 471 | "base_uri": "https://localhost:8080/", 472 | "height": 52 473 | } 474 | }, 475 | "source": [ 476 | "## test the model with 1 batch\n", 477 | "model = MyModel()\n", 478 | "for images, labels in trainloader:\n", 479 | " print(\"batch size:\", images.shape)\n", 480 | " out = model(images)\n", 481 | " print(out.shape)\n", 482 | " break" 483 | ], 484 | "execution_count": 0, 485 | "outputs": [ 486 | { 487 | "output_type": "stream", 488 | "text": [ 489 | "batch size: torch.Size([32, 1, 28, 28])\n", 490 | "torch.Size([32, 10])\n" 491 | ], 492 | "name": "stdout" 493 | } 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": { 499 | "id": "9h_3eZQRHV_P", 500 | "colab_type": "text" 501 | }, 502 | "source": [ 503 | "## Training the Model\n", 504 | "Now we are ready to train the model but before that we are going to setup a loss function, an optimizer and a function to compute accuracy of the model. " 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "metadata": { 510 | "id": "3_0Vjq2RHlph", 511 | "colab_type": "code", 512 | "colab": {} 513 | }, 514 | "source": [ 515 | "learning_rate = 0.001\n", 516 | "num_epochs = 5\n", 517 | "\n", 518 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 519 | "model = MyModel()\n", 520 | "model = model.to(device)\n", 521 | "criterion = nn.CrossEntropyLoss()\n", 522 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)" 523 | ], 524 | "execution_count": 0, 525 | "outputs": [] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "metadata": { 530 | "id": "44IdrNNeIi_I", 531 | "colab_type": "code", 532 | "colab": {} 533 | }, 534 | "source": [ 535 | "## compute accuracy\n", 536 | "def get_accuracy(logit, target, batch_size):\n", 537 | " ''' Obtain accuracy for training round '''\n", 538 | " corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()\n", 539 | " accuracy = 100.0 * corrects/batch_size\n", 540 | " return accuracy.item()" 541 | ], 542 | "execution_count": 0, 543 | "outputs": [] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "metadata": { 548 | "id": "nK3EcuIOISSR", 549 | "colab_type": "text" 550 | }, 551 | "source": [ 552 | "Now it's time for training." 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "metadata": { 558 | "id": "E59hwZlAIVcL", 559 | "colab_type": "code", 560 | "outputId": "ab16b14b-8a6e-4568-8500-2f2f5b447a93", 561 | "colab": { 562 | "base_uri": "https://localhost:8080/", 563 | "height": 105 564 | } 565 | }, 566 | "source": [ 567 | "for epoch in range(num_epochs):\n", 568 | " train_running_loss = 0.0\n", 569 | " train_acc = 0.0\n", 570 | "\n", 571 | " model = model.train()\n", 572 | "\n", 573 | " ## training step\n", 574 | " for i, (images, labels) in enumerate(trainloader):\n", 575 | " \n", 576 | " images = images.to(device)\n", 577 | " labels = labels.to(device)\n", 578 | "\n", 579 | " ## forward + backprop + loss\n", 580 | " logits = model(images)\n", 581 | " loss = criterion(logits, labels)\n", 582 | " optimizer.zero_grad()\n", 583 | " loss.backward()\n", 584 | "\n", 585 | " ## update model params\n", 586 | " optimizer.step()\n", 587 | "\n", 588 | " train_running_loss += loss.detach().item()\n", 589 | " train_acc += get_accuracy(logits, labels, BATCH_SIZE)\n", 590 | " \n", 591 | " model.eval()\n", 592 | " print('Epoch: %d | Loss: %.4f | Train Accuracy: %.2f' \\\n", 593 | " %(epoch, train_running_loss / i, train_acc/i)) " 594 | ], 595 | "execution_count": 0, 596 | "outputs": [ 597 | { 598 | "output_type": "stream", 599 | "text": [ 600 | "Epoch: 0 | Loss: 1.4901 | Train Accuracy: 96.97\n", 601 | "Epoch: 1 | Loss: 1.4808 | Train Accuracy: 97.90\n", 602 | "Epoch: 2 | Loss: 1.4767 | Train Accuracy: 98.34\n", 603 | "Epoch: 3 | Loss: 1.4748 | Train Accuracy: 98.55\n", 604 | "Epoch: 4 | Loss: 1.4725 | Train Accuracy: 98.81\n" 605 | ], 606 | "name": "stdout" 607 | } 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "metadata": { 613 | "id": "QuZxfQc1UIU-", 614 | "colab_type": "text" 615 | }, 616 | "source": [ 617 | "We can also compute accuracy on the testing dataset to see how well the model performs on the image classificaiton task. As you can see below, our basic CNN model is performing very well on the MNIST classification task." 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "metadata": { 623 | "id": "YU5WR0BTUHv1", 624 | "colab_type": "code", 625 | "outputId": "e0f48883-e06a-4108-a933-0f33b2e56b4f", 626 | "colab": { 627 | "base_uri": "https://localhost:8080/", 628 | "height": 34 629 | } 630 | }, 631 | "source": [ 632 | "test_acc = 0.0\n", 633 | "for i, (images, labels) in enumerate(testloader, 0):\n", 634 | " images = images.to(device)\n", 635 | " labels = labels.to(device)\n", 636 | " outputs = model(images)\n", 637 | " test_acc += get_accuracy(outputs, labels, BATCH_SIZE)\n", 638 | " \n", 639 | "print('Test Accuracy: %.2f'%( test_acc/i))" 640 | ], 641 | "execution_count": 0, 642 | "outputs": [ 643 | { 644 | "output_type": "stream", 645 | "text": [ 646 | "Test Accuracy: 98.04\n" 647 | ], 648 | "name": "stdout" 649 | } 650 | ] 651 | }, 652 | { 653 | "cell_type": "markdown", 654 | "metadata": { 655 | "id": "BZz7LAewgGAK", 656 | "colab_type": "text" 657 | }, 658 | "source": [ 659 | "**EXERCISE:** As a way to practise, try to include the testing part inside the code where I was outputing the training accuracy, so that you can also keep testing the model on the testing data as you proceed with the training steps. This is useful as sometimes you don't want to wait until your model has completed training to actually test the model with the testing data." 660 | ] 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "metadata": { 665 | "id": "uLQlqGPsVjOB", 666 | "colab_type": "text" 667 | }, 668 | "source": [ 669 | "## Final Words\n", 670 | "That's it for this tutorial! Congratulations! You are now able to implement a basic CNN model in PyTorch for image classification. If you would like, you can further extend the CNN model by adding more convolution layers and max pooling, but as you saw, you don't really need it here as results look good. If you are interested in implementing a similar image classification model using RNNs see the references below. " 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "metadata": { 676 | "id": "ztAiTq9HcS_H", 677 | "colab_type": "text" 678 | }, 679 | "source": [ 680 | "## References\n", 681 | "- [Building RNNs is Fun with PyTorch and Google Colab](https://colab.research.google.com/drive/1NVuWLZ0cuXPAtwV4Fs2KZ2MNla0dBUas)\n", 682 | "- [CNN Basics with PyTorch by Sebastian Raschka](https://github.com/rasbt/deeplearning-models/blob/master/pytorch_ipynb/cnn/cnn-basic.ipynb)\n", 683 | "- [Tensorflow 2.0 Quickstart for experts](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/quickstart/advanced.ipynb#scrollTo=DUNzJc4jTj6G) " 684 | ] 685 | } 686 | ] 687 | } --------------------------------------------------------------------------------