├── README.md └── mnist.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## Quickstart 2 | 3 | Quickstart project for FloydHub using Pytorch 4 | 5 | [Welcome on FloydHub :)](https://www.floydhub.com/welcome) 6 | -------------------------------------------------------------------------------- /mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quick Start PyTorch - MNIST\n", 8 | "\n", 9 | "To run a Code Cell you can click on the `⏯ Run` button in the Navigation Bar above or type `Shift + Enter`" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Populating the interactive namespace from numpy and matplotlib\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%pylab inline\n", 27 | "import numpy as np\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import torch.nn.functional as F\n", 31 | "import torch.utils.data.dataloader as dataloader\n", 32 | "import torch.optim as optim\n", 33 | "\n", 34 | "from torch.utils.data import TensorDataset\n", 35 | "from torch.autograd import Variable\n", 36 | "from torchvision import transforms\n", 37 | "from torchvision.datasets import MNIST\n", 38 | "\n", 39 | "SEED = 1\n", 40 | "\n", 41 | "# CUDA?\n", 42 | "cuda = torch.cuda.is_available()\n", 43 | "\n", 44 | "# For reproducibility\n", 45 | "torch.manual_seed(SEED)\n", 46 | "\n", 47 | "if cuda:\n", 48 | " torch.cuda.manual_seed(SEED)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 61 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 62 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", 63 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 64 | "Processing...\n", 65 | "Done!\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "train = MNIST('./data', train=True, download=True, transform=transforms.Compose([\n", 71 | " transforms.ToTensor(), # ToTensor does min-max normalization. \n", 72 | "]), )\n", 73 | "\n", 74 | "test = MNIST('./data', train=False, download=True, transform=transforms.Compose([\n", 75 | " transforms.ToTensor(), # ToTensor does min-max normalization. \n", 76 | "]), )\n", 77 | "\n", 78 | "# Create DataLoader\n", 79 | "dataloader_args = dict(shuffle=True, batch_size=256,num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)\n", 80 | "train_loader = dataloader.DataLoader(train, **dataloader_args)\n", 81 | "test_loader = dataloader.DataLoader(test, **dataloader_args)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "[Train]\n", 94 | " - Numpy Shape: (60000, 28, 28)\n", 95 | " - Tensor Shape: torch.Size([60000, 28, 28])\n", 96 | " - min: 0.0\n", 97 | " - max: 1.0\n", 98 | " - mean: 0.13066047740240005\n", 99 | " - std: 0.3081078089011192\n", 100 | " - var: 0.0949304219058486\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "train_data = train.train_data\n", 106 | "train_data = train.transform(train_data.numpy())\n", 107 | "\n", 108 | "print('[Train]')\n", 109 | "print(' - Numpy Shape:', train.train_data.cpu().numpy().shape)\n", 110 | "print(' - Tensor Shape:', train.train_data.size())\n", 111 | "print(' - min:', torch.min(train_data))\n", 112 | "print(' - max:', torch.max(train_data))\n", 113 | "print(' - mean:', torch.mean(train_data))\n", 114 | "print(' - std:', torch.std(train_data))\n", 115 | "print(' - var:', torch.var(train_data))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## Model" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# One hidden Layer NN\n", 132 | "class Model(nn.Module):\n", 133 | " def __init__(self):\n", 134 | " super(Model, self).__init__()\n", 135 | " self.fc = nn.Linear(784, 1000)\n", 136 | " self.fc2 = nn.Linear(1000, 10)\n", 137 | "\n", 138 | " def forward(self, x):\n", 139 | " x = x.view((-1, 784))\n", 140 | " h = F.relu(self.fc(x))\n", 141 | " h = self.fc2(h)\n", 142 | " return F.log_softmax(h) \n", 143 | " \n", 144 | " \n", 145 | "model = Model()\n", 146 | "if cuda:\n", 147 | " model.cuda() # CUDA!\n", 148 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Train\n", 156 | "\n", 157 | "Training time:\n", 158 | "\n", 159 | "- CPU, about 1 minute and 30 seconds\n", 160 | "- GPU, about 10 seconds" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 5, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | " Train Epoch: 1/5 [60000/60000 (100%)]\tLoss: 0.282395\t Test Accuracy: 94.8300%\n", 173 | " Train Epoch: 2/5 [60000/60000 (100%)]\tLoss: 0.239222\t Test Accuracy: 96.7600%\n", 174 | " Train Epoch: 3/5 [60000/60000 (100%)]\tLoss: 0.055666\t Test Accuracy: 97.3800%\n", 175 | " Train Epoch: 4/5 [60000/60000 (100%)]\tLoss: 0.124187\t Test Accuracy: 97.7700%\n", 176 | " Train Epoch: 5/5 [60000/60000 (100%)]\tLoss: 0.008199\t Test Accuracy: 97.8100%\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "EPOCHS = 5\n", 182 | "losses = []\n", 183 | "\n", 184 | "model.train()\n", 185 | "for epoch in range(EPOCHS):\n", 186 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 187 | " # Get Samples\n", 188 | " data, target = Variable(data), Variable(target)\n", 189 | " \n", 190 | " if cuda:\n", 191 | " data, target = data.cuda(), target.cuda()\n", 192 | " \n", 193 | " # Init\n", 194 | " optimizer.zero_grad()\n", 195 | "\n", 196 | " # Predict\n", 197 | " y_pred = model(data) \n", 198 | "\n", 199 | " # Calculate loss\n", 200 | " loss = F.cross_entropy(y_pred, target)\n", 201 | " losses.append(loss.cpu().data[0])\n", 202 | " # Backpropagation\n", 203 | " loss.backward()\n", 204 | " optimizer.step()\n", 205 | " \n", 206 | " \n", 207 | " # Display\n", 208 | " if batch_idx % 100 == 1:\n", 209 | " print('\\r Train Epoch: {}/{} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 210 | " epoch+1,\n", 211 | " EPOCHS,\n", 212 | " batch_idx * len(data), \n", 213 | " len(train_loader.dataset),\n", 214 | " 100. * batch_idx / len(train_loader), \n", 215 | " loss.cpu().data[0]), \n", 216 | " end='')\n", 217 | " # Eval\n", 218 | " evaluate_x = Variable(test_loader.dataset.test_data.type_as(torch.FloatTensor()))\n", 219 | " evaluate_y = Variable(test_loader.dataset.test_labels)\n", 220 | " if cuda:\n", 221 | " evaluate_x, evaluate_y = evaluate_x.cuda(), evaluate_y.cuda()\n", 222 | "\n", 223 | " model.eval()\n", 224 | " output = model(evaluate_x)\n", 225 | " pred = output.data.max(1)[1]\n", 226 | " d = pred.eq(evaluate_y.data).cpu()\n", 227 | " accuracy = d.sum()/d.size()[0]\n", 228 | " \n", 229 | " print('\\r Train Epoch: {}/{} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\t Test Accuracy: {:.4f}%'.format(\n", 230 | " epoch+1,\n", 231 | " EPOCHS,\n", 232 | " len(train_loader.dataset), \n", 233 | " len(train_loader.dataset),\n", 234 | " 100. * batch_idx / len(train_loader), \n", 235 | " loss.cpu().data[0],\n", 236 | " accuracy*100,\n", 237 | " end=''))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 6, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "[]" 249 | ] 250 | }, 251 | "execution_count": 6, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | }, 255 | { 256 | "data": { 257 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VNX9x/H3d2YSdmQLiGwBARUrCiIiLnVBRbRal1bc\ntVrbWlurVkWt1mpd21r1516wLnWra10AFUXEBSQouywhbEGWsIVACNnO74+5M0ySWRIIJHf8vJ6H\nhzt3bu49lxs+c+bcc88x5xwiIpJ+Ag1dABER2T0U8CIiaUoBLyKSphTwIiJpSgEvIpKmFPAiImlK\nAS8ikqYU8CIiaUoBLyKSpkINdeAOHTq47Ozshjq8iIgvTZ8+fZ1zLqs22zZYwGdnZ5OTk9NQhxcR\n8SUzW1bbbdVEIyKSphTwIiJpSgEvIpKmFPAiImlKAS8ikqYU8CIiaUoBLyKSpnwX8AtWF/GPDxew\nfsv2hi6KiEij5ruAX1ywhf/7JJd1W0obuigiIo2a7wI+FDAAyioqG7gkIiKNm+8CPiMYLrICXkQk\nOR8HvGvgkoiING6+C/hQMNxEU64avIhIUr4L+GgNvlI1eBGRZHwY8N5N1nLV4EVEkvFdwIcC4SKX\nVyrgRUSS8V3AZ4bCNfhS3WQVEUnKdwEfrcHrJquISFK+C/iMUCTgVYMXEUnGfwEfiDTRqAYvIpKM\n7wI+FFQTjYhIbfgu4KPdJNVEIyKSlA8DPvKgk2rwIiLJ+DbgdZNVRCQ53wV8MGCYaTRJEZFUfBfw\nABmBgNrgRURS8GfAB001eBGRFHwZ8KFgQN0kRURS8GXAZwQDGi5YRCQFnwa8abhgEZEUUga8mXUz\ns4lmNs/M5prZNXG2MTN7xMxyzWyWmQ3cPcUNCwWNctXgRUSSCtVim3LgeufcN2bWCphuZh855+bF\nbHMK0Mf7czjwhPf3bpERDGgsGhGRFFLW4J1zq5xz33jLRcB3QJdqm50BPO/CpgBtzKxzvZfWkxHQ\nTVYRkVTq1AZvZtnAAGBqtbe6ACtiXudT80Og3mSETE+yioikUOuAN7OWwBvAH5xzm3fmYGZ2pZnl\nmFlOQUHBzuwCCE/6oSYaEZHkahXwZpZBONxfdM69GWeTlUC3mNddvXVVOOeeds4Ncs4NysrK2pny\nAuFeNKrBi4gkV5teNAaMAb5zzj2YYLN3gIu93jRDgELn3Kp6LGcVGcGAnmQVEUmhNr1ojgQuAmab\n2Qxv3S1AdwDn3JPAWGAEkAsUA5fVf1F3CAUDbC2t2J2HEBHxvZQB75z7HLAU2zjgt/VVqFQyg6Ze\nNCIiKfjySdZQQE00IiKp+DLgM0IB3WQVEUnBnwEfMHWTFBFJwZcBH1I3SRGRlHwZ8BnBAOWadFtE\nJCnfBnyphgsWEUnKpwGv4YJFRFLxZcCH9CSriEhKvgz4jIBRVuEIP18lIiLx+DPgg+Fiq5lGRCQx\nXwZ8KBLw6iopIpKQLwM+IxgeGqdMXSVFRBLyacCHi12mrpIiIgn5MuBDXg1ebfAiIon5MuAjNXg9\n7CQikphPA141eBGRVHwa8JFeNKrBi4gk4suADwW8JhoFvIhIQr4M+GgTjfrBi4gk5NOA97pJqgYv\nIpKQLwM+0k2yTDV4EZGEfBnwmdGxaFSDFxFJxJcBH1ITjYhISv4M+ICaaEREUvFlwGeGVIMXEUnF\nlwEfqcGrm6SISGK+DHh1kxQRSc3nAa8avIhIIr4M+B3DBasGLyKSiC8DXsMFi4ik5tOA13DBIiKp\n+DTgNVywiEgqvgz4SDfJUt1kFRFJyJcBb2aEAqYavIhIEr4MeAg306gfvIhIYikD3syeMbO1ZjYn\nwfvHmlmhmc3w/txe/8WsKRQ09YMXEUkiVIttngUeBZ5Pss1k59xp9VKiWsoMBtQPXkQkiZQ1eOfc\nZ8CGPVCWOgkFjbJy1eBFRBKprzb4I8xsppmNM7MDE21kZleaWY6Z5RQUFOzSAUOBAGWqwYuIJFQf\nAf8N0MM5dzDwf8DbiTZ0zj3tnBvknBuUlZW1SwfNDAXUBi8iksQuB7xzbrNzbou3PBbIMLMOu1yy\nFNRNUkQkuV0OeDPb28zMWx7s7XP9ru43lXA3SdXgRUQSSdmLxsxeBo4FOphZPvBnIAPAOfckcA7w\nGzMrB7YBI51zuz15M4KmfvAiIkmkDHjn3Hkp3n+UcDfKPSqkbpIiIkn5+ElWdZMUEUnGxwGvbpIi\nIsn4OuA16baISGK+DfhQQDdZRUSS8W3AZ4Q0mqSISDL+DfiARpMUEUnGtwEfCgb0JKuISBK+Dfhw\nLxrV4EVEEvFxwOsmq4hIMj4OeHWTFBFJxrcBHwoaparBi4gk5NuAzwjoJquISDL+DfhggEoHFbrR\nKiISl28DPhQ0AN1oFRFJwLcBnxkMF71cNXgRkbh8G/DRGny5avAiIvH4OODDRdeQwSIi8fk24DOj\nbfBqohERice3AR8KeG3wuskqIhKXbwM+I+Q10agGLyISl38DPqBukiIiyfg24CM3WTUejYhIfL4N\n+AzvJqvGoxERic/HAa+brCIiyfg/4PUkq4hIXL4N+JCaaEREkvJtwGcEdJNVRCQZ/wZ8SN0kRUSS\n8W3AR55kVcCLiMTn24DPVD94EZGkfBvwmvBDRCQ5/we8ukmKiMTl24CPNNFowg8Rkfh8G/DRsWg0\n4YeISFwpA97MnjGztWY2J8H7ZmaPmFmumc0ys4H1X8yaMjThh4hIUrWpwT8LDE/y/ilAH+/PlcAT\nu16s1DLUTVJEJKmUAe+c+wzYkGSTM4DnXdgUoI2Zda6vAiYSCBgBUzdJEZFE6qMNvguwIuZ1vreu\nBjO70sxyzCynoKBglw+cEQyoBi8iksAevcnqnHvaOTfIOTcoKytrl/cXDnjV4EVE4qmPgF8JdIt5\n3dVbt9tlBE29aEREEqiPgH8HuNjrTTMEKHTOraqH/aYUUhONiEhCoVQbmNnLwLFABzPLB/4MZAA4\n554ExgIjgFygGLhsdxW2uoyAqYlGRCSBlAHvnDsvxfsO+G29lagOMkKqwYuIJOLbJ1kBggGjpKyi\noYshItIopazBN2Z5BVvJK9hKRaUjGLCGLo6ISKPi6xp8hJppRERqSouAr9CQwSIiNaRFwJcr4EVE\navB1wO/duimgGryISDy+Dvirj+8NaEx4EZF4fB3wIa/njGrwIiI1+TrgI10jNWSwiEhNvg74yMTb\nqsGLiNTk64APBiLzsirgRUSq83XAqw1eRCQxXwd8tA1evWhERGrwdcCrBi8ikpivA35HDV4BLyJS\nna8DPuTdZFUNXkSkJl8HvPrBi4gk5uuAj/SDr3QKeBGR6nwd8M0zgwBs2FrawCUREWl8fB3wfTu1\nIiNozFu1uaGLIiLS6Pg64DOCAVo1zaCopKyhiyIi0uj4OuABmmUEKS7VxNsiItX5PuBbNAlSvF0B\nLyJSne8DvnlmiK2l5Q1dDBGRRicNAj7INjXRiIjU4PuAb5YRJGfZRnWVFBGpxvcB36VtMwD+N2Nl\nA5dERKRx8X3ARybe/su78ygpU1ONiEiE7wO+Q4sm0eX9bxvfgCUREWlcfB/wAW/AMRERqcr3AQ9w\nWv/ODV0EEZFGJy0Cvm3zzBrrKiud2uRF5ActLQL+mmF9aqy787157H/beMoqNF+riPwwpUXAt2wS\nqrHulWnLARTwIvKDVauAN7PhZrbAzHLNbFSc9y81swIzm+H9uaL+i5pYk1DN0zA0X6uI/LDVrPpW\nY2ZB4DHgRCAfmGZm7zjn5lXb9FXn3NW7oYwpmSXuSaPp/ETkh6o2NfjBQK5zLs85Vwq8Apyxe4u1\n6yKZX64mGhH5gapNwHcBVsS8zvfWVXe2mc0ys9fNrFu9lK4OendsCcBH89YAEPASvixBE83guyfw\nr8/y9kzhREQaQH3dZH0XyHbO9Qc+Ap6Lt5GZXWlmOWaWU1BQUE+HDhvxo70B+OXzOfz8qa+INNok\nqsGvLdrO3WO/q9cyiIg0JrUJ+JVAbI28q7cuyjm33jm33Xs5Gjg03o6cc0875wY55wZlZWXtTHkT\nKtiyPbr89ZIN0eWyOG3wzqldXkTSX20CfhrQx8x6mlkmMBJ4J3YDM4t9lPR0YI9XjasHeVllpbe+\nZg0+XuiLiKSblL1onHPlZnY18AEQBJ5xzs01szuBHOfcO8Dvzex0oBzYAFy6G8scV0W1tvaSsnCw\nV+9FU15RyUtTl+2xcomINJSUAQ/gnBsLjK227vaY5ZuBm+u3aHWTqL/7jPxNPDEpl94dW3Hh4d15\nf/Yq/vJu9R6eIiLpp1YB7weVCQL+trfneEurmZK3niN6td9zhRIRaUBpMVQBQHll6v7uazeXMPf7\nwj1QGhGRhpc2NfjqbfDxrN5cwtL1xXugNCIiDS+NavCpAz5y41VE5IcgbQK+NjX4eLaVVtD31nE8\n+OGCei6RiEjDSpuAH9C97U793JzvCymtqOSRT3JZu7kk2m9+yD0f8+BHC+uziCIie1TaBPw1J9Sc\n9KM2Ymv+g+/5ONrrZvXmEh75eFG9lE1EpCGkTcAHA8aE646p089kBK1G084Hc1fXGMrg0wVr+Wb5\nxjrtW9MFikhDS5uAB+jdsVWdti+rcLw/e1WVdRuLy5j7/eYq6y799zTOevzLWu93w9ZS9r9tPKMn\na7RKEWk4aRXw1R3crU3KbV6aurzGuouf+XqXjrt5WxkAD09QE4+INJy0C/heWS04sV8nlt53KpcO\n7bFT+yguLY+7vrS8kqcmLU7Y/FJYXMb6Ldup8Jp4irbH34+IyJ6QdgH/yfXH8q+LBwGQalTgW0bs\nH3f99vId/eVjh0B4Ycoy7h03n+v+O4M3pueTPep9NhWXRt8fcNeHHPrXCZomUEQahbQL+FipusZ3\na9s87vrYD4atMbX5hyeEu02Onb2a61+bCcBUb+z5sorK6PFihyjeXr7rN1tz126p801eEZG0DvhA\n4rm4AcgMxT/9Lm2aRZc3l5THXY741QvTuX/8fC4aMzW6Ljbg66Or5bAHJ3HW41/yRe465q/enPoH\nRERI84AfcVBnzhvcnV5ZLeK+3yurZdz1Kzdtiy4fed8nKY/zxKeLmZK3Yxap2GETlm8I7+vBDxeQ\ns3THNp8vWses/E1x9zfv+828/HXNm78XjJ7K8IcmpyyPiAikecA3zQhy71kHkdWySdz3e3ZowYzb\nT6z345bGtOG3aZaBc45HPsnlnCe/is4Re+GYqZz+6Bdxf37EI5O5+c3Z9V6u16fn896s7+t9vyLS\nOKXNaJLJtGyy4zQvP6onYz5fEn3dpnlmvR/v9v/NiS5vLC6tctO2963jGHZAx4Q/OzVv/U4ds7C4\njMxQgGaZwYTb/NG7b3Ba/3126hh70hvT86lwjp8P6pZ6YxGJ6wcR8Pef059b35rNiIM6c8YhXbjy\nmF5xe9icN7gbL3+9YpePt7hgKwChgLGxuJRtpVVvtE74bm10uaBoO1mtwt8wlq8v5tynp8R9L5WD\n7/yQ7u2a89mNx+1q8RuFyE1sBbzIzkvrJpqIDi2b8NRFgzjjkC4AdGrdlL33alpju3vP6s9or4tl\nxFtXDeW93x1VY9t/X3ZYyuO2bBrii9z1PDQh8aBlh909gdWFJbw4dRkl1XrcHHb3hJTHiLV8QzEz\nVmwir2AL785UU4zID90PogafykPnHkJnL/CH9esUXT/3LyfTwmvemfjHYznu759G32tbrWnnoC57\nMXtlIYd0a8OMFeGbp1u8XjfPfZV8ku8Hxs/nzW9Xct2JfWu8V9dhkH/62I52/RP7daJpRrjJpvr4\nOvWhtLySvn8ax10//REXDdm5h8pEZPf5QdTgU/npgC4cHmeu1hYxbfc9O7TghcsHR1+3bFK1rfvd\n3x3F0vtOjX5QAByW3a5Wx498IMQbnnhaTM+biDem53PqI5NZvr6YL3PXJdzvqsIS1haVAFUf3ooo\nKimr0qWzrgq9IRnuG/sdD4yfz9ZqT+5u3Foa3SaRykqHc44HP1zAB3NX73RZRKQm1eDjyPnTsLhP\nox7dJ4tu7ZqxYsM2MoPxb2b+9rjejJsTDqr/O38Ag/66o5mlZ4cWLFm3tcbP5MVZFzEypk0+ItI+\nfczfJgJwx0/6xf3ZyDeOpfedSkHR9uj6TcWlHPPAxGi//kk3HMvzXy3jwiE96NkhfpfSeCIfDltL\nK3j808U0zQjy+xP68Oq05XRr25zzR4efDVh09ylkBOPXJXrdMpbzBnePdgtdet+pNbZxzvHClGX8\npP8+tG1R/zfFRdKVavBxdGjZJG4bPcBRvTsAsFezjLjv7xPzkFSHlk0YcdDe0defXP/jKgFWm3b8\n2rjj3XlJ3z/yvk94OOaBqw/mrq7y0NaP//YpYz5fwsinvwLCgTp6ch7Zo95n0F8/InvU+zw0YSHZ\no96v8kGxrdqYPEY49G96Y3Y03AH63DqOL3LXMXlRAdmj3uezhQXAjmEg4vX5jzV7ZSG3/28uN74x\nK+l2u+Lj79ZEv+2IpAsFfB395fQfMemGY9mr+Y6Aj70J26Ta07GPX3BodNlsx6O1R/Rqz5H7dtiN\nJd1h5aZtvD49P/o60UxVazZv5+Wvl9Pz5rH89f3vAFi3JTzWzkPeyJiH3T2B+as3c+bjX7Bmc9VA\nbN4kxOrC+CH51rcro5OpXPzM11RWurjNRvFE5tLduLU0xZY7bCut4OslNZu34qmodFz+XA7n/2tq\nwm1e/no5T01anHQ/m4pLGT05b7fc7xDZGWqiqaPMUIAe7as2Y/yoy17R5eoBD+GbnbEjUC65dwQQ\nDvzT+nfmvVmravzM7rRm8/aE79XmAavI07TVA7GsopItCUbQjP2AAVixsZgFq4tSHgt23CC2akNP\nTMlbz+Dsdtz/wXx6tGvB+Yd3j7532bNfMyVvA5NvPI5u7ZrzwpRlNA0F+FmcbpeR8YKWrNvKbW/P\noaBoO307teS6k/aLbhP5d/nVj/dNWM4bX5/Fh/PWMKB7GwZ0a0sg1VgZu2hV4TZGT17CLSMOILib\njyX+pIDfBWN/fzSLC7ZUWReK09b8r2pdL2Nr8odlt9vjAb+7fJG7jn/UcvLyL3LXc8tbqT9Mxny+\nhNdyws8mGOF/t4/mrWHmik08OjGXX/94X56aFJ5Y5fzDu7Nxayl567ZGh44o8pqiIt8e4ga89w3B\nCI8YCjB+LlUCPmJW/ib6d40/z0DkG83ZT3zFfp1aMe6ao3dryI96YzaTFhZwYr9ODInTSWB3WLim\niKlLNnDqQZ1pp/shjZ6aaHZBv31a85ODd+2p0OrdIJtmJL4k+yS4L9BYTF60jrJaDpWcKNyzR71f\nZSygu96bx3yvph/5XPzl8zk8OjEXgCdjmk0mLljLgLs+4olPd6yrdC7l9ImR5w/Ka9El9T9TljFx\n/lq2lVZQWFzGTa/PYnNJuKdQacy5L1hTxLlPf8XLXy9ne3kFU/LW8+q05Pca6qq8MvzBtCeHpz7p\nn59x29tzGFzHZzQSufWt2fxvxsp62ZfUpIBvYKFg1RrepUN7AvDu1UfVaO55+uJBLLr7FJbedypH\n9m6fNPD/ee7BdG8Xfzjkxu7ZL5bEXW8GZz+ReOrEP/9vLgATvlsTXbe2qIT9bxtfZbuLxkzlmAcm\ncv/4+WSPep9Fa6p+C0tmwZotXPbsNK58IYeD7/yQV3NW8IL3nEP1bqLTlm7k5jdnc+2rMxj59BRu\nemM22aPeJ39jMRDuZlq9m2ptupZGBLxPvIqYNv8vF69jzeaS6IfOta/O4IHx82v8bHlFZcJnLCoq\nHd8s38jE+WurzHdQ5efr+HxGIi9OXc41r8yol33FU7itjBtem5mw6TDdqYlmN4ntPZPMzw7tRl7B\nVp79cikAN568H6f178yPuuzF9Sf15Z6x8/nJwfvQr3NrDtyndbR558UrhgDhABt898c19nvmgK5k\nBANc/dK3SY8/OLsdX8fpax/r85uO46j7J8Z97/bT+nFwtzbc+e5cZuYXpjrduPp2asnCmJD91+T4\nAR87Ymc8yzcU11g3q1qZLhozlcmLws8ORGr6iaZodM5VaU4L7y/8zEJkHxCeAeyN6flxjw/h+QNi\nPTUpjyN7t+fX//kmuu4nB+9D7totfLcqPBz0ExcMZGCPtixZt5Wxs1dx22n92FRcRlFJWXQU1EjA\nR3ojOeeq3BfJ+dMw3vo2XDu+cfj+5CzdwOjJS3j0/AH0vnUcQ3q145Urj6hR3kc/yeWf3tPXh/ds\nx6u/qrkNwOKCLXRt24wmoZpdhldu2sZpj0zmlSuPYL+9W7FwTRFL1m3l5AN3/L/YmYnpN5eU0Swj\nyMI1RWS1bELH1jsqOZ8vWkdxaTknxRzjqUmLeW16Pvt2bMmvk9w/2ZPi/V7tLgr43SBeX+5EmmUG\nueP0Azksux2rCrcRCFj0pu0g70Gpswd24dj94g9Q1rFVU/4wrA+GRf9TvvGboUB4ULEXpyznhAM6\nYmaMm72KnGVVJw4Zfekgnv9yKZcf1YunPltMWUUl781axbL1O8KqQ5zRONs2z2Dovh34xVHhbxzH\n79+JmfmFHNm7PR1bNeWtb1dy9sCuvPFNfo2fjfXEBQM5tEdbBt9T80OqPjxUbV7c2GBOZfXmEoIB\nq1LTjddBprzSRZ9NqI0XpiyLtvVHVB9a4jcvflPl9eCe7aIf1o9fMJARB3WO3lh9bGIuG4tLObV/\n56r7+M/06PK9Y7/jqc/C9yoiYyVNydvA9vIKmoSC3Pb2HIb/aG+2lVZEf48A5q1KPP/ACf+YxGn9\nO/Po+QOB8E3fNZu3c0i3Nrw/63s2Fpdx8kOfccPJ+/G3D8L3ZmL/b2yI6RVVm9BzztH/jg/5ycH7\n8O7M72meGWTencMBWLC6iAu9ORm+ue3E6P2ByLebXenYNDu/kE3bSjm6T1bS7baXV7Dfn8Zz9XG9\n2W/vVtHm2/OensLgnu249sS+zFyxiTMe+4KXfnk4Q/dALzoFfCNR/T8nwMDubZl/1/DocAOJ/GFY\nXxYXbOGfExZyaI+2HNqjbfS9l68cEl2+4PDujJ6cR4eWTRjl9Qpp3TSDq4/vE90PwNJ1xSxbX8x/\nf3UEwUC4Z1CrpqHoDUuAr24+oUq5IoOilZW76H+uHu1TNxGdclD4vMf+/mhGPNI4xrofum97vly8\nniPuTT0XAIRH8tzdYr+J3fDaTK5+6ZvoDGI5yzaSs2wj1/236ofMtKU7Pswj4Q4w9/sd32pOeWhy\n9EG7F6cuqzELWlFJOdmj3uePJ/WlW5wmv/dmreLhkY5nv1zKgx8uYGtpBW9eNZR7xu5oFoqEO8CE\neWt44IP5PPeLwVUCvufNY7nymF5cNKRH3OPAjucuIh+GxTGD+P3syR1Nd2c/8SUT/3hs3H0ksm7L\ndk55eDJjLhlE/65tmL5sI1c8N43rTuzLbV7T338uP5yj+oRD+YbXZrJgTRFvX3Vk9EZ65Pcgcn/o\n6D4d2KtZBl/lreervPVce2JfPpkfHmhw0sKCPRLwaoNv5FKFe0SvDi24ZcT+PHr+gKT7uvr4Powc\n3D3hNgB/+1l/nv/FYAb3bMehPdphZsy+42Q++MMxXDo0m7x7RtQoVyTMM0JGkdf+u3frpjz484Oj\n21SfAzf2HkK/fVpz22nxn8j98NpjyE7xYXHT8Pjz69bV17eeUOt/84hXpu0YgTTZTfL6srW0IuV0\nlMnEfnuIfYo62T7//uHChG3l+94ylrvem8dWL3DPejzxfZIrns9h4ZotHHHvJ9HpLiOe/iyPox+Y\nyH+nreCJTxfjnGN2fiGPfrKI5euLySuo+cT3JO+hudgH95as28qXi9cxenIeeOeU6MvBjBWbyN9Y\nzIR5aygo2s6/v1jK69PzOfuJL9lYXBYNdwjP4VBSVsH4Oat5bXo+s/ILo/c6oOaDf4fc+RE9bx4b\nff1F7rroA4ehPdSt1RrqoYxBgwa5nJycBjm2hNtIt24vp2+nVvWyv5KyCv7+wQKuOLoXFc5x/7j5\n3Hf2QTTPDJE96n0g/PX82+UbOfPxL+nbqSXv/e7oGtMm/vuLJfzl3Xncdlo/7npvHif168TTFw/C\nOcdFY77mc2/snSX3jsDMOPepr5i6ZAMvXD6Yb5ZtqtK8APC743uzdH1xlSaQY/pmcc6hXfn9yzXv\nTyy971Te/nYlf3i1Zpjdd9ZBPDoxl1P7d+bEAzpxzpNfVXn/njMPYuRh3dj/9vFVJn2p7oDOraNt\n7RD+UIg8zCX1K/b+zoTrjmHJumKuenF6ld5eTUKB6EN3qZ5Lef4Xg6vcszluvyx6dmhJRsg4br+O\ncYcWief3x/eO2w23NsxsunNuUOotFfCyB7z5TT5FJeVcMjQb5xyPTczlzIFdq8x9W1tf5q6jSUaA\nQ3uE70/krt3CPz9ayD9+fjBNM4Jc9+oMpi7ZwMpN2xi6b3te+uUQCreV8eSkxdGbqnn3jCAQMKYv\n21ilV87M20+KPqEc+VCC8Fg/Iwd3j1uzf316fnQilch+c9cW8a/PlvBqTtW5BYYd0IlrTuhDZijA\nyQ99xrADOpKzbCMdWzWpcpO5NiJjIkXsv3eraHfSxubcQd2Yv6aImSviT1H5Q3TtsL5cM6zPTv1s\nXQK+Vt8nzWy4mS0ws1wzGxXn/SZm9qr3/lQzy65bkSWdnTWwK5cMzQbCD3ldfXyfnQp3gKG9O0TD\nHaB3x5Y8dsHAaPg+eO4hfH7Tccy64ySevSw8+udezTK4afj+DDugE4d0axNtMz20R1smXPdj/nP5\n4bzxm6FVhp+Yf9dw+ncN3+xu2yIzYbPNOYd2jS5H9tu7YytuOfUAIPzk85hLBvGnUw9g9CWDOKjr\nXuy3dyuW3ncqoy85jOl/OpGOrcJNVY+dP5AFfx3OyMO68Ysje3LgPq2rHCu22+tbVx0ZXT5zQBfG\nXXM0R/Zuz9B9ww88tW+RySPn1Wyui31u47Vfx+8dE7H/3ju+3T1y3gAyqz3E9+IVhyf9eYAD92nN\n/ef0Z8wlNfPozjMOrNKEF8/ZA7vGXd8iycxle8KTFw7cpZ9PMPZevUtZgzezILAQOBHIB6YB5znn\n5sVscxXbHYV8AAAIFElEQVTQ3zn3azMbCZzpnDs32X5Vg5fGrqSsglenreDCIT2SDgWwdnMJZZWu\nyoeWc47Lnp3GJUOzOS5BD6iIVYXbGDN5CTcnGXIg0svk0U8WsWV7BaNO2Z85KwtZs7mEEw7oVGXb\ntZtLaNU0g2aZQcbNXsWS9VtxDoIB49Kh2ex/23gO6daGt397JE98upgnJy3m0fMHsHV7BT/um8U/\nJyzk/MHd6dauOWUVlRQUbadbu+bM+34zYz5fQsDgten5LL3vVBYXbKGsopJpSzYwM7+QrFZN2FRc\nxpeL17FsfTGz7ziJVk3DH5zj56zivzn5PHLegCrTaAJMnL+Wy56dBoS/MT3+6WIuHNKDMwd04aIx\nU7n+pP0oKinnjEP2iQ7jvXTdVo79+6cMO6AjFxzegyl561m4poiJCwqS/nvHiszfYAZXHt2Lpz7L\nI7t9c5aur9rl9YDOrSmrqCR37RZaNw0x646T+fi7NVz+XDjDFv71FPr+aVzcY7RsEqrRD//6E/vy\nuxN2fw2+NgF/BHCHc+5k7/XNAM65e2O2+cDb5iszCwGrgSyXZOcKeJGGsW7LdlpkhpLO37urNhWX\nMn91UZ2GUHDOMW3pRg7LbrtL/cSnLd3A3q2b0rZFJi2bhKo0twE8eWF4AMBJC9dy71n9o8eOHNM5\nx9L1xXRv15x7xn7HCQd0jPZ4+X7TNpqEArRv2QTnHGNnr+akAzuREQywurCEJqEAbVtkUlJWwSMf\nL2JIr/bs27Elqwu30TQjSL/Ordn3lrH85th9ueHknesYUN8Bfw4w3Dl3hff6IuBw59zVMdvM8bbJ\n914v9rZJ2OlYAS8ie8LqwhI2bStl36yWCecl2JOueeVbjtuvIz8d0GWnfr4uAb9H+8Gb2ZXAlQDd\nuyfvqiciUh/23iv+HMwN5eGRibsy17fafJytBGKH4OvqrYu7jddEsxewvvqOnHNPO+cGOecGZWUl\nfypMRER2TW0CfhrQx8x6mlkmMBJ4p9o27wCXeMvnAJ8ka38XEZHdL2UTjXOu3MyuBj4AgsAzzrm5\nZnYnkOOcewcYA7xgZrnABsIfAiIi0oBq1QbvnBsLjK227vaY5RLgZ/VbNBER2RUNf0tZRER2CwW8\niEiaUsCLiKQpBbyISJpqsNEkzawAWJZyw/g6ALWfmqfx0/k0bjqfxu2Hdj49nHO1epCowQJ+V5hZ\nTm0f1fUDnU/jpvNp3HQ+iamJRkQkTSngRUTSlF8D/umGLkA90/k0bjqfxk3nk4Av2+BFRCQ1v9bg\nRUQkBd8FfKr5YRsjM+tmZhPNbJ6ZzTWza7z17czsIzNb5P3d1ltvZvaId46zzGzXJoDcDcwsaGbf\nmtl73uue3ny8ud78vJne+kY/X6+ZtTGz181svpl9Z2ZH+PzaXOv9ns0xs5fNrKmfro+ZPWNma72J\nhCLr6nw9zOwSb/tFZnZJvGPtCQnO52/e79ssM3vLzNrEvHezdz4LzOzkmPV1zz7nnG/+EB7NcjHQ\nC8gEZgL9GrpctSh3Z2Cgt9yK8By3/YAHgFHe+lHA/d7yCGAcYMAQYGpDn0Occ7oOeAl4z3v9X2Ck\nt/wk8Btv+SrgSW95JPBqQ5c9zrk8B1zhLWcCbfx6bYAuwBKgWcx1udRP1wc4BhgIzIlZV6frAbQD\n8ry/23rLbRvR+ZwEhLzl+2POp5+Xa02Anl7eBXc2+xr8F7KO/1BHAB/EvL4ZuLmhy7UT5/E/wpOY\nLwA6e+s6Awu85acIT2we2T66XWP4Q3jSl4+B44H3vP9c62J+YaPXifAw00d4yyFvO2voc4g5l728\nQLRq6/16bboAK7xgC3nX52S/XR8gu1og1ul6AOcBT8Wsr7JdQ59PtffOBF70lqtkWuT67Gz2+a2J\nJvLLG5HvrfMN7yvwAGAq0Mk5t8p7azXQyVtu7Of5EHAjUOm9bg9scs5Fpo6PLW/0XLz3C73tG4ue\nQAHwb6/JabSZtcCn18Y5txL4O7AcWEX433s6/r0+EXW9Ho36OlXzC8LfQqCez8dvAe9rZtYSeAP4\ng3Nuc+x7Lvyx3Oi7NJnZacBa59z0hi5LPQkR/vr8hHNuALCVcBNAlF+uDYDXNn0G4Q+ufYAWwPAG\nLVQ989P1SMXMbgXKgRd3x/79FvC1mR+2UTKzDMLh/qJz7k1v9Roz6+y93xlY661vzOd5JHC6mS0F\nXiHcTPMw0MbC8/FC1fLWar7eBpQP5DvnpnqvXycc+H68NgDDgCXOuQLnXBnwJuFr5tfrE1HX69HY\nrxNmdilwGnCB96EF9Xw+fgv42swP2+iYmRGe1vA759yDMW/FzmV7CeG2+cj6i70eAkOAwpivpw3K\nOXezc66rcy6b8L//J865C4CJhOfjhZrn0mjn63XOrQZWmNl+3qoTgHn48Np4lgNDzKy593sXOR9f\nXp8Ydb0eHwAnmVlb71vNSd66RsHMhhNu5jzdOVcc89Y7wEivd1NPoA/wNTubfQ19M2UnblaMINwL\nZTFwa0OXp5ZlPorwV8pZwAzvzwjCbZ0fA4uACUA7b3sDHvPOcTYwqKHPIcF5HcuOXjS9vF/EXOA1\noIm3vqn3Otd7v1dDlzvOeRwC5HjX523CvS58e22AvwDzgTnAC4R7ZPjm+gAvE75/UEb4G9blO3M9\nCLdt53p/Lmtk55NLuE09kgdPxmx/q3c+C4BTYtbXOfv0JKuISJryWxONiIjUkgJeRCRNKeBFRNKU\nAl5EJE0p4EVE0pQCXkQkTSngRUTSlAJeRCRN/T8AZbLAbRIGGAAAAABJRU5ErkJggg==\n", 258 | "text/plain": [ 259 | "" 260 | ] 261 | }, 262 | "metadata": {}, 263 | "output_type": "display_data" 264 | } 265 | ], 266 | "source": [ 267 | "plot(losses)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "## Evaluate" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 7, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "Accuracy: 97.81\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "evaluate_x = Variable(test_loader.dataset.test_data.type_as(torch.FloatTensor()))\n", 292 | "evaluate_y = Variable(test_loader.dataset.test_labels)\n", 293 | "if cuda:\n", 294 | " evaluate_x, evaluate_y = evaluate_x.cuda(), evaluate_y.cuda()\n", 295 | "\n", 296 | "model.eval()\n", 297 | "output = model(evaluate_x)\n", 298 | "pred = output.data.max(1)[1]\n", 299 | "d = pred.eq(evaluate_y.data).cpu()\n", 300 | "accuracy = d.sum()/d.size()[0]\n", 301 | "\n", 302 | "print('Accuracy:', accuracy*100)" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "kernelspec": { 308 | "display_name": "Python 3", 309 | "language": "python", 310 | "name": "python3" 311 | }, 312 | "language_info": { 313 | "codemirror_mode": { 314 | "name": "ipython", 315 | "version": 3 316 | }, 317 | "file_extension": ".py", 318 | "mimetype": "text/x-python", 319 | "name": "python", 320 | "nbconvert_exporter": "python", 321 | "pygments_lexer": "ipython3", 322 | "version": "3.6.2" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 2 327 | } 328 | --------------------------------------------------------------------------------