├── A Disciplined Approach To Neural Network Hyper-Parameters.ipynb ├── README.md └── images ├── BS_LR.png ├── HighBSHighLR.png ├── HighBSLowLR.png ├── LowBSHighLR.png ├── LowBSLowLR.png ├── Under_over.png ├── loss.png ├── one_cycle.png └── test /A Disciplined Approach To Neural Network Hyper-Parameters.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### A DISCIPLINED APPROACH TO NEURAL NETWORK HYPER-PARAMETERS: PART 1 – LEARNING RATE, BATCH SIZE, MOMENTUM, AND WEIGHT DECAY" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "- Reviewing the approach for setting Hyperparameters by Leslie Smith. \n", 15 | "- 'Setting the hyper-parameters remains a black art that requires years of experience to acquire' - Leslie Smith" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "heading_collapsed": true 22 | }, 23 | "source": [ 24 | "### Load Dependancies" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": { 31 | "hidden": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# Uncomment the below if you need to reset your precomputed activations\n", 36 | "!rm -rf {PATH}tmp" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": { 43 | "hidden": true 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "%reload_ext autoreload\n", 48 | "%autoreload 2\n", 49 | "%matplotlib inline" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": { 56 | "hidden": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "#Import the fastai libraries\n", 61 | "from fastai.imports import *\n", 62 | "from fastai.transforms import *\n", 63 | "from fastai.conv_learner import *\n", 64 | "from fastai.model import *\n", 65 | "from fastai.dataset import *\n", 66 | "from fastai.sgdr import *\n", 67 | "from fastai.plots import *" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": { 74 | "hidden": true 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "# specify the path of the folder we will be working with\n", 79 | "PATH = 'data/pill/small/'" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": { 86 | "hidden": true 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "# Specify the csv file that contains the labels for the corresponding images in the training folder\n", 91 | "labels_csv = f'{PATH}/classify_test_double.csv'\n", 92 | "#labels_csv = f'{PATH}/labels.csv'\n", 93 | "n = len(list(open(labels_csv))) -1\n", 94 | "val_idxs = get_cv_idxs(n) #Create a validation set = in this case 20% set for validation\n", 95 | "#val_idxs = []#use this for not creating a validation set" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 6, 101 | "metadata": { 102 | "hidden": true 103 | }, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "206" 109 | ] 110 | }, 111 | "execution_count": 6, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "len(val_idxs) #number in validation set" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 7, 123 | "metadata": { 124 | "hidden": true 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "f_model = resnet50 #choose the pretrained model\n", 129 | "sz=185\n", 130 | "bs=8" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 8, 136 | "metadata": { 137 | "hidden": true 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "#aug_tfms = [RandomLighting(0.25,0.150)] + [RandomDihedral()]\n", 142 | "#aug_tfms = [RandomDihedral()] + [RandomRotate(27)]\n", 143 | "#aug_tfms = [Cutout(n_holes=200, length=7.5, tfm_y=TfmType.NO)]\n", 144 | "aug_tfms = [RandomRotateZoom(deg=45, zoom=2, stretch=1)] + [AddPadding(pad=50, mode=cv2.BORDER_WRAP)] + [RandomDihedral()]\n", 145 | "#aug_tfms = RandomRotate(57)\n", 146 | "#aug_tfms = RandomFlip()\n", 147 | "#aug_tfms = RandomDihedral()\n", 148 | "#aug_tfms =[RandomRotate(27)]\n", 149 | "#aug_tfms = [AddPadding(pad=50, mode=cv2.BORDER_CONSTANT)] #padding\n", 150 | "#aug_tfms = transforms_top_down\n", 151 | "#aug_tfms = RandomLighting(0.25,0.15) #No transformations\n", 152 | "#aug_tfms = [RandomRotate(10), RandomLighting(0.05, 0.05)] #transforms basic\n", 153 | "#aug_tfms = transforms_side_on = transforms_basic + [RandomFlip()] #transforms side on\n", 154 | "#aug_tfms = RandomBlur(blur_strengths=99, probability=100) #random blur\n", 155 | "#aug_tfms = transforms_top_down = transforms_basic + [RandomDihedral()] # transforms top down\n", 156 | "#aug_tfms = RandomRotate(90, p=0.75, mode=cv2.BORDER_REFLECT, tfm_y=TfmType.NO) #Random rotatateaug_tfms = [RandomLighting(b=0.5, c=0.1, tfm_y=TfmType.NO)] # Random Lighting\n", 157 | "#aug_tfms = [RandomRotateZoom(deg=45, zoom=2, stretch=1)] #Random Zoom Rotate\n", 158 | "#aug_tfms = [RandomRotate(27), RandomLighting(0.15, 0.15), RandomDihedral(), RandomBlur(blur_strengths=71, probability=0.5,\n", 159 | " # tfm_y=TfmType.NO), RandomRotateZoom(deg=45, zoom=2, stretch=1)] #aug_full\n", 160 | "#aug_tfms =[RandomRotate(10), RandomLighting(0.05, 0.05)] + [RandomFlip()] + [RandomDihedral()] + [RandomZoom(zoom_max=1)] + [RandomStretch(max_stretch=0.5)]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 9, 166 | "metadata": { 167 | "hidden": true 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "tfms = tfms_from_model(f_model, sz, aug_tfms)\n", 172 | "md = ImageClassifierData.from_csv(PATH, 'train_double', labels_csv, tfms=tfms,\n", 173 | " val_idxs=val_idxs, test_name='test_600', bs=bs)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "metadata": { 180 | "hidden": true 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "x,y=next(iter((md.val_dl)))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 11, 190 | "metadata": { 191 | "hidden": true 192 | }, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "image/png": "\n", 197 | "text/plain": [ 198 | "
" 199 | ] 200 | }, 201 | "metadata": {}, 202 | "output_type": "display_data" 203 | } 204 | ], 205 | "source": [ 206 | "plt.imshow(md.trn_ds.denorm(x)[7]); #look at picture 7 from training set" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 12, 212 | "metadata": { 213 | "hidden": true 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "learn = ConvLearner.pretrained(f_model, md)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "### Cyclic Learning Rates" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "- If the Learning Rate (LR) is too small = Overfitting\n", 232 | "- If LR is too large = Divergence\n", 233 | "- However large LRs help to regularize the training\n", 234 | "- For Cyclic Learning Rates (CLR) you specify specific minimum and maximum learning rate boundaries and a stepsize" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 13, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "application/vnd.jupyter.widget-view+json": { 245 | "model_id": "b4063c70cc2747f8a77c6826ffb4655e", 246 | "version_major": 2, 247 | "version_minor": 0 248 | }, 249 | "text/plain": [ 250 | "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" 251 | ] 252 | }, 253 | "metadata": {}, 254 | "output_type": "display_data" 255 | }, 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | " 97%|█████████▋| 100/103 [00:16<00:00, 6.20it/s, loss=0.641]" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "learn.lr_find2(end_lr=100)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 14, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "images/loss.png": "\n", 276 | "text/plain": [ 277 | "
" 278 | ] 279 | }, 280 | "metadata": {}, 281 | "output_type": "display_data" 282 | } 283 | ], 284 | "source": [ 285 | "learn.sched.plot(1,5)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "## Understanding Fastai code to implement 1-Cycle Policy" 293 | ] 294 | }, 295 | { 296 | "cell_type": "raw", 297 | "metadata": {}, 298 | "source": [ 299 | "Using Leslie Smith's 1 cycle policy involves using the following code and parameters:\n", 300 | "\n", 301 | "learn.fit(lr,1,cycle_len=20,use_clr_beta=(10,0.25,0.95,0.85),wds=1e-5)\n", 302 | "\n", 303 | "This portion involves looking at the first 3 hyperparameters (lr, 1 and cycle_len)\n", 304 | "\n", 305 | "- lr = This is the learning rate, based on this paper a higher LR is beneficial but is dependant on the architecture of your model and the data\n", 306 | "- 1 = denotes to 1 cycle\n", 307 | "- cycle_len = denotes the length of the cycle length. The paper provides details of what cycle length to use (They vary from as low as 12 (MNIST with LeNet architecture) to 800 (Cifar-10 with a wide ResNet architecture) but the main point is that cycle lengths between 12 and 50 produce good results and in some cases are compared to those using a much higher length (Time and Resource consuming)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "- The 1 cycle policy involves a cycle with 2 steps of equal length\n", 315 | "- Step 1 where the learning rate increases linearly from the maximum to the minimum\n", 316 | "- Step 2 where it linearly decreases" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 21, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "data": { 326 | "application/vnd.jupyter.widget-view+json": { 327 | "model_id": "2e2d55f062db464b8b7ae66f2c76601f", 328 | "version_major": 2, 329 | "version_minor": 0 330 | }, 331 | "text/plain": [ 332 | "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" 333 | ] 334 | }, 335 | "metadata": {}, 336 | "output_type": "display_data" 337 | }, 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "epoch trn_loss val_loss \n", 343 | " 0 0.075126 0.062634 0.98183 \n", 344 | " 1 0.072194 0.060749 0.982329 \n" 345 | ] 346 | }, 347 | { 348 | "data": { 349 | "text/plain": [ 350 | "[array([0.06075]), 0.9823285595884601]" 351 | ] 352 | }, 353 | "execution_count": 21, 354 | "metadata": {}, 355 | "output_type": "execute_result" 356 | } 357 | ], 358 | "source": [ 359 | "learn.fit(2.0,1,cycle_len=2,use_clr_beta=(10,5,0.95,0.85),wds=1e-5)\n", 360 | "#Just an example in order to generate the cycle with 2 steps of equal length" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 22, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "image/png": "\n", 371 | "text/plain": [ 372 | "
" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "learn.sched.plot_lr()" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": {}, 386 | "source": [ 387 | "- The peak in the middle of the cycle (at 100 iterations) acts as a regularization method to prevent overfitting" 388 | ] 389 | }, 390 | { 391 | "cell_type": "raw", 392 | "metadata": {}, 393 | "source": [ 394 | "This portion involves looking at the parameters within use_clr_beta which takes 4 parameters\n", 395 | "\n", 396 | "use_clr_beta=(10,5,0.95,0.85),wds=1e-5)\n", 397 | "use_clr_beta=(A,B,C,D), wds=E)\n", 398 | "\n", 399 | "A = ratio between the initial learning rate and the maximum one (paper suggests between 1/10 or 1/20)\n", 400 | "B = the % of the cycle you want to dedicate to the simulated annealing at the end. Can vary from 5% to 25% depending on architecture and data\n", 401 | "C and D = Momentum can be described as the moving average of the gradients. In this case the momentum and the learning rate are closely related. The Optimal learning rate is dependant on the momentum and the momentum is dependant on the learning rate. Optimal momentum will improve network training and the paper recommends a value between (0.9 and 0.99). Value C is the maximum momentum in this case 0.95 and Value D is the minimum momentum value which in this case is 0.85" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": { 407 | "heading_collapsed": true 408 | }, 409 | "source": [ 410 | "## Batch Size and LR" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": { 417 | "hidden": true 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "# Uncomment the below if you need to reset your precomputed activations\n", 422 | "!rm -rf {PATH}tmp" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": { 429 | "hidden": true 430 | }, 431 | "outputs": [], 432 | "source": [ 433 | "#Use low BS and low LR\n", 434 | "bs=8\n", 435 | "lr=0.01" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": { 442 | "hidden": true 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "%time\n", 447 | "learn.fit(lr,2,cycle_len=2,use_clr_beta=(10,0.25,0.95,0.85),wds=1e-5)\n", 448 | "#better results with using lower 5% of the cycle dedicated to the end of the cycle" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": { 455 | "hidden": true 456 | }, 457 | "outputs": [], 458 | "source": [ 459 | "fig,ax = plt.subplots(1,1,figsize=(4,4))\n", 460 | "ax.plot(list(range(4)),learn.sched.rec_metrics)\n", 461 | "ax.set_title('Low BS and Low LR')\n", 462 | "ax.set_xlabel('Epoch')\n", 463 | "ax.set_ylabel('Accuracy')" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": { 470 | "hidden": true 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "# Uncomment the below if you need to reset your precomputed activations\n", 475 | "!rm -rf {PATH}tmp" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": { 482 | "hidden": true 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "#High BS and Low LR\n", 487 | "bs=128\n", 488 | "lr=0.01" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "hidden": true 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "%time\n", 500 | "learn.fit(lr,2,cycle_len=2,use_clr_beta=(10,0.25,0.95,0.85),wds=1e-5)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": { 507 | "hidden": true 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "fig,ax = plt.subplots(1,1,figsize=(4,4))\n", 512 | "ax.plot(list(range(4)),learn.sched.rec_metrics)\n", 513 | "ax.set_title('High BS and Low LR')\n", 514 | "ax.set_xlabel('Epoch')\n", 515 | "ax.set_ylabel('Accuracy')" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": null, 521 | "metadata": { 522 | "hidden": true 523 | }, 524 | "outputs": [], 525 | "source": [ 526 | "# Uncomment the below if you need to reset your precomputed activations\n", 527 | "!rm -rf {PATH}tmp" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": { 534 | "hidden": true 535 | }, 536 | "outputs": [], 537 | "source": [ 538 | "#Low BS and High LR\n", 539 | "bs=8\n", 540 | "lr=2.0" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "hidden": true 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "%time\n", 552 | "learn.fit(lr,2,cycle_len=2,use_clr_beta=(10,0.25,0.95,0.85),wds=1e-5)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": { 559 | "hidden": true 560 | }, 561 | "outputs": [], 562 | "source": [ 563 | "fig,ax = plt.subplots(1,1,figsize=(4,4))\n", 564 | "ax.plot(list(range(4)),learn.sched.rec_metrics)\n", 565 | "ax.set_title('Low BS and High LR')\n", 566 | "ax.set_xlabel('Epoch')\n", 567 | "ax.set_ylabel('Accuracy')" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": null, 573 | "metadata": { 574 | "hidden": true 575 | }, 576 | "outputs": [], 577 | "source": [ 578 | "# Uncomment the below if you need to reset your precomputed activations\n", 579 | "!rm -rf {PATH}tmp" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "metadata": { 586 | "hidden": true 587 | }, 588 | "outputs": [], 589 | "source": [ 590 | "#High BS and High LR\n", 591 | "bs=128\n", 592 | "lr=2.0" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": null, 598 | "metadata": { 599 | "hidden": true 600 | }, 601 | "outputs": [], 602 | "source": [ 603 | "%time\n", 604 | "learn.fit(lr,2,cycle_len=2,use_clr_beta=(10,0.25,0.95,0.85),wds=1e-5)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": null, 610 | "metadata": { 611 | "hidden": true 612 | }, 613 | "outputs": [], 614 | "source": [ 615 | "fig,ax = plt.subplots(1,1,figsize=(4,4))\n", 616 | "ax.plot(list(range(4)),learn.sched.rec_metrics)\n", 617 | "ax.set_title('High BS and High LR')\n", 618 | "ax.set_xlabel('Epoch')\n", 619 | "ax.set_ylabel('Accuracy')" 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": {}, 625 | "source": [ 626 | "## BS and LR analysis" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": {}, 632 | "source": [ 633 | "![title](img/BS_LR.png)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "metadata": {}, 639 | "source": [ 640 | "- Low BS and High LR as well as High BS and High LR produce the highest accuracy" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": { 646 | "heading_collapsed": true 647 | }, 648 | "source": [ 649 | "## Learning & Validation Losses Weight Decay " 650 | ] 651 | }, 652 | { 653 | "cell_type": "markdown", 654 | "metadata": { 655 | "hidden": true 656 | }, 657 | "source": [ 658 | "- Use this to test the Weight Decay" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": { 665 | "hidden": true 666 | }, 667 | "outputs": [], 668 | "source": [ 669 | "%time\n", 670 | "learn.fit(3.0,1,cycle_len=30,use_clr_beta=(10,5,0.95,0.85),wds=1e-1)\n", 671 | "#better results with using lower 5% of the cycle dedicated to the end of the cycle" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": null, 677 | "metadata": { 678 | "hidden": true 679 | }, 680 | "outputs": [], 681 | "source": [ 682 | "learn.sched.plot_lr()" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": null, 688 | "metadata": { 689 | "hidden": true 690 | }, 691 | "outputs": [], 692 | "source": [ 693 | "fig,ax = plt.subplots(2,1,figsize=(8,12))\n", 694 | "ax[0].plot(list(range(30)),learn.sched.val_losses, label='Validation loss')\n", 695 | "ax[0].plot(list(range(30)),[learn.sched.losses[i] for i in range(32,30*34,33)], label='Training loss')\n", 696 | "ax[0].set_xlabel('Epoch')\n", 697 | "ax[0].set_ylabel('Loss')\n", 698 | "ax[0].legend(loc='upper right')\n", 699 | "ax[1].plot(list(range(30)),learn.sched.rec_metrics)\n", 700 | "ax[1].set_xlabel('Epoch')\n", 701 | "ax[1].set_ylabel('Accuracy')" 702 | ] 703 | }, 704 | { 705 | "cell_type": "markdown", 706 | "metadata": {}, 707 | "source": [ 708 | "## Learning & Validation Losses based on Weight Decay Analysis" 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "metadata": {}, 714 | "source": [ 715 | "![title](img/Under_over.png)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": {}, 721 | "source": [ 722 | "- Pictorial explanation of the tradeoff between underfitting and overfitting\n" 723 | ] 724 | }, 725 | { 726 | "cell_type": "markdown", 727 | "metadata": {}, 728 | "source": [ 729 | "### 1e5, 1e4, 1e3, 1e2" 730 | ] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": {}, 735 | "source": [ 736 | "![title](img/loss.png)" 737 | ] 738 | }, 739 | { 740 | "cell_type": "markdown", 741 | "metadata": {}, 742 | "source": [ 743 | "- The graphs from left to right potray Training (Orange) and Validation (Blue) Loss plots with a Weight Decay(wds) of 1e5, 1e4, 1e3 and 1e2\n", 744 | "- The graphs show that the Training loss is above the Validation loss when the wds is 1e5 and 1e4 but the two losses then intersect when the wds is 1e3 and 1e2" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": null, 750 | "metadata": {}, 751 | "outputs": [], 752 | "source": [] 753 | } 754 | ], 755 | "metadata": { 756 | "kernelspec": { 757 | "display_name": "Python 3", 758 | "language": "python", 759 | "name": "python3" 760 | }, 761 | "language_info": { 762 | "codemirror_mode": { 763 | "name": "ipython", 764 | "version": 3 765 | }, 766 | "file_extension": ".py", 767 | "mimetype": "text/x-python", 768 | "name": "python", 769 | "nbconvert_exporter": "python", 770 | "pygments_lexer": "ipython3", 771 | "version": "3.6.5" 772 | } 773 | }, 774 | "nbformat": 4, 775 | "nbformat_minor": 2 776 | } 777 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Disciplined Approach to Neural Network Hyper-parameters: Part 1 - Learning Rate, Batch Size, Momentum and Weight Decay 2 | 3 | - Reviewing the approach for setting Hyperparameters by Leslie Smith. 4 | - 'Setting the hyper-parameters remains a black art that requires years of experience to acquire' - Leslie Smith 5 | 6 | You can review the paper here: (https://arxiv.org/abs/1803.09820) 7 | 8 | The 1 cycle policy involves a cycle with 2 steps of equal length: Step 1 where the learning rate increases linearly from the maximum to the minimum and Step 2 where it linearly decreases. 9 | 10 | ![](/images/one_cycle.png) 11 | 12 | The peak in the middle of the cycle (at 100 iterations) acts as a regularization method to prevent overfitting 13 | 14 | ### Batch Size and Learning Rate Analysis 15 | 16 | ![](/images/BS_LR.png) 17 | 18 | Low BS and High LR as well as High BS and High LR produce the highest accuracy 19 | 20 | ### Learning and Validation Loss Analysis based on Weight Decay 21 | 22 | ![](/images/Under_over.png) 23 | 24 | Pictorial explanation of the tradeoff between underfitting and overfitting 25 | 26 | ![](/images/loss.png) 27 | 28 | - The graphs from left to right potray Training (Orange) and Validation (Blue) Loss plots with a Weight Decay(wds) of 1e5, 1e4, 1e3 and 1e2 29 | - The graphs show that the Training loss is above the Validation loss when the wds is 1e5 and 1e4 but the two losses then intersect when the wds is 1e3 and 1e2 30 | -------------------------------------------------------------------------------- /images/BS_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/BS_LR.png -------------------------------------------------------------------------------- /images/HighBSHighLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/HighBSHighLR.png -------------------------------------------------------------------------------- /images/HighBSLowLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/HighBSLowLR.png -------------------------------------------------------------------------------- /images/LowBSHighLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/LowBSHighLR.png -------------------------------------------------------------------------------- /images/LowBSLowLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/LowBSLowLR.png -------------------------------------------------------------------------------- /images/Under_over.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/Under_over.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/loss.png -------------------------------------------------------------------------------- /images/one_cycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asvcode/1_cycle/0887e7877a7579e52c6fae85a7f8db22c8a1dd33/images/one_cycle.png -------------------------------------------------------------------------------- /images/test: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------