├── dogs-vs-cats ├── .gitignore └── dogs-vs-cats-part-3.ipynb ├── README.md ├── LICENSE ├── .gitignore ├── mnist-keras.ipynb └── n-armed-bandits.ipynb /dogs-vs-cats/.gitignore: -------------------------------------------------------------------------------- 1 | *.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jupyter-notebooks-public -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /dogs-vs-cats/dogs-vs-cats-part-3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dogs vs Cats - Part 3\n", 8 | "\n", 9 | "Classify whether images contain either a dog or a cat. Download the data from: https://www.kaggle.com/c/dogs-vs-cats/data\n", 10 | "\n", 11 | "Using an ensemble we achieve an accuracy of 99.0889% on our test set.\n", 12 | "\n", 13 | "This notebook assumes you have already run the steps from the Dogs vs Cats - Part 1 notebook where you downloaded the images and created the training, validation, and test directories.\n", 14 | "\n", 15 | "The dataset contains 25,000 images of dogs and cats (12,500 from each class). We will create a new dataset containing 3 subsets, a training set with 10,000 samples of each class, a validation dataset with 1250 of each class and a test set with 1250 samples of each class.\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "import torch.optim as optim\n", 27 | "import torch.utils.data\n", 28 | "import torch.nn.functional as F\n", 29 | "import torchvision\n", 30 | "import torchvision.models as models\n", 31 | "from torchvision import transforms\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "Using cache found in /home/wtf/.cache/torch/hub/pytorch_vision_master\n", 46 | "Using cache found in /home/wtf/.cache/torch/hub/pytorch_vision_master\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "model_resnet18 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)\n", 52 | "model_resnet34 = torch.hub.load('pytorch/vision', 'resnet34', pretrained=True)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# Freeze all params except the BatchNorm layers, as here they are trained to the\n", 62 | "# mean and standard deviation of ImageNet and we may lose some signal\n", 63 | "for name, param in model_resnet18.named_parameters():\n", 64 | " if(\"bn\" not in name):\n", 65 | " param.requires_grad = False\n", 66 | " \n", 67 | "for name, param in model_resnet34.named_parameters():\n", 68 | " if(\"bn\" not in name):\n", 69 | " param.requires_grad = False" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Replace the classifier\n", 79 | "num_classes = 2\n", 80 | "\n", 81 | "model_resnet18.fc = nn.Sequential(nn.Linear(model_resnet18.fc.in_features,512),\n", 82 | " nn.ReLU(),\n", 83 | " nn.Dropout(),\n", 84 | " nn.Linear(512, num_classes))\n", 85 | "\n", 86 | "model_resnet34.fc = nn.Sequential(nn.Linear(model_resnet34.fc.in_features,512),\n", 87 | " nn.ReLU(),\n", 88 | " nn.Dropout(),\n", 89 | " nn.Linear(512, num_classes))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=5, device=\"cpu\"):\n", 99 | " for epoch in range(epochs):\n", 100 | " training_loss = 0.0\n", 101 | " valid_loss = 0.0\n", 102 | " model.train()\n", 103 | " for batch in train_loader:\n", 104 | " optimizer.zero_grad()\n", 105 | " inputs, targets = batch\n", 106 | " inputs = inputs.to(device)\n", 107 | " targets = targets.to(device)\n", 108 | " output = model(inputs)\n", 109 | " loss = loss_fn(output, targets)\n", 110 | " loss.backward()\n", 111 | " optimizer.step()\n", 112 | " training_loss += loss.data.item() * inputs.size(0)\n", 113 | " training_loss /= len(train_loader.dataset)\n", 114 | " \n", 115 | " model.eval()\n", 116 | " num_correct = 0 \n", 117 | " num_examples = 0\n", 118 | " for batch in val_loader:\n", 119 | " inputs, targets = batch\n", 120 | " inputs = inputs.to(device)\n", 121 | " output = model(inputs)\n", 122 | " targets = targets.to(device)\n", 123 | " loss = loss_fn(output,targets) \n", 124 | " valid_loss += loss.data.item() * inputs.size(0)\n", 125 | " \n", 126 | " correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets).view(-1)\n", 127 | " num_correct += torch.sum(correct).item()\n", 128 | " num_examples += correct.shape[0]\n", 129 | " valid_loss /= len(val_loader.dataset)\n", 130 | "\n", 131 | " print('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}, accuracy = {:.4f}'.format(epoch, training_loss,\n", 132 | " valid_loss, num_correct / num_examples))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 6, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "batch_size=32\n", 142 | "img_dimensions = 224\n", 143 | "\n", 144 | "# Normalize to the ImageNet mean and standard deviation\n", 145 | "# Could calculate it for the cats/dogs data set, but the ImageNet\n", 146 | "# values give acceptable results here.\n", 147 | "img_transforms = transforms.Compose([\n", 148 | " transforms.Resize((img_dimensions, img_dimensions)),\n", 149 | " transforms.ToTensor(),\n", 150 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )\n", 151 | " ])\n", 152 | "\n", 153 | "img_test_transforms = transforms.Compose([\n", 154 | " transforms.Resize((img_dimensions,img_dimensions)), \n", 155 | " transforms.ToTensor(),\n", 156 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )\n", 157 | " ])\n", 158 | "\n", 159 | "def check_image(path):\n", 160 | " try:\n", 161 | " im = Image.open(path)\n", 162 | " return True\n", 163 | " except:\n", 164 | " return False\n", 165 | "\n", 166 | "train_data_path = \"/home/wtf/dogs-vs-cats/train/\"\n", 167 | "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)\n", 168 | "\n", 169 | "validation_data_path = \"/home/wtf/dogs-vs-cats/validation/\"\n", 170 | "validation_data = torchvision.datasets.ImageFolder(root=validation_data_path,transform=img_test_transforms, is_valid_file=check_image)\n", 171 | "\n", 172 | "test_data_path = \"/home/wtf/dogs-vs-cats/test/\"\n", 173 | "test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_test_transforms, is_valid_file=check_image)\n", 174 | "\n", 175 | "num_workers = 6\n", 176 | "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n", 177 | "validation_data_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", 178 | "test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", 179 | "\n", 180 | "\n", 181 | "if torch.cuda.is_available():\n", 182 | " device = torch.device(\"cuda\") \n", 183 | "else:\n", 184 | " device = torch.device(\"cpu\")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 7, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "Num training images: 16000\n", 197 | "Num validation images: 4500\n", 198 | "Num test images: 4500\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "print(f'Num training images: {len(train_data_loader.dataset)}')\n", 204 | "print(f'Num validation images: {len(validation_data_loader.dataset)}')\n", 205 | "print(f'Num test images: {len(test_data_loader.dataset)}')" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "### Train and test the models" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 8, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "def test_model(model):\n", 222 | " correct = 0\n", 223 | " total = 0\n", 224 | " with torch.no_grad():\n", 225 | " for data in test_data_loader:\n", 226 | " images, labels = data[0].to(device), data[1].to(device)\n", 227 | " outputs = model(images)\n", 228 | " _, predicted = torch.max(outputs.data, 1)\n", 229 | " total += labels.size(0)\n", 230 | " correct += (predicted == labels).sum().item()\n", 231 | " print('correct: {:d} total: {:d}'.format(correct, total))\n", 232 | " print('accuracy = {:f}'.format(correct / total))" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 9, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "Epoch: 0, Training Loss: 0.0855, Validation Loss: 0.0358, accuracy = 0.9878\n", 245 | "Epoch: 1, Training Loss: 0.0498, Validation Loss: 0.0309, accuracy = 0.9873\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "model_resnet18.to(device)\n", 251 | "optimizer = optim.Adam(model_resnet18.parameters(), lr=0.001)\n", 252 | "train(model_resnet18, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, validation_data_loader, epochs=2, device=device)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 10, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "correct: 4456 total: 4500\n", 265 | "accuracy = 0.990222\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "test_model(model_resnet18)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 11, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "Epoch: 0, Training Loss: 0.0678, Validation Loss: 0.0239, accuracy = 0.9907\n", 283 | "Epoch: 1, Training Loss: 0.0354, Validation Loss: 0.0317, accuracy = 0.9887\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "model_resnet34.to(device)\n", 289 | "optimizer = optim.Adam(model_resnet34.parameters(), lr=0.001)\n", 290 | "train(model_resnet34, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, validation_data_loader, epochs=2, device=device)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 12, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "correct: 4450 total: 4500\n", 303 | "accuracy = 0.988889\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "test_model(model_resnet34)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "### Make some predictions" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 13, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "dogs\n", 328 | "cats\n" 329 | ] 330 | } 331 | ], 332 | "source": [ 333 | "import os\n", 334 | "def find_classes(dir):\n", 335 | " classes = os.listdir(dir)\n", 336 | " classes.sort()\n", 337 | " class_to_idx = {classes[i]: i for i in range(len(classes))}\n", 338 | " return classes, class_to_idx\n", 339 | "\n", 340 | "def make_prediction(model, filename):\n", 341 | " labels, _ = find_classes('/home/wtf/dogs-vs-cats/test')\n", 342 | " img = Image.open(filename)\n", 343 | " img = img_test_transforms(img)\n", 344 | " img = img.unsqueeze(0)\n", 345 | " prediction = model(img.to(device))\n", 346 | " prediction = prediction.argmax()\n", 347 | " print(labels[prediction])\n", 348 | " \n", 349 | "make_prediction(model_resnet34, '/home/wtf/dogs-vs-cats/test/dogs/dog.11460.jpg')\n", 350 | "make_prediction(model_resnet34, '/home/wtf/dogs-vs-cats/test/cats/cat.12262.jpg')" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "### Save the models to disk" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 14, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "torch.save(model_resnet18.state_dict(), \"./model_resnet18.pth\")\n", 367 | "torch.save(model_resnet34.state_dict(), \"./model_resnet34.pth\")" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "### Load the models from disk and test with an ensemble" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 15, 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "name": "stderr", 384 | "output_type": "stream", 385 | "text": [ 386 | "Using cache found in /home/wtf/.cache/torch/hub/pytorch_vision_master\n", 387 | "Using cache found in /home/wtf/.cache/torch/hub/pytorch_vision_master\n" 388 | ] 389 | }, 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "done\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "# Remember that you must call model.eval() to set dropout and batch normalization layers to\n", 400 | "# evaluation mode before running inference. Failing to do this will yield inconsistent inference result\n", 401 | "\n", 402 | "resnet18 = torch.hub.load('pytorch/vision', 'resnet18')\n", 403 | "resnet18.fc = nn.Sequential(nn.Linear(resnet18.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, num_classes))\n", 404 | "resnet18.load_state_dict(torch.load('./model_resnet18.pth'))\n", 405 | "resnet18.eval()\n", 406 | "\n", 407 | "resnet34 = torch.hub.load('pytorch/vision', 'resnet34')\n", 408 | "resnet34.fc = nn.Sequential(nn.Linear(resnet34.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, num_classes))\n", 409 | "resnet34.load_state_dict(torch.load('./model_resnet34.pth'))\n", 410 | "resnet34.eval()\n", 411 | "\n", 412 | "print(\"done\")" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 16, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stdout", 422 | "output_type": "stream", 423 | "text": [ 424 | "accuracy = 0.990889\n", 425 | "correct: 4459 total: 4500\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "# Test against the average of each prediction from the two models\n", 431 | "models_ensemble = [resnet18.to(device), resnet34.to(device)]\n", 432 | "correct = 0\n", 433 | "total = 0\n", 434 | "with torch.no_grad():\n", 435 | " for data in test_data_loader:\n", 436 | " images, labels = data[0].to(device), data[1].to(device)\n", 437 | " predictions = [i(images).data for i in models_ensemble]\n", 438 | " avg_predictions = torch.mean(torch.stack(predictions), dim=0)\n", 439 | " _, predicted = torch.max(avg_predictions, 1)\n", 440 | "\n", 441 | " total += labels.size(0)\n", 442 | " correct += (predicted == labels).sum().item()\n", 443 | " \n", 444 | "print('accuracy = {:f}'.format(correct / total))\n", 445 | "print('correct: {:d} total: {:d}'.format(correct, total))" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [] 454 | } 455 | ], 456 | "metadata": { 457 | "kernelspec": { 458 | "display_name": "Python 3", 459 | "language": "python", 460 | "name": "python3" 461 | }, 462 | "language_info": { 463 | "codemirror_mode": { 464 | "name": "ipython", 465 | "version": 3 466 | }, 467 | "file_extension": ".py", 468 | "mimetype": "text/x-python", 469 | "name": "python", 470 | "nbconvert_exporter": "python", 471 | "pygments_lexer": "ipython3", 472 | "version": "3.8.2" 473 | } 474 | }, 475 | "nbformat": 4, 476 | "nbformat_minor": 4 477 | } 478 | -------------------------------------------------------------------------------- /mnist-keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Neural Network Image Classification with Keras and the MNIST Dataset\n", 8 | "\n", 9 | "In this post we'll use Keras and implement the Hello, World of machine learning, predict which number is in an image using the MNIST database of handwritten digits, and achieve 99% classification accuracy.\n", 10 | "\n", 11 | "Much of this is inspired by the book Deep Learning with Python by François Chollet. I highly recommend reading the book if you would like to dig deeper or learn more.\n", 12 | "https://www.manning.com/books/deep-learning-with-python" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "Using TensorFlow backend.\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "%matplotlib inline\n", 30 | "\n", 31 | "from keras import models, layers\n", 32 | "from keras.datasets import mnist\n", 33 | "from keras.utils import to_categorical\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import numpy as np" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "Since working with the MNIST digits is so common, Keras provides a function to load the data.\n", 43 | "\n", 44 | "You can see a full list of datasets Keras has packaged up here https://keras.io/datasets/\n", 45 | "\n", 46 | "Let's load the data." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "The training set consists of 60,000 28x28 pixel images, and the test set 10,000." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "((60000, 28, 28), (10000, 28, 28))" 74 | ] 75 | }, 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "train_images.shape, test_images.shape" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "Lets look at the first ten training images. They are each 28x28 grayscale images with one color value between 0 and 255." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": { 96 | "scrolled": true 97 | }, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlIAAABRCAYAAAAZ1Ej0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFGtJREFUeJzt3XmUzfUfx/Hn2LLkF8VEi0o7ORJpkUR7EVK0KUMhhYpKRdqTNhXZypKKkGhxlJTtNNVkiJIlh1QaUhmlhTS/P+55f77XbGa+c5fvd7we53RMc+/c+Xzm3vu9n8/78/68Pyk5OTmIiIiISPGVSXYDRERERMJKAykRERERnzSQEhEREfFJAykRERERnzSQEhEREfFJAykRERERnzSQEhEREfFJAykRERERnzSQEhEREfGpXIJ/X9jLqKcU4T6lvY+lvX+gPoaB+lj6+wfqYxjs831UREpERETEJw2kRERERHzSQEpERETEJw2kRERERHzSQEpERETEJw2kQmrJkiWkpaWRlpZGmTJlKFOmjPv/zMzMZDdPREKob9++pKSkkJKSQoMGDWjQoAHfffddspslEhetWrVy/5WEBlIiIiIiPiW6jlTM7d69m+zs7DzfHz58OAB//vknAKtXrwZgxIgR9O/fH4DJkycDULFiRQYMGADA4MGD497mkli2bBkA5513Htu3bwcgJSVS4uKVV14BYNasWfz666/JaWCCzJs3D4Brr70WgAULFnD88ccns0kx8cgjjwBw//33k5MTKb0yf/58AFq0aJGsZkkhfv/9d/744w8A3nvvPQC2bNkCQL9+/dhvv/2S1rai2rBhAwCTJk1y15OVK1cCsGrVKo444ohkNS1m1qxZA8DOnTtZtGgRAL169QK8a2hB2rVrB8CUKVMAqFChQryaGRO7du3ik08+AeCee+4BcP8vcPvttwOQnp7O9ddfX+LHC8VAauPGjezcuRPwXgyLFy8GYNu2bUyfPn2vj3H44YcD0Lt3b9566y0AqlatCkDDhg0D/yH1+eefA9ChQwcAsrOz3Zv/f//7H+C9ubdu3Up6ejoAjRs33uO2eFq4cCEAv/zyC+3bt4/r78rIyACgSZMmcf09iTJhwgQAhgwZAkDZsmXZvXs3sPeLvCTW+vXrARg6dCgQuRivWLEi3/tmZWXx/PPPJ6xtftWsWROIDNZnzZqV5NbExldffQXAxIkTAZg2bRoA//33Hz/++CPgvbf29h6zv0nPnj0BGDZsmLvuBlF2djbnnHMOALVq1QIir0X7el9lAZNRo0YBUL58ec4999wSP66W9kRERER8CnREaunSpUAkISy/5buiKFu2LOAtmVSpUsUtBx1yyCEAVK9ePZDLQrYsmZmZyXXXXQfApk2b8tzv2GOPBeCuu+4CoFOnTjRr1gzw+n3vvffGvb22BLV27dq4RqT+++8/FxXYuHEjgFsGCytL6P3nn3+S3BL/PvvsMyZNmgR40UmLCgA8/fTTgPe+W7RoEZ07dwbgtNNOS2RTi23VqlVAJBLx6quvAvDXX38BkddenTp1AC/KbctiU6dOdctHJ5xwQkLbXBxVqlQBKBVLeMauebbcGgsW3eratStnnXVWzB43nrKysty/+3pE6tNPPwVwK1xnnXUWHTt2LPHjKiIlIiIi4lOgI1I2O6pRo0aRIlI2q61evToff/wx4OUG2cw3THr06AHA66+/Xuj9lixZAuASXlu0aOGiQwXlbsSDzdbOPPPMuP6en376iTFjxgDe8xrk2X5hPvzwQ4A8eTQnnHAC7777LgAHH3xwwttVHG+88QYQ2Tr/888/A16E8JxzzmHr1q0AbpOHycnJcbdZEm9Q2PXm7rvvBrw+2gaPaMcddxzvv/8+4M107fX4888/uz4G2bZt2wD48ssvk9yS2Dn//POBvBGp1NRUunXrBkSi2wBlyngxBcvDXbBgQSKaKT4tXLiQRx99FPA2jh144IEF3n/y5Mnu8/CYY44B4KmnnopJWwI9kLI/ypNPPsk777wDQKNGjQDo06ePu9/JJ58MeB9KVapUcUsKYUj0zM0GRvZBGr1sZQmErVu3dh9MtlRif5vogWQil7zsohRvN954o/valjXDaPHixXTp0gXI+wF95513BnaZ5d9//wW8hP+bbroJgB07drhNG4MGDQIioXNbrrQQug06ILibBWxDytixYwu8j12M586d6zazrF27Nv6NiwNLI8ivZlRGRoYbGAb1NZmfm2++GfB23Jny5csXusRl78WTTjoJwCWmRz/WqaeeGtO2JoItRZcW3bt3dzsxbSm9sOXWRx991O1mf+mll4DIRrNY0NKeiIiIiE+BjkiZdu3aucqjlsy5fPlyIDKytMiMJUyCN5uwJaAwiK4RBexRJ+qSSy4BvBDm/PnzXVjTIjS2hblhw4ZuO6+FtTMzMznllFPi0m57LjZv3hyXx8/NliHAC9+H0cSJE/NsHrCIYyxqm8SLJVvb8oi54IIL3BJY9NZw+150JAoiJUluuOGGeDbVt6lTp+b7/SOPPJKmTZsC8MQTTwBeaRXwktLDxqLaaWlpeWrpDR48mGrVqgFw6623JrxtfpUrF/l4i35+isJep7/99lue2+yxwlAbLLclS5ZwxhlnJLsZMVOpUiX3Off3338XeD/7XN24cWOR7u+HIlIiIiIiPoUiIgXkKX52wAEHuK9tvfOqq64C9kwcDIs1a9a4An+W6GoRptq1a7uZ+/777w9EcqRat26918e13Iennnpqr0nrfs2ePRuI/xq8RbysCjPAoYceGtffGQ+WfPzyyy+78hw24x84cGDS2lUUAwcO5LHHHgO8Ioa33HILECm1kV+RQouc5vb888+713jQ2DXFItoXXHABEMmLSk1NLfDnEhWVjZdBgwYF/nSHeLEND/ac27Uz2kMPPZTQNvlVrlw5d02xCP66deuS2aSYsfzLr776ihNPPBHIP9dpx44dgBc53rFjB6effjoAV1xxRUzbFL4Rh4iIiEhAhCYildsDDzwARNZ9bau/7dqz2WMY2I6m/v37u3wmm9Xb2XlNmjQpcbTn+++/L9HPF8bOMTT169ePy++xXLisrCxXQNVy5sLAImmXX355ntt69+4NUOJTyOPFZuKPPfaYyw+58MILAW/GV6lSJXd/y0H44IMP3E4w20FqM8q2bdsmoOX+WM6QXWeKqjScZxb24rbFYfl+Q4YMcREbK2ERzXaGly9fPnGNK4Fq1arRvHlzALfjPezsM8x20pYrV44RI0YA5BvZvuOOOwAv3/HQQw+N2/sztAMpSywfO3asS6K2bdgtW7Z026pt2SGo55VlZmYCe9Y6sXOdgn7+X0FisTV4+/btzJkzB/Audh988IG73ZbALHwdBtaf6Npeds5T3759k9KmvbFlgRdffBGIvI9sADVz5sw89//2228B7zDpL774wt125ZVXAl4F/rCykiq2dJCTk+OuL9GV3AGaNWsWugTfop4/F3TRBzGDN9GOZocX59dXm9A+8cQTbrNP9GRBEsOulzYBtVp1ffr0yfcz0mpD2fml5r777otbG7W0JyIiIuJTaCNS5uijj3Yjz7S0NCCyJGbLYjZrtO3ktWvXTnwjC2Hhx5ycHLf1PRaRqNzh+USG663oWW5WNdkKd86bNw+AH374wYXTX3vtNXcfm/1ZxXpbUtq1a1dgCzkWZObMme7kcdO8eXNXDT5680SQ2PNis0DwIjJbtmwBYPz48UAkkvr1118D8PvvvwORmb5t/rDzIqPLlASdJRxbvx566KE8lbKjI1LGlgbHjx/vNhRI4qxYsYLLLrsM8M7jLK6zzz4biBR+LA1++eWXZDehyKzo76uvvkrXrl0B7zPM3mvp6elu40u/fv2AyGfPtGnT9ri/bdSyk0LiQREpEREREZ9CH5ECaN++PeAd2dCvXz+3Hn7PPfcA3tEH9913XyC2zNvxL1YsLCUlxc2gYiF3noMlS8aDRY7sd/Xo0cPNFKJZRMpmCpa4WblyZbeN1WYfjRs3dhE6O2vusMMOAyJlFsJytl5hCeZ169YN/Dl6dlalbfnfsmULRx55JJB/Xom9tyy/ZNOmTdSoUQOANm3axLu5MbFr1y6WLl0KQIcOHQBc4dTKlSu7aJOdKTlnzhwX+Ta7d+8GYMaMGS7/zf6WkliFReMLu82StGfPnu1ypMLs7bffTnYTisxKUXTr1i3PdcaOBcvIyHDHVFnffvzxR/detWvWuHHj4t7eUjGQMg0aNAAiWfr2JrCzzEaNGgVEzsKaO3duUtoXzXbh2dJJamoqnTp1KtFj2g7A6J1Glsw8ZMiQEj12YSwR2c7hKmhnRJ06dQBvt1a9evUAXG2PglhdF1tKqlu3bglbnDi2oy2/5Z3cS31BZMn8lljeunVrt0RgExd7Prt06eLOx7Sabps2bXJfB529F+fMmeMmZ8beUy1btnTnedkSdqtWrfIcDm6v1QEDBrjXvZ3TFvSq2PkNLhYuXAiEp7J5gwYN3G5uSza/6KKLAKhYsWK+P/Pyyy8D4TyfNT8tW7YEwrVrz05BsDSdChUquGuQ1UGsXr06EEmLsYOlbUAVvcxu9fqsGv38+fM5+uij49JuLe2JiIiI+FSqIlKmWrVqdO7cGfDOodu1axcQmVnZTMWWjoKgYsWKvhPhLRL1yCOPADB06FA3CrckPKuIHk933313XB7XktJNrKvSxoMt2eY+Xw5wS7hWCysMLOE/Ouk8Pxa5sJliSkpK4COIdm2wit52wgDAxRdfDHh1vqpVq+b+Brbcs3z5chdlstIOFqGaNWsW11xzDeCdC3nXXXe5WbVp1KhRjHvlX37lD958800AVq5cCXjR5CCzCHlRTwuwqGNpiUhZJNTs3LnTpbjY3yZoRo8eDXhRpIEDB7p0j9yGDx/uNgKkp6fnud02NVlkLl7RKFBESkRERMS3UhWRWr58OQDTp093a6Y22zT16tVz21qDxE+iuUU9bAZt68tt27ZlxowZsWtcwFiuSZBZdf3oE+QtqmMlD0ojy/2LjmoEOUdq9+7drtL6k08+CUSit48//jgAV199NeDlimVkZLjolBXTPe644xg5ciTgzX63b98ORPIFraSHJcRaZAq8qMH69evj0T1fevbsCXjRgWiWrzhs2LCEtikR8oseh1m5cnt+vOfk5LjVi6CyfEvbnGORqfxs3brVlSUxU6ZM4aSTTtrje7ZJKZ4UkRIRERHxKfQRqdWrV/PCCy8AuChMVlZWnvvZ6Lx27dquQGAy2c4Y+3fmzJk899xzRf75Z555hocffhiA7OxswCt4aMVIJXlsx0j0bj07rigR+WrJYsfHhMWYMWNcJMoKhY4ePdpFFD/99FPAKzo6e/ZsF3WznKq0tLQ8M2cr/3DRRRe53WKTJ08GvKKzAM8++2zsO1VCVookTGzlwaJK5557brGOcxk3bhy33XZbXNqWLBbdsVIxq1atcpFE22kdNEU5Kss+76ZOneq+th3EHTt2jF/jChG6gZQNkmwr5PDhw12tnvzYuW92zk4sazWVRO6EzqysLPr06QN4tZQOOuggIHIxty28Vovp+++/dwmDdqHu1atXglqfXGvXrg3s+WW2bdcGyFZPCLy6Q6VZ2JZH7DBm8KopDx061CUer127Ns/PPPjgg4BXo66olcttmdD+DSpburQJqp2fCLjJnt0nngm8RbVo0SJXt87O49ywYUOhy0JWumL27NlAZFNO7lpglStXBsJ/vp5NbjZt2sQzzzyT5NaUnA0CR44c6erwffTRR8lskpb2RERERPwKRURq8+bNLqnMCsKtWrWqwPufdtppbhuyhTeDsJxXmH///ZcRI0YAkWR58M5fW7NmTZ77n3nmmbRq1QrYc1a9L7BtrUGzbNkyV+zVIo22Lb5Xr16Br2IeC+vWrUt2E4qlVq1arnimJeJa1Bfg0ksvBbxz19q1a+cqu5f2M/Tq168PBP857d27d56CqEOHDqVq1aoF/oy9T5csWQLsWerByuJYhN82EIRdSkpKqKvrW+mGsWPHApHPdCt/kIiE8sIEe3QhIiIiEmCBjEjZ+rWd1rxs2bJCZ0XNmjUDvOKTF154YeDXtS3Hp2nTpgB8/vnn7jbLA9u8ebP7np1XZlvJi5OYXtqkp6e7o3+CZNu2bXs8Z4A7l+3pp59ORpMSrnnz5kDhZ5gFycKFC93xN1bOIDU11eUpWuHMMM/k/bLZfpjOaDPFTaZOTU11+bN2bS3oKJmwys7Odq/1/M7+DDorG2KRqc6dO7t8xWQLzEDqs88+AyIhWasB9cMPPxR4f0sE7NOnj0skt103YWChSNtpOHr0aLcLL7e+ffty8803A96BjSJBZOdd2ut03bp1bhJUs2bNpLWrIFWrVnWnINi/EmHVy+vVq+cqmgfR+PHjXWJ8UWq0HXPMMe7zwwb+N910k3vtljZWX7BixYqhqEhfEJs8W923oGwcAy3tiYiIiPiWkuAQfIG/bMCAAcCe51yZevXq0aZNG8BL8Ozfvz/gVRxOkJS936XgPobE3vqY8P5NmDAB8EoLdO/ePd+qy0UUt+cwKyuLTp06AZEt2QBHHXUUkPCE3aS/Tu0569atGy1atAAipUogZue0Jb2PCRC492KMxew5tI0C9robOHCgSxGxkxCsNljbtm2pVatWMZvqW9Jfp5YO8s0337hl2hiftZf0PiZAoX1UREpERETEp8BEpEJinx95U/r7B+pjidlZcx07dnRbzTt06AB4VcJLmNOY9D4mgN6L6mMY7PN9VERKRERExCdFpIpnnx95U/r7B+pjzGzfvt3tqrUt6VY8sYS5UoHpYxzpvag+hsE+30cNpIpnn3/BUPr7B+pjGKiPpb9/oD6GwT7fRy3tiYiIiPiU6IiUiIiISKmhiJSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiITxpIiYiIiPikgZSIiIiIT/8Hn3hdHsbNVR4AAAAASUVORK5CYII=\n", 102 | "text/plain": [ 103 | "
" 104 | ] 105 | }, 106 | "metadata": { 107 | "needs_background": "light" 108 | }, 109 | "output_type": "display_data" 110 | } 111 | ], 112 | "source": [ 113 | "_, ax = plt.subplots(1, 10, figsize=(10,10))\n", 114 | "\n", 115 | "for i in range(0, 10):\n", 116 | " ax[i].axis('off')\n", 117 | " ax[i].imshow(train_images[i], cmap=plt.cm.binary)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "And the labels representing which class the image represents." 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 5, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4], dtype=uint8)" 136 | ] 137 | }, 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "train_labels[0:10]" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "### Build the neural network\n", 152 | "\n", 153 | "Now build the neural network. We'll be using a number of convolutional layers. Note that we only have to specify the input shape in the first layer. The last layer provides the output. It has 10 units (one for each digit 0 to 9) and uses a softmax activation to map the output of a network to a probability distribution over the predicted output classes.\n", 154 | "\n", 155 | "https://en.wikipedia.org/wiki/Convolutional_neural_network" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "WARNING:tensorflow:From /Users/wtf/anaconda3/envs/tf_cpu/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 168 | "Instructions for updating:\n", 169 | "Colocations handled automatically by placer.\n" 170 | ] 171 | } 172 | ], 173 | "source": [ 174 | "model = models.Sequential()\n", 175 | "model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n", 176 | "model.add(layers.MaxPooling2D((2, 2)))\n", 177 | "model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", 178 | "model.add(layers.MaxPooling2D((2, 2)))\n", 179 | "model.add(layers.Conv2D(64, (3, 3), activation='relu'))\n", 180 | "model.add(layers.Flatten())\n", 181 | "model.add(layers.Dense(64, activation='relu'))\n", 182 | "model.add(layers.Dense(10, activation='softmax'))\n", 183 | "\n", 184 | "model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "One way to see what the network looks like is to use the summary() function:" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "_________________________________________________________________\n", 204 | "Layer (type) Output Shape Param # \n", 205 | "=================================================================\n", 206 | "conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n", 207 | "_________________________________________________________________\n", 208 | "max_pooling2d_1 (MaxPooling2 (None, 13, 13, 32) 0 \n", 209 | "_________________________________________________________________\n", 210 | "conv2d_2 (Conv2D) (None, 11, 11, 64) 18496 \n", 211 | "_________________________________________________________________\n", 212 | "max_pooling2d_2 (MaxPooling2 (None, 5, 5, 64) 0 \n", 213 | "_________________________________________________________________\n", 214 | "conv2d_3 (Conv2D) (None, 3, 3, 64) 36928 \n", 215 | "_________________________________________________________________\n", 216 | "flatten_1 (Flatten) (None, 576) 0 \n", 217 | "_________________________________________________________________\n", 218 | "dense_1 (Dense) (None, 64) 36928 \n", 219 | "_________________________________________________________________\n", 220 | "dense_2 (Dense) (None, 10) 650 \n", 221 | "=================================================================\n", 222 | "Total params: 93,322\n", 223 | "Trainable params: 93,322\n", 224 | "Non-trainable params: 0\n", 225 | "_________________________________________________________________\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "model.summary()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 7, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "We need to do some preprocessing of the images. We'll also use the first 50,000 training images for training, and the remaining 10,000 training examples for cross validation." 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 8, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "train_images = train_images.reshape((60000, 28, 28, 1))\n", 256 | "train_images= train_images.astype('float32') / 255 # rescale pixel values from range [0, 255] to [0, 1]\n", 257 | "\n", 258 | "test_images = test_images.reshape((10000, 28, 28, 1))\n", 259 | "test_images= test_images.astype('float32') / 255\n", 260 | "\n", 261 | "train_labels = to_categorical(train_labels)\n", 262 | "test_labels = to_categorical(test_labels)\n", 263 | "\n", 264 | "validation_images = train_images[50000:]\n", 265 | "validation_labels = train_labels[50000:]\n", 266 | "\n", 267 | "train_images = train_images[:50000]\n", 268 | "train_labels = train_labels[:50000]" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 9, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "WARNING:tensorflow:From /Users/wtf/anaconda3/envs/tf_cpu/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 281 | "Instructions for updating:\n", 282 | "Use tf.cast instead.\n", 283 | "Train on 50000 samples, validate on 10000 samples\n", 284 | "Epoch 1/5\n", 285 | "50000/50000 [==============================] - 20s 391us/step - loss: 0.1959 - acc: 0.9387 - val_loss: 0.0798 - val_acc: 0.9760\n", 286 | "Epoch 2/5\n", 287 | "50000/50000 [==============================] - 19s 380us/step - loss: 0.0509 - acc: 0.9845 - val_loss: 0.0513 - val_acc: 0.9849\n", 288 | "Epoch 3/5\n", 289 | "50000/50000 [==============================] - 19s 382us/step - loss: 0.0343 - acc: 0.9892 - val_loss: 0.0408 - val_acc: 0.9880\n", 290 | "Epoch 4/5\n", 291 | "50000/50000 [==============================] - 19s 379us/step - loss: 0.0257 - acc: 0.9918 - val_loss: 0.0448 - val_acc: 0.9874\n", 292 | "Epoch 5/5\n", 293 | "50000/50000 [==============================] - 19s 377us/step - loss: 0.0208 - acc: 0.9938 - val_loss: 0.0356 - val_acc: 0.9903\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "history = model.fit(train_images, train_labels, epochs=5, batch_size=64, validation_data=(validation_images, validation_labels))" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 10, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "10000/10000 [==============================] - 1s 122us/step\n", 311 | "Accuracy: 0.992\n", 312 | "Loss: 0.027386048025220953\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "test_loss, test_acc = model.evaluate(test_images, test_labels)\n", 318 | "print('Accuracy:', test_acc)\n", 319 | "print('Loss: ', test_loss)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "Looks pretty good we're seeing ~99% accuracy on the test set." 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "## Visualize training\n", 334 | "\n", 335 | "Now lets create a function that lets us graph the accuracy and loss values during training." 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 11, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "def plot_accuracy_and_loss(history):\n", 345 | " acc = history.history['acc']\n", 346 | " val_acc = history.history['val_acc']\n", 347 | " loss = history.history['loss']\n", 348 | " val_loss = history.history['val_loss']\n", 349 | "\n", 350 | " epochs = range(1, len(acc) + 1)\n", 351 | "\n", 352 | " plt.plot(epochs, acc, 'bo', label='Training acc')\n", 353 | " plt.plot(epochs, val_acc, 'b', label='Validation acc')\n", 354 | " plt.title('Training and validation accuracy')\n", 355 | " plt.legend()\n", 356 | " plt.show()\n", 357 | "\n", 358 | " plt.plot(epochs, loss, 'bo', label='Training loss')\n", 359 | " plt.plot(epochs, val_loss, 'b', label='Validation loss')\n", 360 | " plt.title('Training and validation loss')\n", 361 | " plt.legend()\n", 362 | " plt.show()\n" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 12, 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "image/png": "\n", 373 | "text/plain": [ 374 | "
" 375 | ] 376 | }, 377 | "metadata": { 378 | "needs_background": "light" 379 | }, 380 | "output_type": "display_data" 381 | }, 382 | { 383 | "data": { 384 | "image/png": "\n", 385 | "text/plain": [ 386 | "
" 387 | ] 388 | }, 389 | "metadata": { 390 | "needs_background": "light" 391 | }, 392 | "output_type": "display_data" 393 | } 394 | ], 395 | "source": [ 396 | "plot_accuracy_and_loss(history)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "The above looks pretty good, we appear to be starting to overfit the data as we get further in, but training and validation sets are pretty close to each other.\n", 404 | "\n", 405 | "Now lets look at a prediction, first we'll generate predictions for the test set." 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 13, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "preds = model.predict(test_images)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": {}, 420 | "source": [ 421 | "We'll use the network to try to figure out what the first digit in the test set is. If we manually look, it appears to be a 7." 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 16, 427 | "metadata": {}, 428 | "outputs": [ 429 | { 430 | "data": { 431 | "text/plain": [ 432 | "" 433 | ] 434 | }, 435 | "execution_count": 16, 436 | "metadata": {}, 437 | "output_type": "execute_result" 438 | }, 439 | { 440 | "data": { 441 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADWtJREFUeJzt3X+oXPWZx/HPZ900gqmakKuJNu7tJqIbgpsuQ1h1WV1/hEQCsX9UEqRkoTQFK26h6EpAq8hCWG26glJNNDRCa1tM3QQJbiWsaGAtGY1Wa3a3/rim2Vxyb4zQFISQ5Nk/7km5jXfOjPPrzM3zfoHMzHnOmfN4yOeemfmema8jQgDy+bOqGwBQDcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpP+/nzubOnRvDw8P93CWQysjIiI4cOeJW1u0o/LZXSHpU0jmSnoqIjWXrDw8Pq16vd7JLACVqtVrL67b9st/2OZIel7RS0mJJa20vbvf5APRXJ+/5l0l6LyI+iIjjkn4qaXV32gLQa52E/1JJv5v0+GCx7E/YXm+7brs+Pj7ewe4AdFMn4Z/qQ4XPfD84IjZHRC0iakNDQx3sDkA3dRL+g5IWTHr8JUmHOmsHQL90Ev69ki63/WXbX5C0RtLO7rQFoNfaHuqLiBO275T0H5oY6tsaEb/pWmcAeqqjcf6I2CVpV5d6AdBHXN4LJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUh3N0mt7RNIxSSclnYiIWjeaAtB7HYW/8A8RcaQLzwOgj3jZDyTVafhD0i9tv257fTcaAtAfnb7svzYiDtm+SNJLtv87Il6ZvELxR2G9JF122WUd7g5At3R05o+IQ8XtmKTnJS2bYp3NEVGLiNrQ0FAnuwPQRW2H3/Z5tr94+r6k5ZLe6VZjAHqrk5f9F0t63vbp5/lJRLzYla4A9Fzb4Y+IDyT9dRd7AdBHDPUBSRF+ICnCDyRF+IGkCD+QFOEHkurGt/pSeO655xrWtmzZUrrtJZdcUlo/99xzS+u33357aX3evHkNa4sWLSrdFnlx5geSIvxAUoQfSIrwA0kRfiApwg8kRfiBpBjnb9Hdd9/dsDYyMtLTfT/xxBOl9fPPP79hbfHixd1uZ9pYsGBBw9o999xTum2tdvb/Cj1nfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IinH+Fj311FMNa2+99Vbpts3G2t99993S+r59+0rrL7/8csPaa6+9VrptsynUDhw4UFrvxIwZM0rrc+fOLa2Pjo6W1sv+38uuAZAY5wdwFiP8QFKEH0iK8ANJEX4gKcIPJEX4gaSajvPb3ipplaSxiFhSLJsj6WeShiWNSLotIj7pXZvVu/HGG9uqtWLFihUdbf/JJ40PfbNrBJqNZ+/du7etnloxc+bM0voVV1xRWr/yyitL60ePHm1YW7hwYem2GbRy5v+RpDP/dd4raXdEXC5pd/EYwDTSNPwR8YqkM/+Erpa0rbi/TdKtXe4LQI+1+57/4ogYlaTi9qLutQSgH3r+gZ/t9bbrtuvj4+O93h2AFrUb/sO250tScTvWaMWI2BwRtYioDQ0Ntbk7AN3Wbvh3SlpX3F8naUd32gHQL03Db/tZSf8l6QrbB21/Q9JGSTfb/q2km4vHAKaRpuP8EbG2QamzwW10zezZsxvWbrjhho6eu9NrGDqxffv20nrZ9Q2SdNVVVzWsrVmzpq2eziZc4QckRfiBpAg/kBThB5Ii/EBShB9Iip/uRmXGxhpeGCpJuuOOO0rrEVFav//++xvW5syZU7ptBpz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApxvlRmccff7y03uw6gAsvvLC03uynv7PjzA8kRfiBpAg/kBThB5Ii/EBShB9IivADSTHOj57as2dPw9rGjZ1N97BjR/lcMUuWLOno+c92nPmBpAg/kBThB5Ii/EBShB9IivADSRF+IKmm4/y2t0paJWksIpYUyx6Q9E1J48VqGyJiV6+axPS1a1fjfxbHjx8v3famm24qrV999dVt9YQJrZz5fyRpxRTLfxARS4v/CD4wzTQNf0S8IuloH3oB0EedvOe/0/avbW+1PbtrHQHoi3bD/0NJCyUtlTQq6fuNVrS93nbddn18fLzRagD6rK3wR8ThiDgZEackbZG0rGTdzRFRi4ja0NBQu30C6LK2wm97/qSHX5X0TnfaAdAvrQz1PSvpeklzbR+U9D1J19teKikkjUj6Vg97BNADTcMfEWunWPx0D3rBNPTpp5+W1l988cWGtZkzZ5Zu++CDD5bWZ8yYUVpHOa7wA5Ii/EBShB9IivADSRF+ICnCDyTFT3ejIw8//HBpfd++fQ1rK1euLN32mmuuaasntIYzP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxTg/Sr3wwgul9Yceeqi0fsEFFzSs3XfffW31hO7gzA8kRfiBpAg/kBThB5Ii/EBShB9IivADSTHOn9zHH39cWr/rrrtK6ydOnCit33LLLQ1rTLFdLc78QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5BU03F+2wskPSNpnqRTkjZHxKO250j6maRhSSOSbouIT3rXKtpx8uTJ0vqKFStK6x9++GFpfdGiRaX1Zt/3R3VaOfOfkPTdiPgrSX8r6du2F0u6V9LuiLhc0u7iMYBpomn4I2I0It4o7h+TtF/SpZJWS9pWrLZN0q29ahJA932u9/y2hyV9RdKvJF0cEaPSxB8ISRd1uzkAvdNy+G3PkrRd0nci4vefY7v1tuu26+Pj4+30CKAHWgq/7RmaCP6PI+IXxeLDtucX9fmSxqbaNiI2R0QtImpDQ0Pd6BlAFzQNv21LelrS/ojYNKm0U9K64v46STu63x6AXmnlK73XSvq6pLdtv1ks2yBpo6Sf2/6GpAOSvtabFtGJ999/v7Rer9c7ev5NmzaV1hcuXNjR86N3moY/IvZIcoPyjd1tB0C/cIUfkBThB5Ii/EBShB9IivADSRF+ICl+uvss8NFHHzWsLV++vKPnfuSRR0rrq1at6uj5UR3O/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOP8Z4Enn3yyYa3sGoBWXHfddaX1id96wXTEmR9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmKcfxp49dVXS+uPPfZYnzrB2YQzP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1XSc3/YCSc9ImifplKTNEfGo7QckfVPSeLHqhojY1atGM9uzZ09p/dixY20/96JFi0rrs2bNavu5MdhaucjnhKTvRsQbtr8o6XXbLxW1H0RE+awOAAZS0/BHxKik0eL+Mdv7JV3a68YA9Nbnes9ve1jSVyT9qlh0p+1f295qe3aDbdbbrtuuj4+PT7UKgAq0HH7bsyRtl/SdiPi9pB9KWihpqSZeGXx/qu0iYnNE1CKiNjQ01IWWAXRDS+G3PUMTwf9xRPxCkiLicEScjIhTkrZIWta7NgF0W9Pwe+LnWZ+WtD8iNk1aPn/Sal+V9E732wPQK6182n+tpK9Letv2m8WyDZLW2l4qKSSNSPpWTzpER5YuXVpa3717d2l9zpw53WwHA6SVT/v3SJrqx9kZ0wemMa7wA5Ii/EBShB9IivADSRF+ICnCDyTliOjbzmq1WtTr9b7tD8imVqupXq+3NG86Z34gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSKqv4/y2xyV9NGnRXElH+tbA5zOovQ1qXxK9taubvf1FRLT0e3l9Df9ndm7XI6JWWQMlBrW3Qe1Lord2VdUbL/uBpAg/kFTV4d9c8f7LDGpvg9qXRG/tqqS3St/zA6hO1Wd+ABWpJPy2V9j+H9vv2b63ih4asT1i+23bb9qu9PvHxTRoY7bfmbRsju2XbP+2uJ1ymrSKenvA9v8Vx+5N27dU1NsC2/9pe7/t39j+p2J5pceupK9KjlvfX/bbPkfS/0q6WdJBSXslrY2Id/vaSAO2RyTVIqLyMWHbfy/pD5KeiYglxbJ/lXQ0IjYWfzhnR8Q/D0hvD0j6Q9UzNxcTysyfPLO0pFsl/aMqPHYlfd2mCo5bFWf+ZZLei4gPIuK4pJ9KWl1BHwMvIl6RdPSMxaslbSvub9PEP56+a9DbQIiI0Yh4o7h/TNLpmaUrPXYlfVWiivBfKul3kx4f1GBN+R2Sfmn7ddvrq25mChcX06afnj79oor7OVPTmZv76YyZpQfm2LUz43W3VRH+qX5iaJCGHK6NiL+RtFLSt4uXt2hNSzM398sUM0sPhHZnvO62KsJ/UNKCSY+/JOlQBX1MKSIOFbdjkp7X4M0+fPj0JKnF7VjF/fzRIM3cPNXM0hqAYzdIM15XEf69ki63/WXbX5C0RtLOCvr4DNvnFR/EyPZ5kpZr8GYf3ilpXXF/naQdFfbyJwZl5uZGM0ur4mM3aDNeV3KRTzGU8W+SzpG0NSL+pe9NTMH2X2ribC9NTGL6kyp7s/2spOs18a2vw5K+J+nfJf1c0mWSDkj6WkT0/YO3Br1dr4mXrn+cufn0e+w+9/Z3kl6V9LakU8XiDZp4f13ZsSvpa60qOG5c4QckxRV+QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS+n89yrzr7tkdMgAAAABJRU5ErkJggg==\n", 442 | "text/plain": [ 443 | "
" 444 | ] 445 | }, 446 | "metadata": { 447 | "needs_background": "light" 448 | }, 449 | "output_type": "display_data" 450 | } 451 | ], 452 | "source": [ 453 | "# reload the test images so it will be in a format imshow() will understand\n", 454 | "(_, _), (test_images, _) = mnist.load_data()\n", 455 | "\n", 456 | "plt.imshow(test_images[0], cmap=plt.cm.binary)" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "Since the output of the network was a layer with 10 units and a softmax activation, we will get an array of length 10 with a prediction for each potential number. Here you can see that the network is 99.9% certain it is a seven." 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 14, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "[2.6081236e-12 1.8943378e-09 1.0174886e-08 6.8640638e-08 2.3309353e-11\n", 476 | " 1.9539477e-10 7.4824168e-19 9.9999988e-01 4.3342949e-10 8.6599723e-09]\n" 477 | ] 478 | } 479 | ], 480 | "source": [ 481 | "print(preds[0])" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "We can also find the class with the highest prediction score with a numpy function:" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 15, 494 | "metadata": {}, 495 | "outputs": [ 496 | { 497 | "data": { 498 | "text/plain": [ 499 | "7" 500 | ] 501 | }, 502 | "execution_count": 15, 503 | "metadata": {}, 504 | "output_type": "execute_result" 505 | } 506 | ], 507 | "source": [ 508 | "np.argmax(preds[0])" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "The next step would be to retrain the model with all 60,000 training examples (remember that in the model above we trained on 50,000 examples and validated on the remaining 10,000). I'll leave that as an exercise to the reader." 516 | ] 517 | } 518 | ], 519 | "metadata": { 520 | "kernelspec": { 521 | "display_name": "Python 3", 522 | "language": "python", 523 | "name": "python3" 524 | }, 525 | "language_info": { 526 | "codemirror_mode": { 527 | "name": "ipython", 528 | "version": 3 529 | }, 530 | "file_extension": ".py", 531 | "mimetype": "text/x-python", 532 | "name": "python", 533 | "nbconvert_exporter": "python", 534 | "pygments_lexer": "ipython3", 535 | "version": "3.7.1" 536 | } 537 | }, 538 | "nbformat": 4, 539 | "nbformat_minor": 2 540 | } 541 | -------------------------------------------------------------------------------- /n-armed-bandits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# N-armed Bandits" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Epsilon-greedy strategy\n", 15 | "\n", 16 | "Given 10 slot machines with differing probabilities of paying out, find a strategy to maximize getting paid. Each arm has a random probability of paying out, and we are trying to determine which has the highest probability.\n", 17 | "\n", 18 | "Most of the time choose the arm with the best known probability of paying out(exploitation), but occasionally (choosing from a random probability epsilon) pick a random arm to try to learn more about the environment and potentially find a better arm (exploration)." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "from scipy import stats\n", 29 | "import random\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "\n", 32 | "n = 10 # number of arms\n", 33 | "probs = np.random.rand(n) # hidden probabilities associated with each arm\n", 34 | "eps = 0.2 # epsilon for the epsilon-greedy action selection" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "def get_reward(prob, n=10):\n", 44 | " reward = 0\n", 45 | " for i in range(n):\n", 46 | " if random.random() < prob:\n", 47 | " reward += 1\n", 48 | " return reward" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Given a single arm with a 70% probability of paying out a reward of 1 and running 10 iterations (so a max reward of 10), we see that the longer we run the closer we get to a mean reward of 7 for the arm." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "7.0335" 67 | ] 68 | }, 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "reward_test = [get_reward(0.7, n=10) for _ in range(2000)]\n", 76 | "np.mean(reward_test)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "(array([ 1., 3., 22., 73., 209., 383., 500., 489., 320.]),\n", 88 | " array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),\n", 89 | " )" 90 | ] 91 | }, 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "output_type": "execute_result" 95 | }, 96 | { 97 | "data": { 98 | "image/png": "\n", 99 | "text/plain": [ 100 | "
" 101 | ] 102 | }, 103 | "metadata": { 104 | "needs_background": "light" 105 | }, 106 | "output_type": "display_data" 107 | } 108 | ], 109 | "source": [ 110 | "plt.figure(figsize=(9,5))\n", 111 | "plt.xlabel(\"Reward\",fontsize=22)\n", 112 | "plt.ylabel(\"# Observations\",fontsize=22)\n", 113 | "plt.hist(reward_test,bins=9)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "array([[0., 0.],\n", 125 | " [0., 0.],\n", 126 | " [0., 0.],\n", 127 | " [0., 0.],\n", 128 | " [0., 0.],\n", 129 | " [0., 0.],\n", 130 | " [0., 0.],\n", 131 | " [0., 0.],\n", 132 | " [0., 0.],\n", 133 | " [0., 0.]])" 134 | ] 135 | }, 136 | "execution_count": 5, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "# 10 actions x 2 columns\n", 143 | "# Columns: Count #, Avg Reward\n", 144 | "records = np.zeros((n, 2))\n", 145 | "records" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# Takes the records array, an action (index value of the arm), and a new reward observation\n", 155 | "# Then updates the average reward for the arm\n", 156 | "def update_records(records, action, r):\n", 157 | " new_r = (records[action, 0] * records[action, 1] + r) / (records[action, 0] + 1)\n", 158 | " records[action, 0] += 1\n", 159 | " records[action, 1] = new_r\n", 160 | " return records" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# Given our array of records, find the arm with the highest probability of payout\n", 170 | "def get_best_arm(records):\n", 171 | " arm_index = np.argmax(records[:, 1], axis=0)\n", 172 | " return arm_index" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "Highest probability: 0.9518422506414441\n" 185 | ] 186 | }, 187 | { 188 | "data": { 189 | "image/png": "\n", 190 | "text/plain": [ 191 | "
" 192 | ] 193 | }, 194 | "metadata": { 195 | "needs_background": "light" 196 | }, 197 | "output_type": "display_data" 198 | } 199 | ], 200 | "source": [ 201 | "fig, ax = plt.subplots(1, 1)\n", 202 | "ax.set_xlabel(\"Plays\")\n", 203 | "ax.set_ylabel(\"Avg Reward\")\n", 204 | "fig.set_size_inches(9,5)\n", 205 | "\n", 206 | "records = np.zeros((n, 2))\n", 207 | "\n", 208 | "num_arms = 10\n", 209 | "probs = np.random.rand(num_arms) # hidden probabilities associated with each arm\n", 210 | "eps = 0.2 # epsilon for the epsilon-greedy action selection\n", 211 | "\n", 212 | "rewards = [0]\n", 213 | "for i in range(500):\n", 214 | " if random.random() > eps:\n", 215 | " choice = get_best_arm(records)\n", 216 | " else:\n", 217 | " choice = np.random.randint(num_arms)\n", 218 | " r = get_reward(probs[choice])\n", 219 | " records = update_records(records,choice,r)\n", 220 | " mean_reward = ((i+1) * rewards[-1] + r)/(i+2)\n", 221 | " rewards.append(mean_reward)\n", 222 | "ax.scatter(np.arange(len(rewards)),rewards)\n", 223 | "\n", 224 | "print('Highest probability: ' + str(np.amax(probs)))" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "## Softmax selection policy\n", 232 | "\n", 233 | "Softmax gives us a probability distribution over our options, with the largest probability being the best arm. This allows us to explore, but with a lower likelihood of picking a poorly performing arm.\n", 234 | "\n", 235 | "For this problem softmax tends to converge on an optimal policy faster than epsilon greedy, but is very sensitive to the tau value and can take some time to find a good value (whereas finding an epsion parameter tends to be more intuitive)." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 9, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "def softmax(av, tau=1.12):\n", 245 | " softm = np.exp(av / tau) / np.sum(np.exp(av / tau))\n", 246 | " return softm" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 10, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "Highest probability: 0.8402165336734575\n" 259 | ] 260 | }, 261 | { 262 | "data": { 263 | "image/png": "\n", 264 | "text/plain": [ 265 | "
" 266 | ] 267 | }, 268 | "metadata": { 269 | "needs_background": "light" 270 | }, 271 | "output_type": "display_data" 272 | } 273 | ], 274 | "source": [ 275 | "n = 10\n", 276 | "probs = np.random.rand(n)\n", 277 | "records = np.zeros((n,2))\n", 278 | "\n", 279 | "fig,ax = plt.subplots(1,1)\n", 280 | "ax.set_xlabel(\"Plays\")\n", 281 | "ax.set_ylabel(\"Avg Reward\")\n", 282 | "fig.set_size_inches(9,5)\n", 283 | "rewards = [0]\n", 284 | "for i in range(500):\n", 285 | " p = softmax(records[:,1],tau=0.7)\n", 286 | " choice = np.random.choice(np.arange(n),p=p)\n", 287 | " r = get_reward(probs[choice])\n", 288 | " records = update_records(records,choice,r)\n", 289 | " mean_reward = ((i+1) * rewards[-1] + r)/(i+2)\n", 290 | " rewards.append(mean_reward)\n", 291 | "ax.scatter(np.arange(len(rewards)),rewards)\n", 292 | "\n", 293 | "print('Highest probability: ' + str(np.amax(probs)))" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "## Contextual Bandits" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 11, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "import numpy as np\n", 310 | "import torch" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 12, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "class ContextBandit:\n", 320 | " def __init__(self, arms=10):\n", 321 | " self.arms = arms\n", 322 | " self.init_distribution(arms)\n", 323 | " self.update_state()\n", 324 | " \n", 325 | " def init_distribution(self, arms):\n", 326 | " # Num states = Num Arms to keep things simple\n", 327 | " self.bandit_matrix = np.random.rand(arms,arms)\n", 328 | " # each row represents a state, each column an arm\n", 329 | " \n", 330 | " def reward(self, prob):\n", 331 | " reward = 0\n", 332 | " for i in range(self.arms):\n", 333 | " if random.random() < prob:\n", 334 | " reward += 1\n", 335 | " return reward\n", 336 | " \n", 337 | " def get_state(self):\n", 338 | " return self.state\n", 339 | " \n", 340 | " def update_state(self):\n", 341 | " self.state = np.random.randint(0,self.arms)\n", 342 | " \n", 343 | " def get_reward(self,arm):\n", 344 | " return self.reward(self.bandit_matrix[self.get_state()][arm])\n", 345 | " \n", 346 | " def choose_arm(self, arm):\n", 347 | " reward = self.get_reward(arm)\n", 348 | " self.update_state()\n", 349 | " return reward" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 13, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "def one_hot(N, pos, val=1):\n", 359 | " one_hot_vec = np.zeros(N)\n", 360 | " one_hot_vec[pos] = val\n", 361 | " return one_hot_vec\n", 362 | "\n", 363 | "def running_mean(x,N=50):\n", 364 | " c = x.shape[0] - N\n", 365 | " y = np.zeros(c)\n", 366 | " conv = np.ones(N)\n", 367 | " for i in range(c):\n", 368 | " y[i] = (x[i:i+N] @ conv)/N\n", 369 | " return y" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 14, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "arms = 10\n", 379 | "loss_fn = torch.nn.MSELoss()\n", 380 | "env = ContextBandit(arms)\n", 381 | "\n", 382 | "N = 1\n", 383 | "D_in = arms\n", 384 | "H = 100\n", 385 | "D_out = arms" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 15, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "model = torch.nn.Sequential(\n", 395 | " torch.nn.Linear(D_in, H),\n", 396 | " torch.nn.ReLU(),\n", 397 | " torch.nn.Linear(H, D_out),\n", 398 | " torch.nn.ReLU(),\n", 399 | ")" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 16, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "def train(env, epochs=5000, learning_rate=1e-2):\n", 409 | " # Convert the environment's current state to a PyTorch variable\n", 410 | " cur_state = torch.Tensor(one_hot(arms,env.get_state()))\n", 411 | " \n", 412 | " optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", 413 | " rewards = []\n", 414 | " for i in range(epochs):\n", 415 | " y_pred = model(cur_state)\n", 416 | " \n", 417 | " # Convert reward predictions to probability distribution with softmax\n", 418 | " av_softmax = softmax(y_pred.data.numpy(), tau=2.0)\n", 419 | " \n", 420 | " # Normalize distribution to ensure it sums to 1\n", 421 | " av_softmax /= av_softmax.sum()\n", 422 | " \n", 423 | " # Choose new action probabilistically\n", 424 | " choice = np.random.choice(arms, p=av_softmax)\n", 425 | " \n", 426 | " # Take action, recieve reward\n", 427 | " cur_reward = env.choose_arm(choice)\n", 428 | " \n", 429 | " # Convert PyTorch data to a numpy array\n", 430 | " one_hot_reward = y_pred.data.numpy().copy()\n", 431 | " \n", 432 | " # Update one_hot_reward array to use as labeled training data\n", 433 | " one_hot_reward[choice] = cur_reward\n", 434 | " reward = torch.Tensor(one_hot_reward)\n", 435 | " rewards.append(cur_reward)\n", 436 | " loss = loss_fn(y_pred, reward)\n", 437 | " optimizer.zero_grad()\n", 438 | " loss.backward()\n", 439 | " optimizer.step()\n", 440 | " \n", 441 | " # Update current environment state\n", 442 | " cur_state = torch.Tensor(one_hot(arms,env.get_state()))\n", 443 | " return np.array(rewards)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 17, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "rewards = train(env)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 18, 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "data": { 462 | "text/plain": [ 463 | "[]" 464 | ] 465 | }, 466 | "execution_count": 18, 467 | "metadata": {}, 468 | "output_type": "execute_result" 469 | }, 470 | { 471 | "data": { 472 | "image/png": "\n", 473 | "text/plain": [ 474 | "
" 475 | ] 476 | }, 477 | "metadata": { 478 | "needs_background": "light" 479 | }, 480 | "output_type": "display_data" 481 | } 482 | ], 483 | "source": [ 484 | "plt.plot(running_mean(rewards,N=500))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [] 507 | } 508 | ], 509 | "metadata": { 510 | "kernelspec": { 511 | "display_name": "Python 3", 512 | "language": "python", 513 | "name": "python3" 514 | }, 515 | "language_info": { 516 | "codemirror_mode": { 517 | "name": "ipython", 518 | "version": 3 519 | }, 520 | "file_extension": ".py", 521 | "mimetype": "text/x-python", 522 | "name": "python", 523 | "nbconvert_exporter": "python", 524 | "pygments_lexer": "ipython3", 525 | "version": "3.8.2" 526 | } 527 | }, 528 | "nbformat": 4, 529 | "nbformat_minor": 4 530 | } 531 | --------------------------------------------------------------------------------