├── .gitignore ├── Part1.ipynb └── Part2.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /Part2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 6.86x - Introduction to ML Packages (Part 2)\n", 8 | "\n", 9 | "This tutorial is designed to provide a short introduction to deep learning with PyTorch.\n", 10 | "\n", 11 | "You can start studying this tutorial as you work through unit 3 of the course.", 12 | "\n", 13 | "For more resources, check out [the PyTorch tutorials](https://pytorch.org/tutorials/)! There are many more in-depth examples available there.\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "Source code for this notebook hosted at: https://github.com/varal7/ml-tutorial" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## PyTorch\n", 28 | "\n", 29 | "[PyTorch](https://pytorch.org) is a flexible scientific computing package targetted towards gradient-based deep learning. Its low-level API closely follows [NumPy](http://www.numpy.org/). However, there are a several key additions:\n", 30 | "\n", 31 | "- GPU support!\n", 32 | "- Automatic differentiation!\n", 33 | "- Deep learning modules!\n", 34 | "- Data loading!\n", 35 | "- And other generally useful goodies.\n", 36 | "\n", 37 | "If you don't have GPU enabled hardward, don't worry. Like NumPy, PyTorch runs pre-compiled, highly efficient C code to handle all intensive backend functions.\n", 38 | "\n", 39 | "Go to pytorch.org to download the correct package for your computing environment." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "# Start by importing torch\n", 51 | "import torch" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "ExecuteTime": { 58 | "end_time": "2019-02-20T16:35:28.485220Z", 59 | "start_time": "2019-02-20T16:35:28.476294Z" 60 | } 61 | }, 62 | "source": [ 63 | "### Tensors\n", 64 | "\n", 65 | "Tensors are PyTorch's equivalent of NumPy ndarrays. " 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": { 72 | "ExecuteTime": { 73 | "end_time": "2019-02-20T16:35:28.492012Z", 74 | "start_time": "2019-02-20T16:35:28.487024Z" 75 | } 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "tensor([[1., 1.],\n", 83 | " [1., 1.]])\n", 84 | "tensor([[0., 0.],\n", 85 | " [0., 0.]])\n", 86 | "tensor([[-2.4514, -0.6150],\n", 87 | " [ 0.9997, 0.4635]])\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "# Construct a bunch of ones\n", 93 | "some_ones = torch.ones(2, 2)\n", 94 | "print(some_ones)\n", 95 | "\n", 96 | "# Construct a bunch of zeros\n", 97 | "some_zeros = torch.zeros(2, 2)\n", 98 | "print(some_zeros)\n", 99 | "\n", 100 | "# Construct some normally distributed values\n", 101 | "some_normals = torch.randn(2, 2)\n", 102 | "print(some_normals)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": { 108 | "ExecuteTime": { 109 | "end_time": "2019-02-20T16:35:28.497368Z", 110 | "start_time": "2019-02-20T16:35:28.494171Z" 111 | }, 112 | "scrolled": true 113 | }, 114 | "source": [ 115 | "PyTorch tensors and NumPy ndarrays even share the same memory handles, so you can switch between the two types essentially for free:" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "torch_tensor = torch.randn(5, 5)\n", 127 | "numpy_ndarray = torch_tensor.numpy()\n", 128 | "back_to_torch = torch.from_numpy(numpy_ndarray)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": { 134 | "ExecuteTime": { 135 | "end_time": "2019-02-20T16:35:28.564975Z", 136 | "start_time": "2019-02-20T16:35:28.499321Z" 137 | } 138 | }, 139 | "source": [ 140 | "Like NumPy, there are a zillion different operations you can do with tensors. Best thing to do is to go to https://pytorch.org/docs/stable/tensors.html if you know you want to do something to a tensor but don't know how!\n", 141 | "\n", 142 | "We can cover a few major ones here:" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "ExecuteTime": { 149 | "end_time": "2019-02-19T16:57:10.540510Z", 150 | "start_time": "2019-02-19T16:57:10.496709Z" 151 | } 152 | }, 153 | "source": [ 154 | "In the Numpy tutorial, we have covered the basics of Numpy, numpy arrays, element-wise operations, matrices operations and generating random matrices. \n", 155 | "In this section, we'll cover indexing, slicing and broadcasting, which are useful concepts that will be reused in `Pandas` and `PyTorch`.\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "tensor([[-2.4792, 0.7683, -0.8724, -1.0555, -1.3677],\n", 168 | " [ 0.2659, 0.3905, 0.4132, 1.0330, 1.3572],\n", 169 | " [-0.3723, -0.8348, -1.1457, -1.4766, -1.0380],\n", 170 | " [ 1.7401, 1.5151, -0.6725, -0.8755, 0.2736],\n", 171 | " [ 0.9129, 0.9838, -0.8510, -0.2960, -0.3731]])\n", 172 | "tensor([[ 0.7255, 0.7353, -0.5352, 1.4629, -0.4881],\n", 173 | " [-1.2316, 0.7042, -1.3126, 0.8110, -1.3477],\n", 174 | " [-2.4669, 0.0770, 0.9740, 0.4297, -0.5245],\n", 175 | " [-1.0458, -1.2261, 0.6324, 0.8264, -1.3746],\n", 176 | " [-2.2290, 0.1202, -0.4826, -1.9797, -0.0879]])\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "# Create two tensors\n", 182 | "a = torch.randn(5, 5)\n", 183 | "b = torch.randn(5, 5)\n", 184 | "print(a)\n", 185 | "print(b)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 7, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "tensor(-1.1457)\n", 198 | "-1.1456899642944336\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "# Indexing by i,j\n", 204 | "another_tensor = a[2, 2]\n", 205 | "print(another_tensor)\n", 206 | "\n", 207 | "# The above returns a tensor type! To get the python value:\n", 208 | "python_value = a[2, 2].item()\n", 209 | "print(python_value)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 8, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "tensor([[-1.1457, -1.4766],\n", 222 | " [-0.6725, -0.8755]])\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "# Getting a whole row or column or range\n", 228 | "first_row = a[0, :]\n", 229 | "first_column = a[:, 0]\n", 230 | "combo = a[2:4, 2:4]\n", 231 | "print(combo)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 9, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# Addition\n", 241 | "c = a + b\n", 242 | "\n", 243 | "# Elementwise multiplication: c_ij = a_ij * b_ij\n", 244 | "c = a * b\n", 245 | "\n", 246 | "# Matrix multiplication: c_ik = a_ij * b_jk\n", 247 | "c = a.mm(b)\n", 248 | "\n", 249 | "# Matrix vector multiplication\n", 250 | "c = a.matmul(b[:, 0])" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 11, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "torch.Size([5, 5])\n", 263 | "torch.Size([5])\n", 264 | "tensor([ 3.8873, 2.8224, 0.5655, -1.8550, 3.2441])\n", 265 | "tensor([[ 3.8873],\n", 266 | " [ 2.8224],\n", 267 | " [ 0.5655],\n", 268 | " [-1.8550],\n", 269 | " [ 3.2441]])\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "a = torch.randn(5, 5)\n", 275 | "print(a.size())\n", 276 | "\n", 277 | "vec = a[:, 0]\n", 278 | "print(vec.size())\n", 279 | "\n", 280 | "# Matrix multiple 5x5 * 5x5 --> 5x5\n", 281 | "aa = a.mm(a)\n", 282 | "\n", 283 | "# matrix vector 5x5 * 5 --> 5\n", 284 | "v1 = a.matmul(vec)\n", 285 | "print(v1)\n", 286 | "\n", 287 | "\n", 288 | "vec_as_matrix = vec.view(5, 1)\n", 289 | "v2 = a.mm(vec_as_matrix)\n", 290 | "print(v2)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "In-place operations exist to, generally denoted by a trailing '_' (e.g. my_tensor.my_inplace_function_)." 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 12, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "tensor([[0., 0., 0., 0., 0.],\n", 309 | " [0., 0., 0., 0., 0.],\n", 310 | " [0., 0., 0., 0., 0.],\n", 311 | " [0., 0., 0., 0., 0.],\n", 312 | " [0., 0., 0., 0., 0.]])" 313 | ] 314 | }, 315 | "execution_count": 12, 316 | "metadata": {}, 317 | "output_type": "execute_result" 318 | } 319 | ], 320 | "source": [ 321 | "# Add one to all elements\n", 322 | "a.add_(1)\n", 323 | "\n", 324 | "# Divide all elements by 2\n", 325 | "a.div_(2)\n", 326 | "\n", 327 | "# Set all elements to 0\n", 328 | "a.zero_()" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "Manipulate dimensions..." 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 13, 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "torch.Size([10, 10, 1])\n", 348 | "torch.Size([1, 10, 10])\n", 349 | "torch.Size([10, 1, 10])\n", 350 | "torch.Size([10, 10])\n", 351 | "torch.Size([100, 1])\n", 352 | "torch.Size([50, 2])\n", 353 | "tensor([[-0.1561],\n", 354 | " [ 0.1588]])\n", 355 | "tensor([[-0.1561, -0.1561, -0.1561],\n", 356 | " [ 0.1588, 0.1588, 0.1588]])\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "# Add a dummy dimension, e.g. (n, m) --> (n, m, 1)\n", 362 | "a = torch.randn(10, 10)\n", 363 | "\n", 364 | "# At the end\n", 365 | "print(a.unsqueeze(-1).size())\n", 366 | "\n", 367 | "# At the beginning\n", 368 | "print(a.unsqueeze(0).size())\n", 369 | "\n", 370 | "# In the middle\n", 371 | "print(a.unsqueeze(1).size())\n", 372 | "\n", 373 | "# What you give you can take away\n", 374 | "print(a.unsqueeze(0).squeeze(0).size())\n", 375 | "\n", 376 | "# View things differently, i.e. flat\n", 377 | "print(a.view(100, 1).size())\n", 378 | "\n", 379 | "# Or not flat\n", 380 | "print(a.view(50, 2).size())\n", 381 | "\n", 382 | "# Copy data across a new dummy dimension!\n", 383 | "a = torch.randn(2)\n", 384 | "a = a.unsqueeze(-1)\n", 385 | "print(a)\n", 386 | "print(a.expand(2, 3))" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "If you have a GPU..." 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 14, 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "name": "stdout", 403 | "output_type": "stream", 404 | "text": [ 405 | "CPU it is!\n" 406 | ] 407 | } 408 | ], 409 | "source": [ 410 | "# Check if you have it\n", 411 | "do_i_have_cuda = torch.cuda.is_available()\n", 412 | "\n", 413 | "if do_i_have_cuda:\n", 414 | " print('Using fancy GPUs')\n", 415 | " # One way\n", 416 | " a = a.cuda()\n", 417 | " a = a.cpu()\n", 418 | "\n", 419 | " # Another way\n", 420 | " device = torch.device('cuda')\n", 421 | " a = a.to(device)\n", 422 | "\n", 423 | " device = torch.device('cpu')\n", 424 | " a = a.to(device)\n", 425 | "else:\n", 426 | " print('CPU it is!')" 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": {}, 432 | "source": [ 433 | "And many more!" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "## A Quick Note about Batching\n", 441 | "\n", 442 | "In most ML applications we do mini-batch stochastic gradient descent instead of pure stochastic gradient descent.\n", 443 | "\n", 444 | "Mini-batch SGD is a step between full gradient descent and stochastic gradient descent by computing the average gradient over a small number of examples.\n", 445 | "\n", 446 | "In a nutshell, given `n` examples:\n", 447 | "- **Full GD:** dL/dw = average over all `n` examples. One step per `n` examples.\n", 448 | "- **SGD:** dL/dw = point estimate over a single example. `n` steps per `n` examples.\n", 449 | "- **Mini-batch SGD:** dL/dw = average over `m << n` examples. `n / m` steps per `n` examples.\n", 450 | "\n", 451 | "Advantages of mini-batch SGD include a more stable gradient estimate and computational efficiency on modern hardware (exploiting parallelism gives sub-linear to constant time complexity, especially on GPU).\n", 452 | "\n", 453 | "In PyTorch, batched tensors are represented as just another dimension. Most of the deep learning modules assume batched tensors as input (even if the batch size is just 1)." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 15, 459 | "metadata": {}, 460 | "outputs": [ 461 | { 462 | "name": "stdout", 463 | "output_type": "stream", 464 | "text": [ 465 | "torch.Size([10, 5, 5])\n" 466 | ] 467 | } 468 | ], 469 | "source": [ 470 | "# Batched matrix multiply\n", 471 | "a = torch.randn(10, 5, 5)\n", 472 | "b = torch.randn(10, 5, 5)\n", 473 | "\n", 474 | "# The same as for i in 1 ... 10, c_i = a[i].mm(b[i])\n", 475 | "c = a.bmm(b)\n", 476 | "\n", 477 | "print(c.size())" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": { 483 | "ExecuteTime": { 484 | "end_time": "2019-02-20T16:35:28.575416Z", 485 | "start_time": "2019-02-20T16:35:28.571576Z" 486 | } 487 | }, 488 | "source": [ 489 | "## Autograd: Automatic Differentiation!\n", 490 | "\n", 491 | "Along with the flexible deep learning modules (to follow) this is the best part of using a package PyTorch.\n", 492 | "\n", 493 | "What is autograd? It *automatically* computes *gradients*. All those complicated functions you might be using for your model need gradients for back-propagation. Autograd does this auto-magically! (Sorry, you still need to do this by hand for homework 4.)\n", 494 | "\n", 495 | "Let's warmup." 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 16, 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "tensor([-0.1662], requires_grad=True)\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "# A tensor that will remember gradients\n", 513 | "x = torch.randn(1, requires_grad=True)\n", 514 | "print(x)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": { 520 | "ExecuteTime": { 521 | "end_time": "2019-02-20T16:35:28.581665Z", 522 | "start_time": "2019-02-20T16:35:28.577506Z" 523 | } 524 | }, 525 | "source": [ 526 | "At first the 'grad' parameter is None:" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 17, 532 | "metadata": { 533 | "ExecuteTime": { 534 | "end_time": "2019-02-20T16:35:28.585548Z", 535 | "start_time": "2019-02-20T16:35:28.583493Z" 536 | } 537 | }, 538 | "outputs": [ 539 | { 540 | "name": "stdout", 541 | "output_type": "stream", 542 | "text": [ 543 | "None\n" 544 | ] 545 | } 546 | ], 547 | "source": [ 548 | "print(x.grad)" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": {}, 554 | "source": [ 555 | "Let's do an operation. Take y = e^x." 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": 18, 561 | "metadata": { 562 | "ExecuteTime": { 563 | "end_time": "2019-02-20T16:35:28.591506Z", 564 | "start_time": "2019-02-20T16:35:28.587586Z" 565 | } 566 | }, 567 | "outputs": [], 568 | "source": [ 569 | "y = x.exp()" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": { 575 | "ExecuteTime": { 576 | "end_time": "2019-02-20T16:35:28.596276Z", 577 | "start_time": "2019-02-20T16:35:28.593558Z" 578 | } 579 | }, 580 | "source": [ 581 | "To run the gradient computing magic, call '.backward()' on a variable." 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 19, 587 | "metadata": { 588 | "ExecuteTime": { 589 | "end_time": "2019-02-20T16:35:28.602746Z", 590 | "start_time": "2019-02-20T16:35:28.598538Z" 591 | } 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "y.backward()" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "metadata": { 601 | "ExecuteTime": { 602 | "end_time": "2019-02-20T16:35:28.610474Z", 603 | "start_time": "2019-02-20T16:35:28.605443Z" 604 | } 605 | }, 606 | "source": [ 607 | "For all dependent variables {x_1, ..., x_n} that were used to compute y, dy/x_i is computed and stored in the x_i.grad field.\n", 608 | "\n", 609 | "Here dy/dx = e^x = y. Let's see!" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": 21, 615 | "metadata": { 616 | "ExecuteTime": { 617 | "end_time": "2019-02-20T16:35:28.617215Z", 618 | "start_time": "2019-02-20T16:35:28.613422Z" 619 | } 620 | }, 621 | "outputs": [ 622 | { 623 | "name": "stdout", 624 | "output_type": "stream", 625 | "text": [ 626 | "tensor([0.8469]) tensor([0.8469], grad_fn=)\n" 627 | ] 628 | } 629 | ], 630 | "source": [ 631 | "print(x.grad, y)" 632 | ] 633 | }, 634 | { 635 | "cell_type": "markdown", 636 | "metadata": {}, 637 | "source": [ 638 | "**Important!** Remember to zero gradients before subsequent calls to backwards." 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 22, 644 | "metadata": {}, 645 | "outputs": [ 646 | { 647 | "name": "stdout", 648 | "output_type": "stream", 649 | "text": [ 650 | "tensor([2.8469])\n" 651 | ] 652 | } 653 | ], 654 | "source": [ 655 | "# Compute another thingy with x.\n", 656 | "z = x * 2\n", 657 | "z.backward()\n", 658 | "\n", 659 | "# Should be 2! But it will be 2 + e^x.\n", 660 | "print(x.grad)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 25, 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "name": "stdout", 670 | "output_type": "stream", 671 | "text": [ 672 | "tensor([-2.4034])\n", 673 | "tensor([-31.5575])\n", 674 | "tensor([0.8933])\n" 675 | ] 676 | } 677 | ], 678 | "source": [ 679 | "x_a = torch.randn(1, requires_grad=True)\n", 680 | "x_b = torch.randn(1, requires_grad=True)\n", 681 | "x = x_a * x_b\n", 682 | "x1 = x ** 2\n", 683 | "x2 = 1 / x1\n", 684 | "x3 = x2.exp()\n", 685 | "x4 = 1 + x3\n", 686 | "x5 = x4.log()\n", 687 | "x6 = x5 ** (1/3)\n", 688 | "x6.backward()\n", 689 | "print(x_a.grad)\n", 690 | "print(x_b.grad)\n", 691 | "\n", 692 | "\n", 693 | "x = torch.randn(1, requires_grad=True)\n", 694 | "y = torch.tanh(x)\n", 695 | "y.backward()\n", 696 | "print(x.grad)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "markdown", 701 | "metadata": {}, 702 | "source": [ 703 | "**Also important!** Under the hood PyTorch stores all the stuff required to compute gradients (call stack, cached values, etc). If you want to save a variable just to keep it around (say for logging or plotting) remember to call `.item()` to get the python value and free the PyTorch machinery memory.\n", 704 | "\n", 705 | "You can stop auto-grad from running in the background by using the `torch.no_grad()` context manager.\n", 706 | "\n", 707 | "```python\n", 708 | "with torch.no_grad():\n", 709 | " do_all_my_things()\n", 710 | "```" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "metadata": { 716 | "ExecuteTime": { 717 | "end_time": "2019-02-20T16:35:28.624048Z", 718 | "start_time": "2019-02-20T16:35:28.619908Z" 719 | } 720 | }, 721 | "source": [ 722 | "## Manual Neural Net + Autograd SGD Example (read this while studying unit 3)\n", 723 | "\n", 724 | "Before we move on to the full PyTorch wrapper library, let's do a simple NN SGD example by hand.\n", 725 | "\n", 726 | "We'll train a one hidden layer feed forward NN on a toy dataset." 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 26, 732 | "metadata": {}, 733 | "outputs": [], 734 | "source": [ 735 | "# Set our random seeds\n", 736 | "import random\n", 737 | "import numpy as np\n", 738 | "\n", 739 | "def set_seed(seed):\n", 740 | " random.seed(seed)\n", 741 | " np.random.seed(seed)\n", 742 | " torch.manual_seed(seed)\n", 743 | " if torch.cuda.is_available():\n", 744 | " torch.cuda.manual_seed(seed)" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 27, 750 | "metadata": { 751 | "ExecuteTime": { 752 | "end_time": "2019-02-20T16:35:28.630303Z", 753 | "start_time": "2019-02-20T16:35:28.626019Z" 754 | } 755 | }, 756 | "outputs": [ 757 | { 758 | "name": "stdout", 759 | "output_type": "stream", 760 | "text": [ 761 | "Number of examples: 100\n", 762 | "Number of features: 2\n" 763 | ] 764 | }, 765 | { 766 | "data": { 767 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzs3XdcleX/x/HXzd4bmQIOFBcq4N7b\n3Nuc5cwyLTUrLUvL1DQrK9NSSzNz5d5aKg7EhQv3QkUQ2Xuec/3+oB9+LUuEAzfjej4ePh6ew7mv\n630EP9znuq/7uhQhBJIkSVLZoad2AEmSJEm3ZGGXJEkqY2RhlyRJKmNkYZckSSpjZGGXJEkqY2Rh\nlyRJKmNkYZckSSpjZGGXJEkqY2RhlyRJKmMM1OjUwcFBeHl5qdG1JElSqXX27NkYIYTj816nSmH3\n8vLizJkzanQtSZJUaimKci8/r5NDMZIkSWWMLOySJElljCzskiRJZYws7JIkSWWMLOySJElljCzs\nklSG3b9/n0uXLqHVatWOIhUjWdglqQxKS0ujb5+X8PerQZ9eLfCp7snFixfVjiUVE1XmsUuSVLTm\nzPkEkXWa+2ddMDJSWLUhmZcH9uTylTsoiqJ2PKmIFfqMXVGUioqiHFIU5YqiKJcVRXlLF8EkSSq4\nXTs38c7rphgb66EoCq8MsCQpKYY7d+6oHU0qBroYiskBpgghagKNgfGKotTUQbuSJBWQg4MD9x/m\n5D1OTtGSkpqDjY2Niqmk4lLooRghRCQQ+dffkxVFuQq4AVcK27YkSQUz5Z2PGTWyP5mZAidHfRZ8\nn0H//v2xt7dXO5pUDHQ6xq4oihdQHzipy3YlSXoxnTt3ZuWqzXy7aB7xCbH06j2Et96apHYsqZgo\nQgjdNKQoFkAg8JkQYvMzvj4WGAvg4eHhf+9evtaykSRJkv6iKMpZIUTA816nk+mOiqIYApuANc8q\n6gBCiB+FEAFCiABHx+euOilJkiQVkC5mxSjACuCqEOLLwkeSJEmSCkMXY+zNgGHAJUVRzv/13HQh\nxG4dtC1JkkoiIyPZtGkTiqLQr18/nJyc1I4k5VOhz9iFEMeEEIoQwlcIUe+vP7KoS1IpduTIEerU\nrsaZ459w8sgsatfyJigoSO1YUj7JO08lSfqHSW+PZel8S/p0tQBg7ZZkpkwex4lguSxBaSDXipEk\n6SlCCELOXadHJ/O853p2Mifk3FUVU0kvQhZ2SZKeoigKvnWqsO9wWt5zew+l4VvHW8VU0ouQQzEl\nTHJyMtnZ2djZ2akdRSrHvli4hEEv92ZAjyy0Wvh9Zzq/b/pN7VhSPskz9hIiMzOToYOGUsGhAm4u\nbjRv0oLIyEi1Y0nlVIcOHTh95hIe1adQudZUQs5doXXr1mrHkvJJZ3eevoiAgABx5syZYu+3JAoN\nDWXRV4sIPHyE5PA0fLL80UOfMIOruDZw4GjQEbUjSsUsKyuLLVu2EBoaSqNGjejSpQt6evIcTMr/\nnadyKEZFp0+fpl3r9jhnenBfc596NMdAMQTAK6cGx8/sJj4+HltbW5WTSsUlMzOTjh2ao826S+um\nghnTF7NqZVM2bNwh11GX8k2eBqho5oxZuKVXxVPrgyHG5JCd9zUNOSgKGBoaqphQKm4bN25ET3uX\nQ5ttmfWuPSd22hN68RiBgYFqR5NKEVnYVXTn9h0shBUA7lThGueIF9EkiwRuml6gf7/+WFhYqJxS\nKk7nz5+lQ0vQ08s9OzcyUmjTzFhuaye9EFnYVdSl20s8Nn6AEAI3KmGPE5f0ggl3vsbQN19m2U/L\n1I4oFbOAgEbs+kOg0eRe+0pP13LgSAZ+fn4qJ5NKE3nxVEWJiYm0bdWOe7fvY6pnRqImni3bNtOu\nXTu1o0kqyc7Opnu39kQ/ukSrJvrsPphNo8adWLlqnRxjl/J98VQWdpUJITh16hSxsbG0bNlSDr1I\naDQa9u7dS2hoKA0bNqR169ayqEuALOySJEllTrFutCFJkiSVHLKwS5IklTGysEuSJJUxsrBLkiSV\nMbKwS5IklTGysEuSJJUxsrAXQEZGBrNnz8bP15+unbvJvSAlSSpRZGEvgAF9B7JkznK4ZM6dfZF0\n7tCZ4OBgtWNJkiQBsrC/sLt373Lo0CF80v2wV5xwVyrjlu7N3Nnz1I4mSZIEyML+wqKjozEzMEdP\n0c97zkSYEvUoSsVUkiRJT8jC/oLq16+P1lBDtIgAQCNyiDK7T/9B/VROJkmSlEsW9hdkaGjIth1b\nibC/xXmLI5w0OUDzl5owceJEtaNJkiQBcmu8AmnatCkRjx5y6dIlHB0dcXd3VzuSJElSHlnYC8jA\nwID69eurHUOSJOkf5FCMJElSGSMLuyRJUhkjC7skASkpKUyePIFq3m40aVyH33//Xe1IUjGJjY0l\nMDCQqKiyM2VZFnZJAoYM7k1k2Fo2/mjABxPimDJpJDt27FA7llTEvv56IVWrVmT6u/3x8anEzJkf\nqB1JJ+TWeFK5d//+ffz9avAgxBUjo9y9RddvS2bV5irs3XdM5XRSUbly5QptWjfk5G5HPNwNeRyT\nQ5Ousaxes4vmzZurHe+Z5NZ4kpRPqampmJsZYGj45Dlbaz1SkpPVCyUVuX379tGnixke7rnf+AoO\nBgzpY8iePbtVTlZ4srBLpUpOTg7Tp0/FwcESCwsTRo8eRkpKSqHa9PHxwcLSgW+XJ6HRCGJiNcxZ\nlE6/Aa/qJrRUIrm6unIz7OkRi5t39XBzK/33pcjCLhWZy5cv8+orA2nV0o+PPvqAZB2cAc+Z8wkn\njqzg5G4HbgW7kp6wh9fHvVqoNhVFYfOWPfy23RGn2uF4N42gfsMBTJgg7yYuy3r27ElElAWvTY1n\nz5+pTJkZz8lzegwePFjtaIUmx9ilInH79m2aNK7PlHEm1K9jxE9rM4iI8SLwyGkURSlwu5UrObN5\nhTG+NY0BSEzS4F4/nJiYBExNTQudOyoqCnNzcywsLArVzrVr1wgNDcXPz4/KlSsXOtf/io6OZs6c\nWZwIOkS1ajWZNv0TatSoodM+youYmBgWLvycs6ePUauOP1OnTsfV1VXtWP8qv2Ps8s5THRFCkJiY\niKWlJfr6+s8/oIz7/vtFjBpkwtTxNgC0bW5Knda3CQ4OpkmTJoVq+39/LxTid8QzOTk5Fep4IQTj\nx49my+b1NPK34PVxyYx7fQKffqqbZZ1zcnJo26YJzfwTWfCBCUGnA2ndqjFnzoZSsWJFnfRRnjg4\nODB37gK1Y+icHIrRgcOHD1PJozLOFVxwcnBm5cqVakdS3eNHD6nk+aTq6ukpeHkY8ejRo0K1+8qr\nY5j8cQphD7KJjslhwvREevfqppOzdV3Yv38/h//czLVjzmxeYcXlQGd+/mkxISEhOml/3759WJol\nsHieDc0amjJ1vA0DehizYsUynbQvlQ2ysBdSbGwsPbr1xCbclebZXamSUJe3xr9NeR9q6tq9Pz/8\nkk1ikgaA0+czCD6bQuvWrQvV7gcffEzDZiMI6BRNpYYPMbLqzJKlKwsfWEcOHz7IgB76WFrk/tdy\nsNenZydTAgMDddJ+TEwM7q4GTw1nebgJoqML9wvzReTk5PDgwQOysrKKrU/pxeiksCuK8pOiKI8V\nRQnVRXulya5du7BTHHFUXFEUBSvFFscMd9b8ukbtaKoaMGAATVv2o0qjSAI6xdF1aDw//fQrtra2\nhWrXwMCAuXO/IC4uhdTUDFas+BVLS0sdpS68ypWrEhL6pOgKITgXqtXZOHvHjh3540gy5y5lAPDo\ncQ7L1mTTs2fx7Aewfft2vDydadigBu5ujvKTQgmlqzP2lUBnHbVVqpiYmKBVNE89J/S1JWZoQC16\nenp8++2PXL5ymyU/7uLBg8f07t1bp30U5iJsURk0aBC3wiwYPiGeVRuS6D0ijug4E8LCwoiMjCx0\n+y4uLixd+jOdByfg2yaWmi0fMXjoeDp06KCD9P8tPDycEa8OYu0SUx6ed+XP32348IPJnDt3rsj7\nll6MzmbFKIriBewUQtR+3mvL0qyYtLQ0vDwqYRPvjJPWnQRiuGd2jZNngjly5Ahbf9+Gh1dFpkyd\nQrVq1dSOW2JdvnyZ9evXYWxszNChw/D09FQ7UoHFx8ezdOn3BB7ex7HjJ+nRyQpDQz12HUhj67Y9\nOrmrMT09nevXr+Ph4YGdnZ0OUufOCAoPD6dWrVqYmJj84+tLliwhOHAmP39tnffch3PjUSxG89ln\nc3WSQfpv8s7TYmJmZsaxoKNUbutKqNUJTP1g196dzJwxk5mTPyX8j1gO/hxEA/+GXL16VZWMMTEx\nLFmyhC+//JK7d++qkuG/bNiwgTatG5ERt5TIO98Q4F+bEydOqB2rwGxtbZk27QPi4mJZ/qUdvy62\n5eevrVk8z4LJk17TSR+mpqbUq1dPJ0VdCMHbb7+Oj08lRgzviEdFJ3bu3PmP11lYWBCX8PSJYFyC\ngoVFyRkKk/4ihNDJH8ALCP2Pr48FzgBnPDw8RFl269YtYWFqKdrQW7RX+on2Sj/hrV9HDB/6SrFn\nuXDhgqjgaC0G9XEUY4c7Cns7c7F169Ziz/FvNBqN8KjoKI7tcBeaSG+hifQWq751Eq1bBagdrdD0\n9fVE6t0qee8rLayK0NNT1I71D2vXrhX169iI2GuVhSbSWxzf6S5sbc1FfHz8U69LTk4W7m4OYvb7\njuLKUU/x3dwKwtHBSoSHh6uUvPwBzoh81ONiO2MXQvwohAgQQgQ4OjoWV7eqCA8Px9LQGn3lyXx2\nM40ld27dKfYs06e9xYxJxvy62IYln9uwaYUtb00ci1arLfYsz5KSkkJMbAKN/Z989G/XwozQy9dV\nTKUbvnWqsu9wWt7jfYfTqOvrrWKiZ9u5fQPjXjHCxjr357WxvykN6pn/YyaPhYUFhwODCb3blO6v\nZLI/qA77DwTi5uamRmzpP8gblIqAv78/KTmJJIk4rBQ7tEJLtGk4Q/q8WexZQkIu8N2sJ2OiLRqb\nkpAYQ3x8PPb29sWe5+8sLS1xd3Ni36E0Orc1B2DzrhQaBJT+bQe/WLiEAf17sP9wNooCG7ans37D\narVj/YOjkyv3w58MsWi1ggcPs3jWCViVKlVYu25rccaTCkAnhV1RlLVAa8BBUZRw4GMhxApdtF0a\nWVhY8MuaXxg2ZBjWBvak5iQR0MifCRMmFHuWunVrs+fgdV4bnlvcT5xJx8rSAhsbm2LP8iyKorD4\n+58ZOLAXL7XNIS0dgk5n88efi9WOVmht27blbMhl1q9fjxCCMx8PxMvLS+1Y//DGG2/RrOkv2Fgn\n5C3/YOdQudB3CEvqkWvFFKGkpCSCgoJwdXXF19dXlQznzp2jU8fWdG1vjIW5YN3WdJYsXUm/fsUz\n7zm/Hj16xNatWzE2NqZ3794l5hdPeXHx4kUWzP+Eu3du0qrNS7z33nSsrKzUjiX9TX5nxcjCXg5E\nRUWxdu1a0tPT6du3r5x2KUmllFwETMrj5OTE22+/rXYMSZKKiZzHLhXYjRs3aNuqHWYmZlSv6sP2\n7dvVjiRJErKwSwWUmZlJy+ateHgsloaZHTC/7cjQQcM4e/as2tGkAkpKSuKtt96gahVXmjSuw+bN\nm9WOJBWQHIqRCuTAgQPoZxriIbxBAXucSclIZNmPy/H/wV/teFIBDBzQHUery2z5yZywB3GMn/gq\npqamvPTSS/k6/vLly/z226/o6+szdOhweS1HRfKMXSoQjUaD3t9/fIRCTna2OoGkQrlz5w7nzp1l\n2UJbalU3pmt7cz6bZsb33+VvE4otW7bQpnUjNEnLSYv5gWZN/fjzzz+LOLX0b2RhlwqkQ4cOpOkl\nE8m93N2jRBxRpvcZMWqE2tHKhN9+W0PtWpWws7Pg5YE9iYiIKNL+UlNTsTA3wOB/PsPb2uiTkpr0\n3GOFELw79U3W/2DDnA9smf+RLT8ssGLa+8V/34aUSxZ2qUDMzMz44+Af6NfM4pCyhXv2oXzz/dc0\na9ZM7WhFRgjBoUOHWLp0KRcvXiyyfvbv38+0917n20+zCT1cAa8KQXTr2o6inJpcq1YtDI1sWLIy\nCa1W8DgmhzmL0uk/4Pm/qHNycrhzN5IWjZ8sVd2mmSlXrhb/EhpSLlnYpf+UlJT0rzvl1K9fnwuh\n50nPSOdR9CNeeeWVYk5XfLKzs+nWtR0T3ujD6aMf81LnZrz33uQi6Wv5sm/4cJIprZqa4VzBgM+m\n25Ca8uiFt9fTaDRk53NoTE9Pjy1b97JqkwOONcOp3iySJi2GMG7c68891tDQkLq+3mzalZL33Ibt\nKTRqWO+F8kq6Iwu79Ex37tyhoX8jHB0csbW2492p7/3rwmFGRkYlctMLXVq3bh1JcRcIOeDAsoXW\nXDxYgdW/LOPy5cs67ys7Owtjoyf/noqiYGSo5LtIZ2dnM2nSeGxszLG0NGNA/27ExcU99zgfHx9O\nn7nMtethRETEsHDht+jp5a9EfLf4Z8ZPS2HA2ET6jEzk4wVZLPxyab6OlXRPFnbpH4QQdOnUhcTz\nGTTP7oZfRmtWfb+aH374Qe1oqjlx4jB9u+phYJBbcG1t9OnQyrxI1o0fOuw15n6TwfVbWWRnCxb9\nmEhmtjkNGjTI1/GzZ8/kUsg6rh93JSrUE3vzk4wc8XK++3dycsLc3PyFMjdt2pTr18Po2nsefV5e\nwPUbYdSrJ8/Y1SILu/QP165dIyryMRW13ugpepgoprimVeanH39WO5pqfHzqcuTkk8dZWYITZzKo\nWbOmzvvq06cPY8Z9QKve8VhWucu2Pz3YuetP9PX1n38w8NualcyfYYFzBQMsLfT44mNr/vgzkKSk\n518ILQx7e3tGjBjB8OHDsba2fv4BfxMbG8uSJUv44osvuHXrVhEkLD9kYS9mV65cYdzYcfTu0Ye1\na9cW6QWxgjI2Nkaj1SB4kk1DTrnex3XEiBHcDLOmz8h4Fi6Jp1XvWOr4Ni2SFRAVRWHy5Kk8ioon\nKSmFw4GnX2hOuIGBPlnZT7532Tm5f//7sMrp06dZvHgxhw4dUv3n8PLly9SqWZWjf3zMncuf07hR\nXTZu3KhqplItP7tx6PqPv7+/DvcUKT3OnDkjLM2tRFX92qIG/sLR3Em8+cabasd6phZNWwhPI2/R\nlM7Cj5bC1sxObN68We1YqkpOThaLFy8WEyaME+vWrRPZ2dlqR3qmL774XDT0sxHn/vQQt095if49\n7MXQIX3zvq7VasUbb4wWnhUtxZhhFUTN6jaiR/cOqr6fXj07ii9nOebtNnVid0Xh6mJXYv+N1UI+\nd1CShb0Yde/SQ1RX6udtl9eKHsLMxExER0erHe0f4uPjxavDRwh7G3vhXbmaWLVqldqRpHzSaDRi\n7tzZwtMjd+u6N98cK1JSUvK+fvLkSeFZ0VIk3Mzdti/jflXRyN9ObNiwQbXMnh6O4nqQZ15h10R6\nC0cHMxEREaFappIov4VdLilQDG7evMn58+e5dvUa1sIV/prwYKgYYWZoTkREBA4ODuqG/BsbGxt+\nXvWT2jGkAtDT0+P99z/g/fc/eObXT506Rec2plha5A7NGBoq9OqsEBx8jP79+xdn1Dz16tVj78Hz\nvDnKCIAz5zMwNDR55i5O0vPJwl6EhBC8PXESP634CXtDJyJTH+Cgn4m1JndLumuEEJ8cj79fAKZG\npjg4ODBx0gQmvjUx39PMJOm/CCE4deoUoaGhNGjQAF9fX2rUqMHSxbkzbgwNldwbr4Jg4NC6quWc\n/dmXtGvbnNMXErC1Fqzdks633y3HwECWqIKQG20UoePHj9O9Uw98U5tjqBiRIdI4pfcnZsbmoIXs\nzGzq0AQjjLlNKCkkYWxmyIT3x/PhjA/Vji+VclqtlmFD+xF84k+aNTTh4LE0+vYbyldfLaZnj47E\nRJ2jV2eFQ0EQm+jMkaOnVb1AHh0dzfr160lJSaFPnz5yEbFnkDsolQBz5sxh+UerqaKtnffcLUJp\nP6o5+3cfwCHCEzulAgBaoeEIO/GlCfftrhId+1it2FIZsWPHDj764BWOb7fDxESPpGQN9dpFs37j\nAfz9/dm0aRPBwceoXbsegwcPLteznkoLuYNSCVC5cmUyTVMRKQJFyf3Im22RTqtWrTgWeByF/71b\nM/fvhhiSlp6mTuBSLiYmhhMnTuDl5UWdOnXUjqO648eP0uclPUxMcof1rCz16dbBhKCgIBo1asTA\ngQMZOHCgyimloiAHcotQ7969sXGz4rppCBEijBum57BwMaVfv3689vpYHpjdIE2kkCOyucEFrLAl\n0vgeffv0VTt6nsePH7Nx40aCgoJUn+v8X1atWom3twfffTWari81o1/frvm+Bb+s8vGpybHTSt73\nTaMRnDijoXr16ionK7uysrLIzMxUO4ac7ljUkpKSxPz580WfHn3E559/LhITE4UQuVPSPvxghrA0\ntxQKijBUjISxobHo0qlr3muKU1ZWlti0aZOYO3euOHbsmNBqteKXX34RZibmwtOqirC3cBRNGzUV\nqampxZ7teaKjo4WNjam4fCR3ulz6vaqidTM7sWTJErWjqSotLU341a8hunVyEF9/6ijatrAXbds0\nETk5OWpHK3PS09PFa6+9KszNjYWJiaEYOKCHiI+P13k/yHns/23Dhg2iaiVvYWFmIbp36S4ePHig\nSg6tViu0Wq2IiYkRsbGxqmRIS0sTfnX9hYuFu6hk4CNsze3EiFdGCHNTc9GYjqK90k+0o69wN/US\nc+fOVSXjf9m2bZvo1NbpqTnQK79xEgP6d1U7mupSUlLE4sXfiXHjRoiffvpJZGRkqB2pTJoyZaLo\n3sleRIVWFgk3q4hRQxzEywN76ryf/Bb2cnnxNDAwkJ5delElrQ7mWBOhfxc8Mrlx63q5nGb4/fff\nM2fqfHzSAlAUhRyRzWnjPzE3NMc3tUXe6x6Lh9g2M+bwsUMqpv2nCxcu0KNbC64fd8Lor1URp85K\nQN9yCPPnf6VyOqk8cHWx49AmK7wr587DT0rW4OL7gISEZIyNjXXWT34vnpa/KgYs/mYxLumVsFOc\nMFZM8NL4kByTQnBwsNrRVHHieDCWaXZ5S+8aKIY46DuTmJmIRuTkvS7VIInqNUreFLS6deviH9CM\nrkPj+PX3JN79JIHftmTz5puTiqV/NU6OpJLFyMiQjMwnPwdZWaCvr6faiWKpKuzZ2dlcunSJx48L\nNxUwMzMLRTx564qioKD3rxtKlETLly2noqsHpiZm9Ojak8jIyAK31bBxA1LMEp5cZBM5JBBD+/bt\nuGx+knBxhzuGl4kzi+C9ae/p6i3o1Lr12+g3aDY7A/1QzAdz8tQFPDw8irTPsLAwurzUGkNDA9zd\nHPj++2+LtD+p5Bo95g3Gv5/MxSuZ3LidxajJiQwbNhhDQ0N1AuVnvEbXfwoyxn7o0CHhYOcoHCwd\nhZmxmRg1YnSBLwLt2LFD2Jrbi8Z0FG3pI2oofsKlgmupWXBo69atwtbMTjSgjWhJd1HFoKaoXaO2\n0Gq1BWovJSVF1PKpLVwtPEQlagh7c0cxbMhwkZOTI9asWSMG9B0gpr7zrrh3756O34kQsbGxYsaM\n6aJbl1Zi+vR3S+S6Oc+i0WhErZqVxez3HUXKnSoi5A8P4V3FSmzdulXtaJIKcnJyxOzZs4SXZwXh\n6mIn3nnnrSK5nkFZGmNPT0/HxcmVysm1sVecyRHZXDU/xWeLPmXUqFEFyvD114uY9fFMklOSqVen\nPr/8tqpI1tYuCh3adiTyUCIuSu4ZqRCCEPNDHA46hK+vb4HazMjIYNOmTdy4cYPmzZvTvn37It8V\nKT09nQD/WjSsm0y3DobsPZTNkZOmhJy7+sIbPRS306dPM3xIR0IDHfL+nVZvTGLrn75s2bpf5XRS\nWVWmblA6ceIE5ooF9oozkDsG7Jjqwfo16wtc2N9++y0mTpxAVlYWJiYmuoxb5DQ5mr/d3AR6ih4a\njabAbZqYmDBkyJDCRnshmzdvxs0pmRVf2QLQuwv0fDWe9evXM3LkyGLN8qIUReHvp0RaQZnfIlAq\nHUrFGLuDgwNpmtSnLlJl6aXj5OJUqHb19PRKXVEHeG38WCLN75AqktAIDff1b2DraPuvW5GFhYUx\n6e3J9O7Rh5UrVxbqF4AuhYeHU/Nv12JregvCw8PVCfQC/Pz8MDZx5JOFiSQmaTh9PoPZX6UzctQE\ntaNJUukZY2/RtIVwN6kk/Gklaih+wtLMUoSEhLxwO2rJyMgQW7duFb/99puIi4srVFtarVYsmL9A\n2FjZCj1FT7Rs1krcvXv3ma+9efOmsLG0EVUMa4qaBAgnc1cxsP/LhepfV86cOSPcXC1ExMVKQhPp\nLaJCKwvPipYiKChI7Wj5cv/+fdGzRwdhYmIoKldyFsuW/aB2pHIrJSVFTJs2VdT1rSI6tG8q9u3b\np3akIkFZu0EpOTlZvP/e+6K2Tx3RpVNXERwc/MJtqOXu3bvC1clNuFl6CA/LysLS3EocOnQo7+uJ\niYkiLCzshS9+arXa517wHTNqjKiqXztvc4829BIWppbi1q1bBXkrOvfppx8LWxtT0bZFBWFrYypm\nzHhf7UhSKdS9W3vRr4e9CNpVUaxZ4iycnSye+j9WVuS3sJeKi6elXa/uvbm85xZeWh8AYkQkcW4P\nuHPvDpPemsTy5csx0DPE3sGeTVt/x8/PT2d9t27ehvjjmVRQ3PKeu2p1kl82r6Rt27Y666cwIiIi\nCA0NpVatWri5uT3/gCIUGhrKkiWLSIiLplefIfTr10+Om5dwd+7coWkTX+6dccHQMPd79dPaRPYc\nrc+mzXtVTqdb8galEuT48eM4aSrmPbbHmeiYGL755hvWr/ydBpntaZjeAYsHjrRs1hKXCq7U9qnD\n2rVrC913lx4vEWsakXd9IkUkkpAVR0DAc382io2rqysdO3ZUvagHBwfTpnVjXCy30brBCT6dOZbp\n06fqvJ/09HTOnDlDVFSUztsujxISErCzMcor6gBOjgbEx8eqmEpdsrAXgypVqpDIkx+yVJIxNjJi\n88YtOKV6YqQYoygKzooHIkPBPtodo+s2vDH6TbZs2VKovidMmEBVfy/OmR/mhuVZLpkE8ePyH7Gy\nsirs2ypz5s2dwez3zZn+tg0COtBgAAAgAElEQVSjBluzf70tS5Z8T0JCgs762LJlCx4eTox6tRM+\nPpWYPPlNeedqIdWtW5eMLBN+/T0JIQQJiRq++D6dXr2HqR1NNaW6sGs0GubNnYd3pWrUrFaLJUuW\nlMj/JAu+nM89s6vc1b/CPa5z1ewUn837DHt7e7KVJ3e7CiHQoMESW+wVZyqmVWPB3C8K1bepqSmH\njhxi3+G9LFr1JffC7zF48ODCvqUyKSzsLvVqP1nXo4KDAbbWhjo7s46NjWXkyKHsXG3DuT/suXXC\nhYMH1vD777/rpP3ySl9fn02bdzHnW2M8Ax5RuVEENev2Yvz48WpHU02pmMf+b96ZPJV1yzfgnuaN\nFi0fTZ1Jelo6k6dMVjvaU1q0aMHpkNMs+3EZqSmpDBk2hObNm1OnTh26HeyOYZohplhwj+uYYYGF\nkns2bYghKSmJhe5fURQCAgJK1PBLSdSmbWeW/bqOgLq5n6AOBKai0RpTpUoVnbR/+PBhmjawoEG9\n3Cm2tjb6jB1qyK6dv6u2iXRZUb9+fa5eC+POnTvY2tpiZ2endiRVldrCnpOTw48//oBfRhtMlNwt\nvQxSDVm44MsSV9gBqlevzhcLnz77btGiBRu3bGDmh7OIiLhNdkoa9kmuCCHIJosIsztMHVXy3ktZ\nNWPGJ3R56Si+bR7gXMGQC1fS2Lhxu842VK5QoQL3HmTnzlr464Ls/YfgWMFFJ+2Xd4qi6OyXcGlX\nagu7RqMhKzsLQ54ssmOIEalpKSqm+ndxcXF89eVXHAs8TkDjAN55ZwpOTk507NiRjh07AnD79m36\n9e5P0PXdoCj06t4LN3c3bty4ITf2LQZ2dnacCD7PiRMniI+Pp1WrVlhYWOis/WbNmmFl48krEx/w\n6kATQi5m8dPaDE4Ey5uaJN3SyXRHRVE6A4sAfWC5EGLef71eV9MdO7bvxO3AB3jl1EAguG18kfaD\nWrHi5xWFbluXMjMzqV2zDlkPBTaZjiQZxZFtl8qV65efeRHz8ePHzPxoFqt/WY29oROx2VG88upw\nvl38rZx6V8olJiby+eefcTRwP5WrVOPd9z6mVq1aaseSSon8TncsdGFXFEUfuAF0AMKB08AgIcSV\nfztGV4U9KiqKvr36EXIuBIC2bdqwdsNaLC0tC922Lm3YsIHJo6fik9wgrzDfMAth+sJ3ee211/7x\n+uPHj9O9Uw98U5tjqBiRI7K5YH6Mrbs307Jly+KOL0lSCVGci4A1BG4JIe781fE6oCfwr4VdV5yc\nnDh24ihRUVHo6+vj4OBQ1F0WSHh4OEaZpk+dbRukGXP//v1nvj4wMBCbjAoYKrm7sRgohtikOxIY\nGCgLuyRJz6WL6Y5uwIP/eRz+13NPURRlrKIoZxRFORMdHa2Dbp9wcnIqsUVdq9XSunVrYvQjyRTp\nAGSJTOJMo+jUqdMzj6lUqRKZpml5UzeFEGSZpVG5cuViyy1JUulVbPPYhRA/CiEChBABjo6OxdWt\narRaLdPfn46VhTUNGzbCxdWFM8YHuWZ1mjPGBxn9+ihatGjxzGP79OmDjZslN0xDiBBh3DANwdLF\nnL59+xbzuyg+QghCQkLYs2cPSUlJaseRpFJNF0MxD4GK//PY/a/nyrVvFn3Dim9XUi+9BUYYc+/+\ndWrWqsXc+XOoUaMG7u7u/3qssbExwadPsHTpUoKPBdOoWX/GjRtXKpcYzo/U1FT69O7MzRsX8apo\nwtArqaxatY5u3bqpHU2SSiVdXDw1IPfiaTtyC/ppYLAQ4vK/HVMeFgGr5VMbk+v22CkVANAKLSdN\n9nH91nXV10QpaT79dBbnTn7L+h9s0ddXCD6bTpchcSxY8DXNmjUrNTtbSVJRK7ZFwIQQOcCbwD7g\nKrDhv4p6eaGvr494ao+d3OU01dq1vCQ7dHAnY4YYo6+voNUKvv85EROjTP7YPZ12bRvy7ruT1I4o\nSaWKTqqMEGK3EKKaEKKKEOIzXbRZ2r351ngemF0nScSTKdK5bRRK48ZNcHGRdxn+nYdHZUKvZQOw\nY38qodcyuXPKi7VLbAk97MTaNcs5e/asyiklqfQotXeelnRjxowhKSmZhfMXkpySTM8ePVi8dLHa\nsQotJiaGvXv3YmlpSefOnTE2Nn7+Qc8x5Z0Padd2D6lpCZw+l0r/7paYmOSec9ja6NO1vSlBQUH4\n+/sXui9JKg/kuEARURSFd96ZQuTjCFLSklmzbg02NjZqxyqUffv2Ua2aJ5vXTeGr+aOpVbMKDx48\neP6Bz1GnTh2OHD1FTHovHsVX4uhJTd5UT41GcPqCBm9v70L3I0nlhdxBScoXjUZDlcquLF9oSNvm\nZgDMmBdPeFwbVv2yQWf9pKam0rRJfap6xtO2mcKW3RqEQXX2HziKvr6+zvqRpNJI7qAk6VRERARZ\nWWl5RR1gcB9zgoKO6bQfc3Nzjh0/S8v2H3D+dideHv45u3YflEW9GAghOHv2LDt37tTp5iJS8Svz\nY+xRUVEsX7ace2H36NajG927dy+VC2ndvn2bhIQE6tWrp0qRc3R0JCsLbt3Nomql3KUOjp9Op1q1\nGjrvy9LSkrfeelvn7Ur/Li0t7a97CS5QydOEYcNSWLFiNX369FE7mlQQ+dnxWtd//P39C7pJ9wu5\nd++ecLBzFF4m1YQ3vsLBvIJ4/bXXi6VvXUlJSRHt27QXlqZWwsGygnCp4CpCQkLydWxOTo6Y/cls\n4eHqITxcPcVnsz8TOTk5Bc6yaNGXwsPdUsyZ7iCmvO4oHOwtxMmTJwvcnlRyzJ37mejR2V5khVcV\nmkhvcWpvRWFnZyGSk5PVjib9D+CMyEeNLdNDMfPnzcc62ZGqmb54KtWondqEX1at/tfFt0qiWTM/\n4dqJ2zRIb0+9lJbYPXand48++doC8INpH/DdvCU4R1TFOaIK38xZzIwPZxQ4y8SJk1i9ZieRyf0w\nth1J8MnzNGzYsMDtSSXHwT92MGqwEfr6uZ9m/eua4F3JhHPnzqmcTCqIMl3YL10IxSL7yUwUA8UQ\nG2M7bt26pWKqF7N10xZcMrzQU3K/Vc5UJCEugdu3bz/32O+//54qab5YKXZYKXZUSfPl+8VLCpWn\nZcuWfPPNEj77bJ7craYM8azkzaUrOXmPU9O03A5Lo2LFiv9xlFRSlenC3rZjG+JMHuWd3aaJFOIz\nY6hXr16x9H/lyhXWr1/PzZs3C9yGk5Mz6aTmPc4hmyxN1nP3dBRCkJGZif7/XEYxwJCMzPQCZ5HK\nrilTprNoeQYffZ7ALxuS6PRyHN179MTLy0vtaFIBlOnCPmnSJOyqWhFqGcQtswtcMDnGwq8WFvlG\nt0IIxowcQ5OAprw35gP8fP2ZMmlKgdqaMetD7ptd46G4S4yI5LrZWQYNGvTc96AoCn169+Ge8TU0\nIgeNyCHM+Cp9e5fdFSKlgvPx8eF40FlSRH/2HAtgzOsLWbZstdqxpAIq8/PYNRoNf/75Jw8fPqRt\n27Z4enoWeZ9//PEHA3sNwje1GQaKIdkii/NmR9l/aG++x6QvXLjAqFdGEXIxBAdbR9zc3TA2NGbI\nK4N544038jUzJjExkaGDhnLgjwMAdOzQiV/Xrn7mdnySJJV8xbmDUommr6+ft1n038XFxREXF0fl\nypV1ujjXoUOHsE51xEDJ3WjbUDHCPseJwMDAfBX21NRU2rZui3NiZdqI3iTGxXIz4xx/HP6DBg0a\n5DuHtbU1O3bvyFvfXBZ0SSofyvRQzL/RaDSMHjkGd9eKBNRtQCWPyjpdZMrb25ss86d3QEo3Ssn3\nxcZ9+/ZhprHEldyLpraKIxUyPPl5xc8FymNlZSWLuiSVI2W6sB85coQunbrS0K8RX3zxBdnZuSsI\nLlmyhJ3rd9Mwsz0Bae2weuhEt5e6k5OT85wW82fAgAGYORlzw/QcD8VdrpuepYKXA927d8/X8Yqi\nwD/uoRKl8sYqSZKKX5kt7IcOHaLbS90J2x9F9jlDvvx4Ea8OexWA335Zi3OaF4aKEYqi4Kx4oMnQ\ncv78eZ30bWZmxumQU7w58zVq9a7ElDlvcTz4GIaGhvk6vlOnTmTopxLObTQihzgRRZTJfUaNGaWT\nfFL5dvfuXUaNGkqDAB/GjBlequ7rkPKnzI6xz571GRXTquOq5F4stUlzZOu2bURFRWFnb0cCkXmv\n1QotGTnpL7z6ohCClJQUzM3N/zFGb21tzbvvvlug7GZmZhw+epgxI8dw9OxOPNw8WPX1Svz8/ArU\nniT9v4SEBFq2aMiIgTBqlgk79++hRfN9XAq9KYfrypAye8Ye9SgKU54sWKWPPsYGJsTGxvLutKk8\nNLvFI3GfRBHLDZNzNGzUgKpVq+a7/aNHj+Ls6IKNlQ1mxmZMnTpVp/lr1apF0MkgsnOyuX3vNr17\n987XcRkZGRw4cICTJ0/m6+5UqXxZu3YtTQMUZk61pbG/KbOn2eJXR7Bx40a1o0k6VGYLe98BfXhk\nEoZWaAB4TDhmFib4+PjQsmVLNm37HdsmJiRUCuflN/uxbee2fLedmJhIh7YdcIqtRBt6UzenOV9/\nsYivvvqqqN5Ovpw4cQJXZzde7TeKru27U9+3PrGxsapmkkqW2NhY3F20Tz3n7iLkz0kZU2bnsaen\np9Ovd3+OBB7BxNAUfRM9du7eQUDAc6eAPteKFSuYOnoa/kqrvOfuiqukOsYR+Tii0O0/y+XLlzlw\n4ACurq707NnzHzsXabVavCpWwjbCjQqKG0IIbhldpOOw1vy4/MciySSVPhcuXKBzp2Yc2epAFS8j\nbtzOomWvaA4HnpabhpcC5X4eu6mpKbv27iQsLIy4uDh8fX0xMNDN283OzkaL5qnnNGjISM/QSft/\n98UXC5n10Swcta5kGaUz/f0POHk6GHt7+7zXPHjwgMT4RKqRO89dURScszzZu2dfkWSSSqe6devy\n4Yw5NOoyDWdHYx5FZzFv3kJZ1MuYMnvGXpTS0tKwsbTBU1sdFzxJIJarnGXo8CH8vKpgc83/TUxM\nDJ4VvfDLaIWJknvN4KbheV5+uy/zPp+X97qUlBScHJ3wz2iDsWIKQKS4j01jIw4fPcS1a9eoUKEC\nFSpU0Gk+qXRKSUnhzp07VKlSBXNzc7XjSPkkd1AqQmZmZvy+5XcijO4QxF6uKSH41q3Dd99/p/O+\nLl26hK2xfV5RB7DNcuL4kaCnXmdhYcHEiRO5an6aCBHGfeUG982uMnDwANxdKtKqSRsqeVRixCsj\n0Wg0f+9GKmcsLCzw9fWVRb2MkmfshaDRaLh48SLW1tZUrly5SPqIjIykamVvAjLaYqTkjqvf0r+E\naWU9PD29GDxsEEOGDEFPTw8hcmc3rF75K7a2NowbP47uXbtTMcEHR8WVHJHNVbPTfPzFh7z++utF\nkleSpKKT3zN2WdhLgWnvTeOHxT9im+ZElnEGkRkP8FS8MRUWRJs/4OURA1j07aJ/HHfs2DH6dx1I\n7eSmec89Fg+xaWZE4LHDxfgOJEnSBTkUU4bMmTeHrbu3MOC9nphXNKIKtahMLVwUT3xSG7Bs2XLi\n4+P/cZy9vT1pOWloxZPpbZlKOk5OTsUZXyrFIiMjWb16Nfv27ZNDeKWILOylgKIotGzZkrlz56LN\nEVjzZC12I8UYEwMTHj9+/I/jatSoQUADP26YnCNeRPNQ3CXC9DZT33/nhfqPjY1lxocz6Ni2Ex/N\n+Ii4uLhCvyep5Fu79jdq1qzC9t+nMmPaEBoE1Jbf+1JCFvZSpkv3LkQZ38+7qzRGRGJsZvSvd81u\n37Wd4ZMGkVE9Fo/2DuzZv+eFlv5NS0ujgV8DVn+xnkeHkvhlwVoaBjQiPV3uxFSWpaam8uabYzm0\nyZH1P1hxYpctdX2i+fzzz9SOJuWDLOylzKxPZuJUy45z5oe5YhlMmNUVNmza8K8bb5ibm/PZnM8I\nvXaJvQf20qxZsxfqb8OGDWTHCryz6uKkuOOdVY/M6Gw2bdqki7dTauXk5DB79ixq1vCkXt2qfPfd\nN2VqCYfLly/j4WaCb83cC/aKojCkrwlBx/5UOZmUH2X2BqWyytrampNnThISEkJ8fDzNmjXD1NS0\nyPp78OABhukmTz1nmGFa7lcEfPfdt7lwZi0/f2VORkYmkz76mKysTCZP1u2aQWrx8vIi7EEaMbHW\nONjnnjScOJ2Fd/VaKieT8kPOipH+U1BQEF07dMM3rTlGijFZIpMLZkfZf2hfvrf5K2tycnKws7Pk\nyhFnDAwU7Gz0uXglk5dfz+HW7aJZUkIN06a9w9ZNyxkzxJC7DxQ27sgi8MhJqlevrna0cqvcLymQ\nk5NDYGAg2dnZtG7dGhMTk+cfJP1D06ZNeW38WL779jvsjB2JzYjm7bfeKrdFHXLvX0hPz6JFj3AS\nk7WYGCtMed2WlJRstaMVysOHD1m6dDEPH4bRsWMPZs/+nObN27Bz52YquLly5uxruLu7q5YvKSmJ\njz6axt4923F2duLd9z6hS5cuquUpycrkGfu9e/do1bwVGYnZ6Cv6ZOmnc+DgAerVqwfk3qZ/+fJl\nfHx85NS/fLp//z6hoaH4+vqq+p+7JIiMjMTbuyLbVjnTppkZ5y5l0KH/Qzp27sW6daXz2kNYWBhN\nm/jRt6sBPt7w89oc6jfoxrJlv6gdLU+H9s1wsbvOpNfMuR2WzcQPk1nz23batGmjdrRiU67P2MeP\nG49hpCXeWh8AIrnHsMHDuXTlIgvmL2Dmx7OwMbYjITOWd6ZOZdYnM9UNXAp4eHjg4eGhdowSYdu2\nbfR6yY42zXKXeahfx4TRQ23ArJLKyQruq6/mM7yfAXM+yN1sZnh/LVUabWLatJlFdlf1i7h27RpX\nr15k9yln9PUV6tYyJj5Rw/eLF5Srwp5fZXJWzJGjR3DReOU9dhYeXLtxjZMnT/LpzNn4ZbSiZlIj\n/DPasmjhNwQHB6sXVip1zMzMSE55ev/Z1DR9bGxsVUpUeHduXaFB/SfneeZmetTwNufOnTsqpnoi\nOTkZGytD9PWf/Lvb2+qTlJSgYqqSq0wWdne3iiTz5E7MVJKwNLcgMDAQB41L3oJaxooJdulO7Nmz\nR62oUinUu3dvzlzQsHBJAmEPslnxWyIbtqczbNhwtaMVWPOWnfhlYxZabe7Q7LWbWVy4koK/v7/K\nyXL5+fmRlmHCyvVJCCGIjslh/uIM+vZ7Ve1oJVKZLOzzFszljlko95UbPBC3uGZ2hk8/+xQ3Nzey\njZ5eM11jloWbm5tKSaXSyNLSkoOHjhN8yZ9WvZPZtK8au/ccpGLFimpHe6bQ0FACAwPJyPj3/QLe\nfHMiSelVqds2hr6jk2jW4zFfffUdtrYl41OIvr4+W7ft5evlFrj4RlCtWSTNWg1h9OjRakcrkcrk\nxVOAkydPsmTxUjIzMxk5egQdOnQgPT2dmtVrQpQRNlmOJBrGkmmXzNUbV+RGvlKpcerUKRZ9PY/o\nx5F07f4yb7zxBoaGhv94XUpKCn37vMTVKxdwcTImLDybjRu307Jly2e2K4Tg+PHjhIeH06pVK1xc\nXIr6rbwwIQQPHz7E2toaS0tLteMUO7m647+Ijo5m/ufzCToSRECjAN6f/n6J/AGWpGcJDg6mR/f2\nfPi2GR7uBnyzPBPPKh34eeW6f7z2ww+ncePSMtZ8b4u+vsKeP1MZ914Wd8MidbabmFS8iqWwK4rS\nH5gJ1AAaCiHyVa3lDUqSVDD9+3WlbcMzvDbcGoDUNC1eARFcvHTzH0OKjRrWZP70ZFo0fnJnsk/z\naDZtOUKdOnWKNbekG8W1bG8o0Ac4Ush2JEnKh0eRD6la6cmwi7mZHk6Oz17d082tItduZeU9TkzS\nEB2bKe/dKAcK9XlMCHEVchcIkiSp6HV6qQ/fLF9Ei0amGBkp7P4zlaQUvWeegU99dyY9e3QkJUXg\n5qrPN8szGTRokNz3thyQA22SVIq88867DD4bRKWGx3CuYExUNGzYuO2ZY+ZNmjRhz97DfPftAgLP\nPGbEmEGMHDlShdRScXvuGLuiKH8Azs/40gdCiG1/veYw8M5/jbErijIWGAvg4eHhf+/evYJmLpCs\nrCxu376Nm5ubnAEjlXq3bt0iOjoaf39/jIyM1I5TImg0Gnbs2MH58+epW7cu3bt3L3MXiYt1Vkx+\nCvv/Ku6Lp9u3b+fV4SPQ0+iRnpPGu++9x8czPyq2/iVJKloajYZePTsRGR5C5zYK+wPBvkIddu76\n81/3KiiN5J6nf4mOjmbwy0OomliX+qmt8ctowzdffMOBAwdUy5Senk5aWppq/UtSWbNnzx4e3g8h\naKc9n7xnz7HtdsREXWTnzp1qR1NFoQq7oii9FUUJB5oAuxRF2aebWLqzb98+HAycsFbsATBRTLFP\nc2PDug3FniU5OZl+ffpjY22LrY0dfXv3Izk5udhzSFJZc+HCBdq31MPAIHcih4GBQoeWChcuXFA5\nmToKVdiFEFuEEO5CCGMhhJMQopOugumKnZ0dWWQ+9ZzGIBvHCo7FnuWNcW9wevc5mmZ3pml2Z87s\nOc/4ceOLPYcklTV+fn7sO6wlKyt3aDk7W7D3kMDPz0/lZOoo83ee5uTkUN3bB/HQAIdsN5KUOCLN\n73Lh0nm8vLyKJcP/MzE2oVFWR4yU3H0ks0QGJ40OkJH572t4SJL0fFqtlv79unHz+gk6tjLgj6Ma\nPCsFsHnLHjnGXhYZGBgQFHyc9sNbklApnGoveXD0+JFiL+oARoZGaMjJe6xBg5GhnNEgSYWlp6fH\nxt93Mn/hWuwrTmbu/DVs2bq3TBX1F1Hmz9hLkqlTpvLr0nV4puVuAHLP7BpDXx/Egi/mq5xMkqTS\noFzvoFRSzf18LmbmZiz7YTkgGPvaGGZ8NOOF2ti+fTvfLVpMeno67Tq25fXXX5e3iEuS9BR5xl6K\nLPtxGe9Neh/DNFNiicIMCzIN0vno4xl88OEHaseTJKmIyWV7yyA3Z3eso5y5wxUa0g5jxYRMkc55\n06MEHj9M/fr11Y4oSVIRkhdPy6DYuBjSSMEZD4wVEwCMFVMcsl3Zv3+/yukkSSopZGEvRdq1bUea\nkkwqSU89n22SIbf3K8WysrLYsWMHq1evfubyu5L0omRhL6SkpCSmTJpCtSrVadOyLUeOFN3S9D+u\n+BHbKpYkKDFcEWeJFhHcNDiPgZ1Cnz59iqxfqehERkZSp7Y3C+aMZNvGd/DxqcSOHTvUjiWVcnKM\nvRCEELRo2oL75yJxzvQklWQemN3gwMH9NGrUqMj6PHjwICt/XkXY7TCatmjC1Hen4uDgUCT9SUVr\n7NhXsdTfxYKPczeNDjqdzoCxqdy7H/XMfUx1JS0tjQsXLuDh4SE/7ZUicrpjMQgNDeXypSv4Z7ZF\nURSssUeTnsPCBV+y4ff1RdKnoii0a9eOdu3aFUn7UvE6ERTIyq+ebF3XtIEpxkZJhIWF4e3tXSR9\nbt68mbFjX8HL3YS791N5edBgvv32R/T05Af4skJ+JwshPj4eE33Tp3aQMhLGxETHqJhKKk2qVfPh\nxJknaxnde5BNQlJOkW2wHhcXx+jRw9izxpZTe225fdKFk8c3s3590ZyISOqQZ+yF0KhRI7L0Mngs\nHlJBcSNbZBFldp+3hn+qdjSplJjx0Vw6tG/J7XsCJwfBj79mMX36h1hYWBRJf4GBgTTys8C/bu6s\nKitLfcYOM2T3zt8ZNGhQkfQpFT95xl4IxsbG7Ny9kxine5wxO8gp4z/oO6wXTZo04ZVhr9K8cQvm\nzZ1Henq62lGlEqpevXqcOn0BC8fRPEzsy4qftzJ16rQi68/JyYmw+1lotU+urd25J3BycS+yPqXi\nJy+e6oBWq+XWrVs4ODgQFxeHf/0AnNI9MNNYEmMagXeAFwcDD8pNvyXVCSFo3aohTrZ3GTnImHOX\nsvnqxwxOBJ+jSpUqaseTnkPeoFSM9PT0qFatGnZ2dny18CsqZLjjqa2Oo+KKT7o/F0Iucv78ebVj\nShKKorBr9yFq1n+Dz5e6cDOyPYFHTsqiXgwiIiKIjIwslr7kGLuO3Qu7j3GOGfx1cq4oChb6VkRE\nRMhb/qUSwcLCgpkzPwE+UTtKuRAdHc3gQb0ICTmHENC4cUN+XbMZOzu7IutTnrHrWJ/+vYkxf4hG\n5K67niTiicuOplmzZionKz7Jycns27ePixcvqh1FklQ3/o0R1Kx8g4gLbkRccMOjwmUmvT2uSPuU\nY+w6ptFoGD5kONu2bcfKyIaUnERWrl5Zbu4M3bFjB4NfHoy1gR2pOcn4N/Rnx+7tmJqaPv9gqcgI\nIQgODiYqKopWrVpha2urdqRyQavVYmJiRPQVLywtcs+jH8fk4N0kkuTkF59UIW9QUom+vj5r1q3h\n7t27PHjwAD8/vyKbulbSpKWlMWTQUGqkNcBasUcrtFw9eZavv/6aadOKbqaH9N+Sk5Pp2qUN0VG3\n8axozMiRyaxcuZYePXqoHa3MUxQFS0tTHsfkYGmRu1va4xgN1lbmRdqvHIopIpUqVaJly5blpqgD\nhISEYKFvibViD4Ceoodjuhs7t+5SOVn5Nn/+XNwcw7h02IHda6zY9as9o0YNldNwi4GiKEyc+DZD\nxydx6HgafxxJ49WJSUx8a0qR9ivP2CWdcXd3JykrEY3IQV/J/dFK00uhduU6Kicr3wIP7WHGWybo\n6eVe0W/kZ4KbsxGhoaE0aNBA5XRl34wZs7Czc+T9uUvQ09Nj7Bsf8NprRTvGLgu7pDNeXl50696N\nw7uO4pjmRoZ+Go9NHzDtw9/UjlaueVWqyvnQh7RrYQZAQqKG++FpVKxYUeVk5YOenh4TJkxkwoSJ\nxdanLOySTv3622p++uknNm/cQkUPX6ZM3YiPj4/ascq1d9/7mLZtmpKQmICXh8LSVdkMHTYMZ2dn\ntaNJRUTOipGkcuDmzZssWfINjx89pGv3/gwcOPA/V3OMiYnh0qVL+Pj4FNmCZNKLk3ueSpJUIF9/\nvZBZs2ZQ28eC0GspjOPxYYIAAAg3SURBVB8/gU8/nSeXxCgB5HRHSZJe2NWrV5k752POHXDCw92Q\nmFhLmnRbSocOXWjVqpXa8aR8ktMdJUnKs3//fnq9ZIaHe+7uTQ72+gzta8iePXLKamkiC7skSXlc\nXV25eefp4dkbd/Rwc5MzaEoTWdglScrTo0cPouOtGT05nl1/pDL54wROnNVj2LBhakeTXoAs7DqU\nmprKq8NHYGpsiqW5JZPfnkx2drbasSQp34yNjQk8cgrXyqNY/IsX+hYvE3QiBBsbG7WjSS9AXjzV\nodEjx3B8+0kaZLVHm6Vh3bLfMTI2Yt7n83Taj1ar5d69e9jb22NlZaXTtiXJzs6O2bN1+zMrFS95\nxq4jWVlZbN6yicoZtTFWTDBVzPFKq8GK5T/ptJ9Tp07hVbES9WrXx8XJhUlvTUar1eq0D0lS2+PH\nj4mPj1c7RqklC3spkpmZSdfOXbGJcKVBWnsCMtqxZsVaVq9erXY0SdKJR48e0b5dU6pV88TT04WX\nB/YkNTVV7ViljizsOmJkZESf3n25YxJKpvi/9u4/Nur6juP48639YVooCgz5UWEGtATZWJNhOmeR\nuBpxPxgzY4EOMgXWGqCbbug0DdlwfyxuyWBhMiIVi+icG8UVYgwDBxHBNYLQADYIdC4FN35KIdza\n7tr3/qAhuAkt9Hqfuy+vR3JJv9e7+77yyd0r33x/fVr4t5/jw5wG5nx/dsLWUVdXR0Z7FjdbPmZG\nlmVz87kRvFT9csLWIRLSnNmlFI4+wNE9+RzZnQ9t23jqqd69E2IUqdgTqGrlCu6dOoF3szZRn7OV\n6WVTefrniZt+LC8vj5b2Fi6+Wvg/1spNAzRpgqQed2f79u2sWLGC+vr6Ll8fi8V4869bWfR4PzIz\njdyc63j6iT7UrPlTEtJGiw6eJlBubi7VL75A9Ysv9Mrnjxs3jttHj+LAnnoGt40gxlk+ymmkesHv\nemV9Ilervb2d6dOmsHvXVu4an82in8X49tSZLFmy7JLvycjIIDMzg+azHdxww/ltzpMfd9C3b06y\nYkeGttj/R3NzM6tWraKqqopjx46FjvMJZsaGTRuYPGcSx/IbubEomz+vf42ioqLQ0UQ+oba2lsYD\n26h/cyArF/djz+ZBrF3zEpe7R1RWVhazZj3EjHnN1L3XwuZtMcofP8O8+QuSmDwadBOwi+zdu5d7\n7r6HPu03cZ0bJ/wo619fx8SJE0NHE0krCxY8Rv+s1Tz5g/4Xnpv7k9Pc8cWFVFRUXPJ98XicZ575\nBa/8fiXZ2VmUlT9GWVm5bkDWKSk3ATOzXwHfANqAQ8DD7n66J58Z0vxH5jPozAjyGQlAP/+IWd+b\nzaEPD+qLJXIFRo++g7WvGu6OmRGPO+/sjPPgdwsu+76MjAwqKxdSWbkwSUmjqae7YjYCY93988AH\nQFrPWLxz104G+bALywMZQtORJmKxWMBUIumntLSUf50YyJSHT/Ob5z7mvu+cYsiwz1FSUhI62jWh\nR8Xu7n9x93jn4t+A/J5HCuf22wo4xfELy82cZED/AeTk6OCNyJXIyclh69s7eGDyIg4encLsRxaz\nbv3Gy07uIYmTyLNiZgGvJvDzkm7J0sV8bdLXicWbod04nn2E55+t0m4YSQsdHR3U1NTwxhuvMXTo\nCMrL5wad1zQ3N5e5c+cGW/+1rMtiN7NNwKdNjljp7rWdr6kE4sAlr5QxszKgDGD48OFXFba3FRcX\ns3vPLlavXk1rayulpaWMHTs2dCyRbqmoKGP71hrmlGbyQSPcOX45b2/bwciRI0NHkyTr8VkxZvYQ\nUA58xd27tTM6Vc+KEUlXTU1NfGFcAYfqhpDX93oAfvrL05xsmcyyZVWB00midPesmB7t8DKzScAT\nwOTulrqIJF5jYyMFo/pcKHWA8YUZHDrwfsBUEkpPj2T8FugLbDSz3Wa2PAGZROQKFRYW0nDgHPv2\ntwLQ0eGsXtPG3RPuD5xMQujRwVN3H5WoICJy9fLy8li6dDkTppQz8a48Dv69jRv738rKR38UOpoE\noHvFiETEjBkzKSm5jy1btjB06FCKi4t1Rtc1SsUuEiGDBw9m2rRpoWNIYLpaQEQkYlTsIiIRo2IX\nEYkYFbuISMSo2EVEIkbFLiISMUFmUDKz48A/kr7i1DYQOBE6RBrQOHVNY9Q96ThOI9z9M129KEix\ny/8zsx3dubnPtU7j1DWNUfdEeZy0K0ZEJGJU7CIiEaNiTx3PhQ6QJjROXdMYdU9kx0n72EVEIkZb\n7CIiEaNiTyFmNtXM9plZh5lF8mj91TKzSWa238wOmtmTofOkIjNbaWbHzGxv6CypysxuMbPNZvZ+\n52/th6Ez9QYVe2rZCzwIvBU6SCoxs+uBZ4EHgDHAdDMbEzZVSqoGJoUOkeLiwI/dfQxQBMyL4ndJ\nxZ5C3L3B3feHzpGC7gQOunuju7cBfwC+GThTynH3t4BToXOkMnf/p7u/1/n3WaABGBY2VeKp2CUd\nDAOaLlo+TAR/jJJcZvZZoBCoC5sk8TSDUpKZ2SZg8Kf8q9Lda5OdR+RaZGZ9gBrgUXc/EzpPoqnY\nk8zdS0JnSENHgFsuWs7vfE7kiplZJudL/WV3Xxs6T2/QrhhJB+8Ct5nZrWaWBUwD1gXOJGnIzs/u\n/TzQ4O6/Dp2nt6jYU4iZfcvMDgNfAl43sw2hM6UCd48D84ENnD/Y9Ud33xc2Veoxs1eAd4ACMzts\nZrNDZ0pBXwZmAvea2e7Ox1dDh0o0XXkqIhIx2mIXEYkYFbuISMSo2EVEIkbFLiISMSp2EZGIUbGL\niESMil1EJGJU7CIiEfNfYY4Gje+/hrUAAAAASUVORK5CYII=\n", 768 | "text/plain": [ 769 | "
" 770 | ] 771 | }, 772 | "metadata": {}, 773 | "output_type": "display_data" 774 | } 775 | ], 776 | "source": [ 777 | "# Get ourselves a simple dataset\n", 778 | "from sklearn.datasets import make_classification\n", 779 | "set_seed(7)\n", 780 | "X, Y = make_classification(n_features=2, n_redundant=0, n_informative=1, n_clusters_per_class=1)\n", 781 | "print('Number of examples: %d' % X.shape[0])\n", 782 | "print('Number of features: %d' % X.shape[1])\n", 783 | "\n", 784 | "# Take a peak\n", 785 | "%matplotlib inline\n", 786 | "import matplotlib\n", 787 | "import matplotlib.pyplot as plt\n", 788 | "plt.scatter(X[:, 0], X[:, 1], marker='o', c=Y, s=25, edgecolor='k')\n", 789 | "plt.show()" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 28, 795 | "metadata": { 796 | "collapsed": true 797 | }, 798 | "outputs": [], 799 | "source": [ 800 | "# Convert data to PyTorch\n", 801 | "X, Y = torch.from_numpy(X), torch.from_numpy(Y)\n", 802 | "\n", 803 | "# Gotcha: \"Expected object of scalar type Float but got scalar type Double\"\n", 804 | "# If you see this it's because numpy defaults to Doubles whereas pytorch has floats.\n", 805 | "X, Y = X.float(), Y.float()" 806 | ] 807 | }, 808 | { 809 | "cell_type": "markdown", 810 | "metadata": {}, 811 | "source": [ 812 | "We'll train a one layer neural net to classify this dataset. Let's define the parameter sizes:" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": 29, 818 | "metadata": { 819 | "ExecuteTime": { 820 | "end_time": "2019-02-19T17:16:34.417254Z", 821 | "start_time": "2019-02-19T17:16:34.411064Z" 822 | }, 823 | "collapsed": true 824 | }, 825 | "outputs": [], 826 | "source": [ 827 | "# Define dimensions\n", 828 | "num_feats = 2\n", 829 | "hidden_size = 100\n", 830 | "num_outputs = 1\n", 831 | "\n", 832 | "# Learning rate\n", 833 | "eta = 0.1\n", 834 | "num_steps = 1000" 835 | ] 836 | }, 837 | { 838 | "cell_type": "markdown", 839 | "metadata": {}, 840 | "source": [ 841 | "And now run a few steps of SGD!" 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "execution_count": 30, 847 | "metadata": { 848 | "ExecuteTime": { 849 | "end_time": "2019-02-20T16:35:28.634813Z", 850 | "start_time": "2019-02-20T16:35:28.632236Z" 851 | } 852 | }, 853 | "outputs": [ 854 | { 855 | "data": { 856 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEKCAYAAAARnO4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAH7FJREFUeJzt3XmUHWd95vHv7669SN2tllqLJVmS\nbfC+yO7Y0piAV8AmgYFxGHvCDuMJIQEmnGTwgTMcZg4hOBm2hM3sBGOIDQ7ggI0xMk6wsZGwkY0l\nW5YsyZK1tPbe7/abP6q6+/btltRqdfW9Xff5nHPPrapbfd+3VDpPvfett6rM3RERkfhLVLsCIiIy\nPRT4IiJ1QoEvIlInFPgiInVCgS8iUicU+CIidUKBLyJSJxT4IiJ1QoEvIlInUtWuQLl58+b58uXL\nq10NEZEZY926dfvcvWMi69ZU4C9fvpy1a9dWuxoiIjOGmW2b6Lrq0hERqRMKfBGROqHAFxGpEwp8\nEZE6ocAXEakTkQW+mZ1pZk+UvY6Y2fujKk9ERI4tsmGZ7v4McBGAmSWBncDdUZUnIiLHNl1dOlcD\nm919wuNFT8SmPd08umV/FF8tIhIb03Xh1Y3AHVF9+bWfegiArX/3mqiKEBGZ8SJv4ZtZBngtcOdR\nPr/ZzNaa2dqurq6oqyMiUremo0vnOuC37r5nvA/d/TZ373T3zo6OCd0OQkREJmE6Av8mIuzOERGR\niYk08M2sGbgW+EGU5YiIyPFFetLW3XuBuVGWISIiE6MrbUVE6kSsAt/dq10FEZGaFavAzxcV+CIi\nRxOrwF+37WC1qyAiUrNiFfg3ffnX1a6CiEjNilXgi4jI0SnwRUTqhAJfRKROKPBFROqEAl9EpE4o\n8EVE6oQCX0SkTijwRUTqhAJfRKROKPBFROpELALfrNo1EBGpffEI/GpXQERkBohH4KuJLyJyXLEI\n/ITyXkTkuGIR+KZOHRGR44pF4CvvRUSOL9LAN7M2M7vLzDaa2QYzWx1FOerSERE5vlTE3/8Z4F53\nv8HMMkBTFIWoS0dE5PgiC3wzawVeDrwNwN1zQC6asqL4VhGReImyS2cF0AV83cweN7OvmFlz5Upm\ndrOZrTWztV1dXZMqSHkvInJ8UQZ+CrgY+IK7rwR6gQ9WruTut7l7p7t3dnR0TKogjcMXETm+KAN/\nB7DD3R8N5+8iOABMufK8d/coihARmfEiC3x33w28YGZnhouuBp6Ooqzy9n2xpMAXERlP1KN0/hK4\nPRyhswV4exSFlHfpFN0j3ygRkZko0mx09yeAzijLgNFdOqVS1KWJiMxMsbjSdlSXjvrwRUTGFYvA\nT5R36agPX0RkXLEIfI3SERE5vlgEfnmnjlr4IiLji0Xgl7fw1YcvIjK+WAR+QqN0RESOKxaBX363\nTLXwRUTGF4/AH9XCV+CLiIwnHoFfNl1SC19EZFzxCHyNwxcROa6YBP7ItFr4IiLji13gFzVKR0Rk\nXPEI/LJefLXwRUTGF4vAT4xq4SvwRUTGE4vALz9pqxa+iMj44hH4ZdOfW/Nc1eohIlLL4hH4ZYl/\n3+/3VK8iIiI1LCaBb6PmCxqqIyIyRiwCv1K+qH58EZFKsQh8q5jPqYUvIjJGpA8xN7OtQDdQBAru\nHvkDzQHyCnwRkTEiDfzQle6+L8oCKjtwFPgiImPFokunUr6gPnwRkUpRB74DPzOzdWZ283grmNnN\nZrbWzNZ2dXVNrpCKi61yxeKkvkdEJM6iDvyXufvFwHXAe8zs5ZUruPtt7t7p7p0dHR2TKqRyWGZO\nLXwRkTEiDXx33xm+7wXuBi6NqJxR8+rDFxEZK7LAN7NmM5s9NA28EngqirLmNmdHzSvwRUTGinKU\nzgLg7rC7JQV8x93vjaQkg3mzsuzrGQQ0Dl9EZDyRBb67bwEujOr7KyXLfqvoSlsRkbFiMywzWXbi\nNldQC19EpFI8At8hVdbEVx++iMhY8Qh8IJUcaeEr8EVExopF4DtOOjGyKerSEREZKxaBD5UtfJ20\nFRGpFJ/AT6hLR0TkWGIR+K6TtiIixxWLwIfRLfxB9eGLiIwRi8B3NEpHROR4YhH4AIYCX0TkWGIT\n+OU0SkdEZKxYBP6YB6CoD19EZIxYBD5A+TNQdNJWRGSsWAR+ZQeOWvgiImPFIvArDRb0TFsRkUox\nDXy18EVEKsUi8CvO2apLR0RkHLEIfACz8itt1aUjIlIpNoEPcMMlSwB16YiIjCfKh5hPm6EenX/4\nkws51JfnxUP9Va2PiEgtiryFb2ZJM3vczO6JtJzwPZtOqEtHRGQc09Gl8z5gQ6QllJ21zaYS6tIR\nERlHpIFvZkuA1wBfibKcoKzgPZtKKvBFRMYRdQv/08DfAEdNYDO72czWmtnarq6uky4wm0owmFeX\njohIpcgC38z+CNjr7uuOtZ673+bune7e2dHRMamyyofhB334auGLiFSaUOCb2elmlg2nrzCz95pZ\n23H+7HLgtWa2FfgucJWZffukanusOobvQ106lXfQFBGpdxNt4X8fKJrZGcBtwFLgO8f6A3e/xd2X\nuPty4EbgF+7+ppOp7NHLGpnOpoJNyukhKCIio0w08EvuXgBeD/yju/81sCi6ap24oStthwJf3Toi\nIqNN9MKrvJndBLwV+ONwWXqihbj7g8CDJ1SzScqmkwAM5kvQMB0liojMDBNt4b8dWA18zN2fN7MV\nwD9HV60T44wehw+6n46ISKUJtfDd/WngvQBmNgeY7e6fiLJiJ2rkpK26dERExjPRUToPmlmLmbUD\nvwW+bGafjLZqEzf6pG3QpTOgsfgiIqNMtEun1d2PAG8AvuXulwHXRFetEzd0pW1DOtikgbxa+CIi\n5SYa+CkzWwS8EYj0JmgnqykT9FL159TCFxEpN9HA/z/AfcBmd/+NmZ0GbIquWiemvEunMRyl05cr\nVKk2IiK1aaInbe8E7iyb3wL8l6gqNTlBn05jJgj8fvXhi4iMMtGTtkvM7G4z2xu+vh/eCbMmlN9E\noSkz1MJX4IuIlJtol87XgR8Bp4SvH4fLasbQSVsFvojI+CYa+B3u/nV3L4SvbwCTu7VlxIa6dDQs\nU0RktIkG/n4ze1P4uMKkmb0J2B9lxU5E+Z0xM8kEyYTppK2ISIWJBv47CIZk7gZ2ATcAb4uoTpMy\ndKWtmdGUTqpLR0SkwoQC3923uftr3b3D3ee7+3+m5kbpjGjIJDUOX0Skwsk88eqvpqwWU2DopC0E\nJ27VwhcRGe1kAt+Ov0p1NKpLR0RkjJMJ/Jp5hmDl0wyzqQSPPV8z55RFRGrCMa+0NbNuxg92Axoj\nqdEkWdkPjt/tOAzAniMDLGjRU1BEROA4ge/us6erIifDK45Jjekk/fkiB3pzCnwRkdDJdOnUlPKT\ntl9+SycAR/rzVaqNiEjtiU3gl2tpDH64HBnQxVciIkMiC3wzazCzx8zsd2b2ezP7aFRlVZ60bWkI\nnq/ePaAWvojIkAndHnmSBoGr3L3HzNLAf5jZT93911EUVt6l09IYBL66dERERkQW+B7c4KYnnE2H\nr2kZyjm7QV06IiKVIu3DD2+09gSwF7jf3R8dZ52bzWytma3t6uqaVDmVR5F0MkFjOqkuHRGRMpEG\nvrsX3f0iYAlwqZmdN846t7l7p7t3dnRM/o7LVnHhb0tjiiP9auGLiAyZllE67n4IWAO8OqLvH7Os\npSHNEbXwRUSGRTlKp8PM2sLpRuBaYGNU5VXe2Wd2Q0qBLyJSJspROouAb5pZkuDA8i/ufk+E5Y3S\n0pjmQG9uuooTEal5UY7SWQ+sjOr7R5U1zrKWhjRb9/VOR/EiIjNCbK60rbxXc9Clo5O2IiJD4hH4\n4zTxWxrTdA/kxz2hKyJSj+IR+ATPsi3X2pgmX3Q9CEVEJBSbwK80tzkDoBO3IiKhWAT+eJ02c2cF\ngb+vZ3B6KyMiUqNiEfgw9qTt3OYsAPt71MIXEYGYBP54J2aHWvjq0hERCcQi8GH07ZFhpIW/r1dd\nOiIiEKPAr9SYSdKUSapLR0QkFIvAP9pI+7mzMurSEREJxSLwYexJW4D25qxG6YiIhGIR+Ee7mHZe\nc0ZdOiIioVgEPoy90hbUpSMiUi42gT+ejtlBl06xpPvpiIjEIvD9KKdtF7Y2Uii5+vFFRIhJ4MP4\nJ21PaW0AYNfhgemtjIhIDYpF4B/tpO3CocA/1D+NtRERqU2xCHxg3Cb+otZGQC18ERGIU+CPY05T\nmmwqwa7DauGLiMQi8I/WpWNmLGptYNfhAQ725ugeyE9vxUREakhkgW9mS81sjZk9bWa/N7P3RVUW\ngI172jbo1tl9eICV//d+Vv3tA1FWQUSkpkXZwi8AH3D3c4BVwHvM7JwIyxvXorYGdhwMunR69bhD\nEaljkQW+u+9y99+G093ABmBxVOWNc6EtAMvnNrP7yMhJ21t+8GRUVRARqWnT0odvZsuBlcCj43x2\ns5mtNbO1XV1dU1728nnNo+bveGw7vYOFKS9HRKTWRR74ZjYL+D7wfnc/Uvm5u9/m7p3u3tnR0TGp\nMsZ74tWQ5XObxix7eteYaoiIxF6kgW9maYKwv93dfxBpWUdZvmxu85hl63ccjrIqIiI1KcpROgZ8\nFdjg7p+MqpzjaW1M094cPN92Xvic26d2Hmbrvl5++uSualVLRGTaRdnCvxx4M3CVmT0Rvq6PoqDj\n3QtzyZzgitu3rF7OtecsYP2OQ7zrW2t59+2/5VCfbp8sIvUhylE6/+Hu5u4XuPtF4esnUZV3tFE6\nAEvnBP34mVSCCxa3smVf7/CJ24c374+qSiIiNSXWV9oOWT4vCPy+wQIrT52DO/SEgf/vm6Z+ZJCI\nSC2KReDD0a+0hZETt9sO9HHxsjbSSaN7IAj8h57dd8xRPiIicRGbwD+Wy8+YB8BVZ82nKZPiwiVt\nw5/tPNTPc3t7qlU1EZFpE4vAP9oTr4Ysbmtky99ez+suCi70vey0dgAWtgT3y7/3qd3RVlBEpAbE\nIvDh2CdtARKJkRVWnTYXgH09g3Qum8O/aXimiNSBWAT+iXbBdy4LWviFknPd+YvYuLubLV3q1hGR\neItF4MPxW/jlGjNJPv6G8/nOuy7j+vMXAvBTdeuISMylql2Barnp0lOHpy9a2sbf3/cMrzp3IWfM\nn1XFWomIRCcWLfyTHVT57itOB+C2hzaffGVERGpULAI/cAJ9OhVede5Cbrr0VH74xIsc7tNjEEUk\nnmIR+FNx3dSbVy1jsFDiznUvnPyXiYjUoFgEPpzYSdvxnHNKC53L5vDNR7ZSKJampE4iIrUkNoE/\nFf7HK07nhQP9/PCJF6tdFRGRKReTwJ+ae+Fcc/Z8zl7UwufWPEexpPvriEi8xCTwT+aUbdl3mPHe\nq85gy75e7n585xR8o4hI7YhF4E/lzS5fde5CLlzSyq33buQ3Ww/w57ev48iARu6IyMwXi8CHkz9p\nOySRMP73H5/L3u5B/uSLj/CTJ3fzsXs2TM2Xi4hUUWwCfypdsmwOr1+5eHj+e2tfYM0ze6tYIxGR\nkxeLwI/i9Oot1501PP3SBbP4X3etZ2/3QAQliYhMj1gEPhz7iVeTMb+lgbv+bDV3/PdVfObGlRwZ\nyPMX33mcvMboi8gMFVngm9nXzGyvmT0VVRlDonpEYefydlafPpezF7Xw8Tecz2PPH+Bj/6b+fBGZ\nmaJs4X8DeHWE3z/KVJ20PZrXr1zCOy5fwTce3sqXH9oSbWEiIhGI7PbI7v6QmS2P6vur4UOvOZs9\nRwb42E82kCuW+PMrTseiPtKIiEyRWNwPf7quiU0mjP/3xgsplEr8/X3P8Oyebm694QKyqeQ01UBE\nZPKqftLWzG42s7Vmtrarq2vy3zOFdTqWhnSSL77pEv76VWfywyde5M1ffYxDfblpKl1EZPKqHvju\nfpu7d7p7Z0dHxyS/Y4ordRxmxnuuPIPP3HgRT2w/xBu+8DDb9/dNbyVERE5Q1QN/qlSjL/11Fy3m\n2++6jAO9OV7/+V/x6Jb9014HEZGJinJY5h3AI8CZZrbDzN4ZVVnVdOmKdr7/7v9ES2OaG7/8a269\ndyO5gsbqi0jtiSzw3f0md1/k7ml3X+LuX42wrKi+ekJO75jFPX/5Mt54yVI+/+BmXv/5X7F+x6Gq\n1klEpFJsunSqrTmb4hM3XMAX33QJe7sHed3nfsVHfviU7rQpIjVDgT/FXn3eQh74wCt4y6plfOvX\n23j5rWv40i83058rVrtqIlLnYhH4tfZsqpaGNB993Xn8+C9exoVL2vj4Tzfyh7eu4Yu/3Mzh/jyf\n/Nkz/OvjO9XXLyLTKhYXXkH0t1aYjPMWt/LNd1zKY88f4LMPbOLvfrqRT//8WQbyQdB/5Ee/5w+W\nt3PDJYu56qwFZFKxOP6KSI2KR+DXWhO/wqUr2vn2uy5j/Y5DfPGXm3ly52Fuue5s7ln/Ir/ZepCf\nb9jDvFkZLljShgHXnLOAr/z7FhrSSa45ewHXnrOAc09p0W0cROSkxCPwmfrbI0fhgiVtfP5PLxme\nv/78RRRLzkPPdvHd32zn/qf3UHJ4YOPIw1Y27DrCZx7YxMKWBq4+ez7XnLOAVSvm0pjR7RxE5MTE\nJvBnqmTCuPKs+Vx51nz29wwCsONgP49vP8h/u2wZ3QN51jzTxc+f3sPdj+/k9ke3k04aK5fOYdXp\nc1l92lxWntpGQ1oHABE5tlgEfo336EzY3FnZ4fcLl7YNT99wyRJuuGQJA/kijz5/gIc37+ORzfv5\np19s4rMPbCKbSnDJsjn8wfJ2LlraxgVLWoe/S0RkSCwCH2rzpO1Ua0gnecVLO3jFS4N7Dh3uz/PY\n8wd4ZPN+Ht68j8/+YtPwfYWWtjdy4ZK28ADQxtmLZvPioQEe3ryPMxfOpmNWlkVtjczKjv0vUCo5\niUQd/IOK1JlYBH61r7StltbGNNeeE5zUBegZLPDkjsP8bsch1u84xOPbD3HP+l3H/I7FbY2cMX8W\nh/pymBk9gwU2d/VwSmsjp7Y3sWxuE6fObWL53Obh+dkN6enYPBGZYrEIfJi+2yPXslnZFKtPn8vq\n0+cOL9vbPcCTOw7z9ItH2LinmyvPnM/82Vn2HBlgb/cgz+7p5tk9PXR1D7CgpYFDfXlOm9fM+Ytb\n2Xagj/uf3sP+3tG3f25vzgyH/7K5zSxrb2JpexOL5zSyYHaWVHLs8NLn9/WSShjzW7Kjnh/g7hwZ\nKDA7m9KvCpGIxSbwZXzzZzdw9dkNXH32gkl/R/dAnu0H+ti+v49tB/rYtr+P7Qd6WbftID/+3YuU\nyn5gJRNGe3OGpkyShS0NLGptoC9X5GdP7xlep705w4KWBha2ZNlxsJ9Ne3tIJ42GVJJsOkHH7Abm\nz84ytzlDe3OG9lkZ5jVnh6eHls/KpjRUVeQExCLw67NDZ/rMbkhz7imtnHtK65jPcoUSOw/188KB\nPnYe6mfnwX66ugfpzxfZfXiAtdsOsvvwAAAffs3Z9A4W2X1kgD1HBth9eIADvTnOWjibK86cz5GB\nPO7Q1R38+nhubw8HenP058e/LUU6acxpCsJ/+L05TXtzlvamNHOaM+w5MsCWrl5am9K0NWaY05Sm\nrSlNa2OGddsODP+6mNWQYnZDmlnZFC0NI/OzG1LMyqY0CkpiIRaBD/Vx0rYWZVIJVsxrZsW85qOu\nUyo5RXfS43T1TERfrsD+nhwHeoPX/t4cB3tzHOgL3ofmN+w+wsHeHIf682MeipNJJsgVJ38ri0wy\nweyGVPhKj5oeOkiMLE/TlEmydX8vEHS1zcqmaA4PLLlCiY27jpBJJWnOJmnKpGjKJMPXyHRzNkU2\nldCvGJkysQj8Oj1nO2MkEkbiJM6yNGVSNLWnWNreNKH1iyXncH+eA705ugfynLe4lVTC6M8XOdSX\n52BfjsN9eZIJ49IV7eSLTu9gge6BAkcG8vSE090D+ZH34WUjy/ft66V7oEDPQIHuwcKkt+9YEhZs\nf2N4EGhMh++ZJI3pkYNDQzo5PJ0rlOjqyY1atyGdpCGdYPv+PvpyRRrSiXBZkmwqQTadpCE1smz4\n89TIdDadIBvOZ5I6EM1EsQh8qM4Tr6Q2DZ1HaG/OjFoetJ5TnNLWOGp5JmVkUhnmVKx/Ikolpyc3\n+oDQ0pBm/uwsPYOF0a+BAmcunM3shhR9uSJ9g0X6cgX68mXTuWL4GjvdH04f6M3TP7QsHywvhCdU\nWhpSlDz4dVSqaBC1NqbJFUoMFIqTbiyZERwowgNANpUkk0qwr2cQA7Lh+Zih8zK5Qolt+/tIJ41M\nKjjIZFLBgSMbHkAyo5Ylh5cNrZtNVayXSpBKGMlEgmQCEmakwunyZS8eGmCwUCSVTJBJGulkYviV\nSVXMJxOkw2WZ4eVGOiw7KM8mlDfuTr7opBJWMwMSYhP4ItWUSBgtDWlaGtLA6APKyRxITlS+WCJX\nKNEcXl8xFDr94UEhnbThi/LcnVyxxEC+xGChyGC+xEC+yEA+OBgM5MNlhXBZPlxWKDEYvg9UvF+4\npI3mbLJsveDvB/MlLloaXBE+WAjqmCuWyBWKI/OFEj2DBXKF0vCy4L1IrhhM18KveTPKDghGX65I\nsRR0WaaSNnygyBdLwyPcEgapZIJ0woL3ZHBwSoUHoI5ZWf7lz1ZHXvdYBP5X3trJwtaGaldDpOqG\nWqpDzCz8BZOgldHXT5hZ0BJPJYHav7bC3SmUfPjgkCuWKJScUilYXix7lXxkWXM2GDGWK5bIF518\noRQcGIfmi6VgWekon4UH0VHzxRL5QjDfkA5+beSLQd0KpfCzUoklbY0kEkahGMwXik6hGJRVKJbC\n5c6s7PQMCohF4F9+xrxqV0FEImZmQfdKMkGz7hwyKboBu4hInYg08M3s1Wb2jJk9Z2YfjLIsERE5\ntsgC38ySwOeA64BzgJvM7JyoyhMRkWOLsoV/KfCcu29x9xzwXeB1EZYnIiLHEGXgLwZeKJvfES4T\nEZEqqPpJWzO72czWmtnarq6ualdHRCS2ogz8ncDSsvkl4bJR3P02d+90986Ojo4IqyMiUt+iDPzf\nAC8xsxVmlgFuBH4UYXkiInIMFuXToszseuDTQBL4mrt/7DjrdwHbJlncPGDfJP92ptI21wdtc/yd\nzPYuc/cJdY9EGvjTyczWuntntesxnbTN9UHbHH/Ttb1VP2krIiLTQ4EvIlIn4hT4t1W7AlWgba4P\n2ub4m5btjU0fvoiIHFucWvgiInIMMz7w43pHTjNbamZrzOxpM/u9mb0vXN5uZveb2abwfU643Mzs\ns+G/w3ozu7i6WzB5ZpY0s8fN7J5wfoWZPRpu2/fC6zows2w4/1z4+fJq1nuyzKzNzO4ys41mtsHM\nVsd9P5vZ/wz/Xz9lZneYWUPc9rOZfc3M9prZU2XLTni/mtlbw/U3mdlbT6ZOMzrwY35HzgLwAXc/\nB1gFvCfctg8CD7j7S4AHwnkI/g1eEr5uBr4w/VWeMu8DNpTNfwL4lLufARwE3hkufydwMFz+qXC9\nmegzwL3ufhZwIcG2x3Y/m9li4L1Ap7ufR3Cdzo3Ebz9/A3h1xbIT2q9m1g58BLiM4IaUHxk6SEyK\nu8/YF7AauK9s/hbglmrXK6Jt/SFwLfAMsChctgh4Jpz+EnBT2frD682kF8EtOB4ArgLuAYzggpRU\n5T4H7gNWh9OpcD2r9jac4Pa2As9X1jvO+5mRGyu2h/vtHuBVcdzPwHLgqcnuV+Am4Etly0etd6Kv\nGd3Cp07uyBn+hF0JPAoscPdd4Ue7gQXhdFz+LT4N/A1QCufnAofcvRDOl2/X8DaHnx8O159JVgBd\nwNfDbqyvmFkzMd7P7r4T+AdgO7CLYL+tI977eciJ7tcp3d8zPfBjz8xmAd8H3u/uR8o/8+CQH5th\nVmb2R8Bed19X7bpMoxRwMfAFd18J9DLyMx+I5X6eQ/BsjBXAKUAzY7s+Yq8a+3WmB/6E7sg5U5lZ\nmiDsb3f3H4SL95jZovDzRcDecHkc/i0uB15rZlsJHphzFUH/dpuZpcJ1yrdreJvDz1uB/dNZ4Smw\nA9jh7o+G83cRHADivJ+vAZ539y53zwM/INj3cd7PQ050v07p/p7pgR/bO3KamQFfBTa4+yfLPvoR\nMHSm/q0EfftDy98Snu1fBRwu++k4I7j7Le6+xN2XE+zLX7j7nwJrgBvC1Sq3eejf4oZw/RnVEnb3\n3cALZnZmuOhq4GlivJ8JunJWmVlT+P98aJtju5/LnOh+vQ94pZnNCX8ZvTJcNjnVPqkxBSdFrgee\nBTYDH6p2faZwu15G8HNvPfBE+LqeoO/yAWAT8HOgPVzfCEYsbQaeJBgBUfXtOIntvwK4J5w+DXgM\neA64E8iGyxvC+efCz0+rdr0nua0XAWvDff2vwJy472fgo8BG4Cngn4Fs3PYzcAfBOYo8wS+5d05m\nvwLvCLf9OeDtJ1MnXWkrIlInZnqXjoiITJACX0SkTijwRUTqhAJfRKROKPBFROqEAl/qmpl9KLxr\n43oze8LMLjOz95tZU7XrJjLVNCxT6paZrQY+CVzh7oNmNg/IAA8TjIPeV9UKikwxtfClni0C9rn7\nIEAY8DcQ3N9ljZmtATCzV5rZI2b2WzO7M7y/EWa21cxuNbMnzewxMzujWhsiMhEKfKlnPwOWmtmz\nZvZ5M3uFu38WeBG40t2vDFv9HwaucfeLCa6I/auy7zjs7ucD/0Rwp0+RmpU6/ioi8eTuPWZ2CfCH\nwJXA92zsU9NWETxc51fBbV/IAI+UfX5H2funoq2xyMlR4Etdc/ci8CDwoJk9yciNrYYYcL+733S0\nrzjKtEjNUZeO1C0zO9PMXlK26CJgG9ANzA6X/Rq4fKh/3syazeylZX/zX8vey1v+IjVHLXypZ7OA\nfzSzNoJnCD9H8DzRm4B7zezFsB//bcAdZpYN/+7DBHdoBZhjZuuBwfDvRGqWhmWKTFL4oBYN35QZ\nQ106IiJ1Qi18EZE6oRa+iEidUOCLiNQJBb6ISJ1Q4IuI1AkFvohInVDgi4jUif8PDNoPdpz2riQA\nAAAASUVORK5CYII=\n", 857 | "text/plain": [ 858 | "
" 859 | ] 860 | }, 861 | "metadata": {}, 862 | "output_type": "display_data" 863 | } 864 | ], 865 | "source": [ 866 | "# Input to hidden weights\n", 867 | "W1 = torch.randn(hidden_size, num_feats, requires_grad=True)\n", 868 | "b1 = torch.zeros(hidden_size, requires_grad=True)\n", 869 | "\n", 870 | "# Hidden to output\n", 871 | "W2 = torch.randn(num_outputs, hidden_size, requires_grad=True)\n", 872 | "b2 = torch.zeros(num_outputs, requires_grad=True)\n", 873 | "\n", 874 | "# Group parameters\n", 875 | "parameters = [W1, b1, W2, b2]\n", 876 | "\n", 877 | "# Get random order\n", 878 | "indices = torch.randperm(X.size(0))\n", 879 | "\n", 880 | "# Keep running average losses for a learning curve?\n", 881 | "avg_loss = []\n", 882 | "\n", 883 | "# Run!\n", 884 | "for step in range(num_steps):\n", 885 | " # Get example\n", 886 | " i = indices[step % indices.size(0)]\n", 887 | " x_i, y_i = X[i], Y[i]\n", 888 | " \n", 889 | " # Run example\n", 890 | " hidden = torch.relu(W1.matmul(x_i) + b1)\n", 891 | " y_hat = torch.sigmoid(W2.matmul(hidden) + b2)\n", 892 | " \n", 893 | " # Compute loss binary cross entropy: -(y_i * log(y_hat) + (1 - y_i) * log(1 - y_hat))\n", 894 | " # Epsilon for numerical stability\n", 895 | " eps = 1e-6\n", 896 | " loss = -(y_i * (y_hat + eps).log() + (1 - y_i) * (1 - y_hat + eps).log())\n", 897 | "\n", 898 | " # Add to our running average learning curve. Don't forget .item()!\n", 899 | " if step == 0:\n", 900 | " avg_loss.append(loss.item())\n", 901 | " else:\n", 902 | " old_avg = avg_loss[-1]\n", 903 | " new_avg = (loss.item() + old_avg * len(avg_loss)) / (len(avg_loss) + 1)\n", 904 | " avg_loss.append(new_avg)\n", 905 | " \n", 906 | " # Zero out all previous gradients\n", 907 | " for param in parameters:\n", 908 | " # It might start out as None\n", 909 | " if param.grad is not None:\n", 910 | " # In place\n", 911 | " param.grad.zero_()\n", 912 | "\n", 913 | " # Backward pass\n", 914 | " loss.backward()\n", 915 | " \n", 916 | " # Update parameters\n", 917 | " for param in parameters:\n", 918 | " # In place!\n", 919 | " param.data = param.data - eta * param.grad\n", 920 | " \n", 921 | "\n", 922 | "plt.plot(range(num_steps), avg_loss)\n", 923 | "plt.ylabel('Loss')\n", 924 | "plt.xlabel('Step')\n", 925 | "plt.show()" 926 | ] 927 | }, 928 | { 929 | "cell_type": "markdown", 930 | "metadata": { 931 | "ExecuteTime": { 932 | "end_time": "2019-02-20T16:35:28.639328Z", 933 | "start_time": "2019-02-20T16:35:28.636394Z" 934 | } 935 | }, 936 | "source": [ 937 | "## torch.nn\n", 938 | "\n", 939 | "The `nn` package is where all of the cool neural network stuff is. Layers, loss functions, etc.\n", 940 | "\n", 941 | "Let's dive in." 942 | ] 943 | }, 944 | { 945 | "cell_type": "markdown", 946 | "metadata": { 947 | "ExecuteTime": { 948 | "end_time": "2019-02-20T16:35:28.643927Z", 949 | "start_time": "2019-02-20T16:35:28.640936Z" 950 | } 951 | }, 952 | "source": [ 953 | "### Layers\n", 954 | "\n", 955 | "Before we manually defined our linear layers. PyTorch has them for you as sub-classes of `nn.Module`." 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "execution_count": 31, 961 | "metadata": {}, 962 | "outputs": [ 963 | { 964 | "name": "stdout", 965 | "output_type": "stream", 966 | "text": [ 967 | "Linear(in_features=10, out_features=10, bias=True)\n", 968 | "Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", 969 | "RNN(10, 10)\n" 970 | ] 971 | } 972 | ], 973 | "source": [ 974 | "import torch.nn as nn\n", 975 | "\n", 976 | "# Linear layer: in_features, out_features\n", 977 | "linear = nn.Linear(10, 10)\n", 978 | "print(linear)\n", 979 | "\n", 980 | "# Convolution layer: in_channels, out_channels, kernel_size, stride\n", 981 | "conv = nn.Conv2d(1, 20, 5, 1)\n", 982 | "print(conv)\n", 983 | "\n", 984 | "# RNN: num_inputs, num_hidden, num_layers\n", 985 | "rnn = nn.RNN(10, 10, 1)\n", 986 | "print(rnn)" 987 | ] 988 | }, 989 | { 990 | "cell_type": "code", 991 | "execution_count": 33, 992 | "metadata": {}, 993 | "outputs": [ 994 | { 995 | "name": "stdout", 996 | "output_type": "stream", 997 | "text": [ 998 | "Parameter containing:\n", 999 | "tensor([[-0.2087, 0.0624, 0.0927, 0.2812, 0.0016, 0.2136, -0.1054, -0.2304,\n", 1000 | " -0.0307, 0.1642],\n", 1001 | " [-0.1235, -0.2677, -0.1926, 0.0560, 0.3015, 0.0175, -0.2549, -0.1416,\n", 1002 | " 0.1605, -0.0995],\n", 1003 | " [-0.0427, 0.2353, 0.1162, 0.1936, 0.2839, -0.1041, 0.0458, -0.2373,\n", 1004 | " 0.3143, -0.2120],\n", 1005 | " [ 0.3006, 0.2895, 0.0688, -0.2734, -0.0102, -0.1303, 0.0969, 0.1788,\n", 1006 | " 0.1761, 0.1016],\n", 1007 | " [-0.2423, -0.2660, 0.0934, -0.0694, 0.1478, 0.3073, 0.0955, -0.1904,\n", 1008 | " -0.0913, 0.1948],\n", 1009 | " [ 0.0300, 0.2156, -0.3031, -0.0390, -0.1542, 0.2403, 0.1383, -0.0424,\n", 1010 | " -0.2934, -0.0373],\n", 1011 | " [ 0.2564, -0.0085, -0.0131, -0.2924, 0.2504, 0.2616, -0.2541, -0.2243,\n", 1012 | " 0.0153, -0.1809],\n", 1013 | " [-0.2588, 0.0992, -0.0820, 0.1096, 0.1257, 0.2816, 0.1879, -0.2973,\n", 1014 | " -0.2548, 0.2535],\n", 1015 | " [-0.2687, 0.1933, -0.1927, 0.2537, 0.1788, -0.2183, -0.2614, -0.1386,\n", 1016 | " -0.1446, -0.1795],\n", 1017 | " [ 0.2228, 0.0777, -0.0397, -0.0215, 0.1316, 0.0324, -0.0392, 0.2808,\n", 1018 | " 0.2182, 0.0222]], requires_grad=True)\n", 1019 | "['weight', 'bias']\n" 1020 | ] 1021 | } 1022 | ], 1023 | "source": [ 1024 | "print(linear.weight)\n", 1025 | "print([k for k,v in conv.named_parameters()])" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "execution_count": 34, 1031 | "metadata": { 1032 | "ExecuteTime": { 1033 | "end_time": "2019-02-20T16:35:28.850451Z", 1034 | "start_time": "2019-02-20T16:35:28.645580Z" 1035 | }, 1036 | "collapsed": true 1037 | }, 1038 | "outputs": [], 1039 | "source": [ 1040 | "# Make our own model!\n", 1041 | "\n", 1042 | "class Net(nn.Module):\n", 1043 | " def __init__(self):\n", 1044 | " super(Net, self).__init__()\n", 1045 | " # 1 input channel to 20 feature maps of 5x5 kernel. Stride 1.\n", 1046 | " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", 1047 | "\n", 1048 | " # 20 input channels to 50 feature maps of 5x5 kernel. Stride 1.\n", 1049 | " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", 1050 | "\n", 1051 | " # Full connected of final 4x4 image to 500 features\n", 1052 | " self.fc1 = nn.Linear(4*4*50, 500)\n", 1053 | " \n", 1054 | " # From 500 to 10 classes\n", 1055 | " self.fc2 = nn.Linear(500, 10)\n", 1056 | "\n", 1057 | " def forward(self, x):\n", 1058 | " x = F.relu(self.conv1(x))\n", 1059 | " x = F.max_pool2d(x, 2, 2)\n", 1060 | " x = F.relu(self.conv2(x))\n", 1061 | " x = F.max_pool2d(x, 2, 2)\n", 1062 | " x = x.view(-1, 4*4*50)\n", 1063 | " x = F.relu(self.fc1(x))\n", 1064 | " x = self.fc2(x)\n", 1065 | " return F.log_softmax(x, dim=1)\n", 1066 | "\n", 1067 | "# Initialize it\n", 1068 | "model = Net()" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "markdown", 1073 | "metadata": {}, 1074 | "source": [ 1075 | "A note on convolution sizes:\n", 1076 | "\n", 1077 | "Running a kernel over the image reduces the image height/length by kernel_size - 1.\n", 1078 | "\n", 1079 | "Running a max pooling over the image reduces the image heigh/length by a factor of the kernel size.\n", 1080 | "\n", 1081 | "So starting from a 28 x 28 image:\n", 1082 | "\n", 1083 | "- Run 5x5 conv --> 24 x 24\n", 1084 | "- Apply 2x2 max pool --> 12 x 12\n", 1085 | "- Run 5x5 conv --> 8 x 8\n", 1086 | "- Apply 2x2 max pool --> 4 x 4" 1087 | ] 1088 | }, 1089 | { 1090 | "cell_type": "markdown", 1091 | "metadata": {}, 1092 | "source": [ 1093 | "### Optimizers\n", 1094 | "\n", 1095 | "PyTorch handles all the optimizing too. There are several algorithms you can learn about later. Here's SGD:\n" 1096 | ] 1097 | }, 1098 | { 1099 | "cell_type": "code", 1100 | "execution_count": 35, 1101 | "metadata": { 1102 | "collapsed": true 1103 | }, 1104 | "outputs": [], 1105 | "source": [ 1106 | "import torch.optim as optim\n", 1107 | "\n", 1108 | "# Initialize with model parameters\n", 1109 | "optimizer = optim.SGD(model.parameters(), lr=0.01)" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "markdown", 1114 | "metadata": {}, 1115 | "source": [ 1116 | "Updating is now as easy as:\n", 1117 | "\n", 1118 | "```python\n", 1119 | "loss = loss_fn()\n", 1120 | "optimizer.zero_grad()\n", 1121 | "loss.backward()\n", 1122 | "optimizer.step()\n", 1123 | "```\n", 1124 | "\n", 1125 | "### Full train and test loops\n", 1126 | "Let's look at a full train loop now." 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "code", 1131 | "execution_count": 36, 1132 | "metadata": { 1133 | "collapsed": true 1134 | }, 1135 | "outputs": [], 1136 | "source": [ 1137 | "import tqdm\n", 1138 | "import torch.nn.functional as F\n", 1139 | "\n", 1140 | "def train(model, train_loader, optimizer, epoch):\n", 1141 | " # For things like dropout\n", 1142 | " model.train()\n", 1143 | " \n", 1144 | " # Avg loss\n", 1145 | " total_loss = 0\n", 1146 | " \n", 1147 | " # Iterate through dataset\n", 1148 | " for data, target in tqdm.tqdm(train_loader):\n", 1149 | " # Zero grad\n", 1150 | " optimizer.zero_grad()\n", 1151 | "\n", 1152 | " # Forward pass\n", 1153 | " output = model(data)\n", 1154 | " \n", 1155 | " # Negative log likelihood loss function\n", 1156 | " loss = F.nll_loss(output, target)\n", 1157 | "\n", 1158 | " # Backward pass\n", 1159 | " loss.backward()\n", 1160 | " total_loss += loss.item()\n", 1161 | " \n", 1162 | " # Update\n", 1163 | " optimizer.step()\n", 1164 | "\n", 1165 | " # Print average loss\n", 1166 | " print(\"Train Epoch: {}\\t Loss: {:.6f}\".format(epoch, total_loss / len(train_loader)))" 1167 | ] 1168 | }, 1169 | { 1170 | "cell_type": "markdown", 1171 | "metadata": {}, 1172 | "source": [ 1173 | "Testing loops are similar." 1174 | ] 1175 | }, 1176 | { 1177 | "cell_type": "code", 1178 | "execution_count": 37, 1179 | "metadata": { 1180 | "collapsed": true 1181 | }, 1182 | "outputs": [], 1183 | "source": [ 1184 | "def test(model, test_loader):\n", 1185 | " model.eval()\n", 1186 | " test_loss = 0\n", 1187 | " correct = 0\n", 1188 | " with torch.no_grad():\n", 1189 | " for data, target in test_loader:\n", 1190 | " output = model(data)\n", 1191 | " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", 1192 | " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", 1193 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 1194 | "\n", 1195 | " test_loss /= len(test_loader.dataset)\n", 1196 | "\n", 1197 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 1198 | " test_loss, correct, len(test_loader.dataset),\n", 1199 | " 100. * correct / len(test_loader.dataset)))" 1200 | ] 1201 | }, 1202 | { 1203 | "cell_type": "markdown", 1204 | "metadata": {}, 1205 | "source": [ 1206 | "## MNIST\n", 1207 | "\n", 1208 | "Just going to run mnist!" 1209 | ] 1210 | }, 1211 | { 1212 | "cell_type": "code", 1213 | "execution_count": 38, 1214 | "metadata": {}, 1215 | "outputs": [ 1216 | { 1217 | "name": "stderr", 1218 | "output_type": "stream", 1219 | "text": [ 1220 | "100%|██████████| 1875/1875 [00:57<00:00, 32.43it/s]\n" 1221 | ] 1222 | }, 1223 | { 1224 | "name": "stdout", 1225 | "output_type": "stream", 1226 | "text": [ 1227 | "Train Epoch: 1\t Loss: 0.314105\n" 1228 | ] 1229 | }, 1230 | { 1231 | "name": "stderr", 1232 | "output_type": "stream", 1233 | "text": [ 1234 | " 0%| | 3/1875 [00:00<01:03, 29.64it/s]" 1235 | ] 1236 | }, 1237 | { 1238 | "name": "stdout", 1239 | "output_type": "stream", 1240 | "text": [ 1241 | "\n", 1242 | "Test set: Average loss: 0.0967, Accuracy: 9701/10000 (97%)\n", 1243 | "\n" 1244 | ] 1245 | }, 1246 | { 1247 | "name": "stderr", 1248 | "output_type": "stream", 1249 | "text": [ 1250 | "100%|██████████| 1875/1875 [00:55<00:00, 33.76it/s]\n" 1251 | ] 1252 | }, 1253 | { 1254 | "name": "stdout", 1255 | "output_type": "stream", 1256 | "text": [ 1257 | "Train Epoch: 2\t Loss: 0.086927\n" 1258 | ] 1259 | }, 1260 | { 1261 | "name": "stderr", 1262 | "output_type": "stream", 1263 | "text": [ 1264 | " 0%| | 4/1875 [00:00<00:58, 31.88it/s]" 1265 | ] 1266 | }, 1267 | { 1268 | "name": "stdout", 1269 | "output_type": "stream", 1270 | "text": [ 1271 | "\n", 1272 | "Test set: Average loss: 0.0590, Accuracy: 9818/10000 (98%)\n", 1273 | "\n" 1274 | ] 1275 | }, 1276 | { 1277 | "name": "stderr", 1278 | "output_type": "stream", 1279 | "text": [ 1280 | "100%|██████████| 1875/1875 [00:55<00:00, 33.53it/s]\n" 1281 | ] 1282 | }, 1283 | { 1284 | "name": "stdout", 1285 | "output_type": "stream", 1286 | "text": [ 1287 | "Train Epoch: 3\t Loss: 0.061660\n" 1288 | ] 1289 | }, 1290 | { 1291 | "name": "stderr", 1292 | "output_type": "stream", 1293 | "text": [ 1294 | " 0%| | 4/1875 [00:00<00:58, 32.16it/s]" 1295 | ] 1296 | }, 1297 | { 1298 | "name": "stdout", 1299 | "output_type": "stream", 1300 | "text": [ 1301 | "\n", 1302 | "Test set: Average loss: 0.0493, Accuracy: 9839/10000 (98%)\n", 1303 | "\n" 1304 | ] 1305 | }, 1306 | { 1307 | "name": "stderr", 1308 | "output_type": "stream", 1309 | "text": [ 1310 | "100%|██████████| 1875/1875 [00:54<00:00, 34.44it/s]\n" 1311 | ] 1312 | }, 1313 | { 1314 | "name": "stdout", 1315 | "output_type": "stream", 1316 | "text": [ 1317 | "Train Epoch: 4\t Loss: 0.048638\n" 1318 | ] 1319 | }, 1320 | { 1321 | "name": "stderr", 1322 | "output_type": "stream", 1323 | "text": [ 1324 | " 0%| | 4/1875 [00:00<00:57, 32.47it/s]" 1325 | ] 1326 | }, 1327 | { 1328 | "name": "stdout", 1329 | "output_type": "stream", 1330 | "text": [ 1331 | "\n", 1332 | "Test set: Average loss: 0.0403, Accuracy: 9874/10000 (99%)\n", 1333 | "\n" 1334 | ] 1335 | }, 1336 | { 1337 | "name": "stderr", 1338 | "output_type": "stream", 1339 | "text": [ 1340 | "100%|██████████| 1875/1875 [00:54<00:00, 34.63it/s]\n" 1341 | ] 1342 | }, 1343 | { 1344 | "name": "stdout", 1345 | "output_type": "stream", 1346 | "text": [ 1347 | "Train Epoch: 5\t Loss: 0.040329\n" 1348 | ] 1349 | }, 1350 | { 1351 | "name": "stderr", 1352 | "output_type": "stream", 1353 | "text": [ 1354 | " 0%| | 4/1875 [00:00<00:55, 33.79it/s]" 1355 | ] 1356 | }, 1357 | { 1358 | "name": "stdout", 1359 | "output_type": "stream", 1360 | "text": [ 1361 | "\n", 1362 | "Test set: Average loss: 0.0337, Accuracy: 9885/10000 (99%)\n", 1363 | "\n" 1364 | ] 1365 | }, 1366 | { 1367 | "name": "stderr", 1368 | "output_type": "stream", 1369 | "text": [ 1370 | "100%|██████████| 1875/1875 [00:54<00:00, 34.14it/s]\n" 1371 | ] 1372 | }, 1373 | { 1374 | "name": "stdout", 1375 | "output_type": "stream", 1376 | "text": [ 1377 | "Train Epoch: 6\t Loss: 0.034306\n" 1378 | ] 1379 | }, 1380 | { 1381 | "name": "stderr", 1382 | "output_type": "stream", 1383 | "text": [ 1384 | " 0%| | 4/1875 [00:00<00:58, 31.85it/s]" 1385 | ] 1386 | }, 1387 | { 1388 | "name": "stdout", 1389 | "output_type": "stream", 1390 | "text": [ 1391 | "\n", 1392 | "Test set: Average loss: 0.0372, Accuracy: 9876/10000 (99%)\n", 1393 | "\n" 1394 | ] 1395 | }, 1396 | { 1397 | "name": "stderr", 1398 | "output_type": "stream", 1399 | "text": [ 1400 | "100%|██████████| 1875/1875 [00:53<00:00, 34.75it/s]\n" 1401 | ] 1402 | }, 1403 | { 1404 | "name": "stdout", 1405 | "output_type": "stream", 1406 | "text": [ 1407 | "Train Epoch: 7\t Loss: 0.029963\n" 1408 | ] 1409 | }, 1410 | { 1411 | "name": "stderr", 1412 | "output_type": "stream", 1413 | "text": [ 1414 | " 0%| | 4/1875 [00:00<00:56, 33.26it/s]" 1415 | ] 1416 | }, 1417 | { 1418 | "name": "stdout", 1419 | "output_type": "stream", 1420 | "text": [ 1421 | "\n", 1422 | "Test set: Average loss: 0.0337, Accuracy: 9889/10000 (99%)\n", 1423 | "\n" 1424 | ] 1425 | }, 1426 | { 1427 | "name": "stderr", 1428 | "output_type": "stream", 1429 | "text": [ 1430 | "100%|██████████| 1875/1875 [00:56<00:00, 33.23it/s]\n" 1431 | ] 1432 | }, 1433 | { 1434 | "name": "stdout", 1435 | "output_type": "stream", 1436 | "text": [ 1437 | "Train Epoch: 8\t Loss: 0.026079\n" 1438 | ] 1439 | }, 1440 | { 1441 | "name": "stderr", 1442 | "output_type": "stream", 1443 | "text": [ 1444 | " 0%| | 4/1875 [00:00<00:56, 33.39it/s]" 1445 | ] 1446 | }, 1447 | { 1448 | "name": "stdout", 1449 | "output_type": "stream", 1450 | "text": [ 1451 | "\n", 1452 | "Test set: Average loss: 0.0349, Accuracy: 9887/10000 (99%)\n", 1453 | "\n" 1454 | ] 1455 | }, 1456 | { 1457 | "name": "stderr", 1458 | "output_type": "stream", 1459 | "text": [ 1460 | "100%|██████████| 1875/1875 [00:53<00:00, 34.85it/s]\n" 1461 | ] 1462 | }, 1463 | { 1464 | "name": "stdout", 1465 | "output_type": "stream", 1466 | "text": [ 1467 | "Train Epoch: 9\t Loss: 0.022861\n" 1468 | ] 1469 | }, 1470 | { 1471 | "name": "stderr", 1472 | "output_type": "stream", 1473 | "text": [ 1474 | " 0%| | 4/1875 [00:00<00:56, 32.90it/s]" 1475 | ] 1476 | }, 1477 | { 1478 | "name": "stdout", 1479 | "output_type": "stream", 1480 | "text": [ 1481 | "\n", 1482 | "Test set: Average loss: 0.0374, Accuracy: 9877/10000 (99%)\n", 1483 | "\n" 1484 | ] 1485 | }, 1486 | { 1487 | "name": "stderr", 1488 | "output_type": "stream", 1489 | "text": [ 1490 | "100%|██████████| 1875/1875 [00:54<00:00, 34.36it/s]\n" 1491 | ] 1492 | }, 1493 | { 1494 | "name": "stdout", 1495 | "output_type": "stream", 1496 | "text": [ 1497 | "Train Epoch: 10\t Loss: 0.019654\n", 1498 | "\n", 1499 | "Test set: Average loss: 0.0319, Accuracy: 9905/10000 (99%)\n", 1500 | "\n" 1501 | ] 1502 | } 1503 | ], 1504 | "source": [ 1505 | "from torchvision import datasets, transforms\n", 1506 | "\n", 1507 | "# See the torch DataLoader for more details.\n", 1508 | "train_loader = torch.utils.data.DataLoader(\n", 1509 | " datasets.MNIST('../data', train=True, download=True,\n", 1510 | " transform=transforms.Compose([\n", 1511 | " transforms.ToTensor(),\n", 1512 | " transforms.Normalize((0.1307,), (0.3081,))\n", 1513 | " ])),\n", 1514 | " batch_size=32, shuffle=True)\n", 1515 | "\n", 1516 | "test_loader = torch.utils.data.DataLoader(\n", 1517 | " datasets.MNIST('../data', train=False,\n", 1518 | " transform=transforms.Compose([\n", 1519 | " transforms.ToTensor(),\n", 1520 | " transforms.Normalize((0.1307,), (0.3081,))\n", 1521 | " ])),\n", 1522 | " batch_size=32, shuffle=True)\n", 1523 | "\n", 1524 | "\n", 1525 | "for epoch in range(1, 10 + 1):\n", 1526 | " train(model, train_loader, optimizer, epoch)\n", 1527 | " test(model, test_loader)" 1528 | ] 1529 | }, 1530 | { 1531 | "cell_type": "code", 1532 | "execution_count": null, 1533 | "metadata": { 1534 | "collapsed": true 1535 | }, 1536 | "outputs": [], 1537 | "source": [] 1538 | } 1539 | ], 1540 | "metadata": { 1541 | "kernelspec": { 1542 | "display_name": "Python 3", 1543 | "language": "python", 1544 | "name": "python3" 1545 | }, 1546 | "language_info": { 1547 | "codemirror_mode": { 1548 | "name": "ipython", 1549 | "version": 3 1550 | }, 1551 | "file_extension": ".py", 1552 | "mimetype": "text/x-python", 1553 | "name": "python", 1554 | "nbconvert_exporter": "python", 1555 | "pygments_lexer": "ipython3", 1556 | "version": "3.6.5" 1557 | }, 1558 | "toc": { 1559 | "base_numbering": 1, 1560 | "nav_menu": {}, 1561 | "number_sections": true, 1562 | "sideBar": true, 1563 | "skip_h1_title": false, 1564 | "title_cell": "Table of Contents", 1565 | "title_sidebar": "Contents", 1566 | "toc_cell": false, 1567 | "toc_position": { 1568 | "height": "calc(100% - 180px)", 1569 | "left": "10px", 1570 | "top": "150px", 1571 | "width": "164px" 1572 | }, 1573 | "toc_section_display": true, 1574 | "toc_window_display": true 1575 | } 1576 | }, 1577 | "nbformat": 4, 1578 | "nbformat_minor": 2 1579 | } 1580 | --------------------------------------------------------------------------------