├── assets ├── framework.png ├── embeddings_vis.png └── performance_summary.png ├── README.md └── ImageNet_Subset ├── Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb ├── Fully_Supervised_Training_IMGNET_subset_Adam.ipynb └── Fully_Supervised_Training_IMGNET_subset_SGD.ipynb /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/framework.png -------------------------------------------------------------------------------- /assets/embeddings_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/embeddings_vis.png -------------------------------------------------------------------------------- /assets/performance_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/performance_summary.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supervised-Contrastive-Learning-in-TensorFlow-2 2 | 3 | (Collaboratively done by [Shweta Shaw](https://www.linkedin.com/in/sweta-shaw-797540159/) and myself) 4 | 5 | Implements the ideas presented in [Supervised Contrastive Learning](https://arxiv.org/pdf/2004.11362v1.pdf) by Khosla et al. The authors propose a two-stage framework to enhance the performance of image classifiers and also achieves SoTA results. 6 | 7 | ![](assets/framework.png) 8 | 9 | (Figures gathered from the paper) 10 | 11 | A detailed discussion of the paper and the results of our experiments are available here in [this report](https://app.wandb.ai/authors/scl/reports/Improving-Image-Classification-with-Supervised-Contrastive-Learning--VmlldzoxMzQwNzE). 12 | 13 | This repository consists of the notebooks (runnable on Colab) showing the experiments we have done. 14 | 15 | ## Acknowledgements 16 | - [Contrastive loss for supervised classification](https://towardsdatascience.com/contrastive-loss-for-supervised-classification-224ae35692e7). 17 | - [Prannay Khosla](https://twitter.com/PrannayKhosla) for sharing his comments on our work. 18 | 19 | ## About the notebooks 20 | ``` 21 | ├── Flowers 22 | │   ├── Contrastive_Training_Flowers.ipynb 23 | │   ├── Contrastive_Training_Flowers_Augmentation.ipynb 24 | │   ├── Fully_Supervised_Training_Flowers.ipynb 25 | │   └── Fully_Supervised_Training_Flowers_Augmentation.ipynb 26 | ├── ImageNet_Subset 27 | │   ├── Contrastive_Training_Imagenet_subset_Adam.ipynb 28 | │   ├── Contrastive_Training_Imagenet_subset_RMSprop.ipynb 29 | │   ├── Contrastive_Training_Imagenet_subset_SGD.ipynb 30 | │   ├── Fully_Supervised_Training_IMGNET_subset_Adam.ipynb 31 | │   ├── Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb 32 | │   └── Fully_Supervised_Training_IMGNET_subset_SGD.ipynb 33 | ├── Pets 34 | │   ├── Contrastive_Training_Pets.ipynb 35 | │   └── Fully_Supervised_Training_Pets.ipynb 36 | ├── Visualization_ImageNet_subset.ipynb 37 | ├── Visualization_Pets.ipynb 38 | ``` 39 | 40 | - `Contrastive_Training_*.ipynb` notebooks show the supervised contrastive framework proposed in the paper. 41 | - `Fully_Supervised_Training_*.ipynb` notebooks show the typical fully supervised training with different datasets. 42 | - `Visualization_ImageNet_*.ipynb` notebooks show the visualizations of the embeddings learned by the supervised contrastive learning framework. 43 | 44 | ## About the datasets 45 | - Flowers 46 | - Cats vs. Dogs 47 | - ImageNet Subset (https://github.com/thunderInfy/imagenet-5-categories) 48 | 49 | ## Things to note 50 | - The authors used AutoAugment in the paper. However, we used simple augmnetation operations which worked for the datasets we tried. Note that, there's no augmentation for the Pets dataset as we got pretty good results on that one even without any data augmentation. 51 | - LARS optimizer was used in the paper, however we used Adam. We have also shown the effect of different optimizers like SGD and RMSProp along with learning rate schedules. 52 | 53 | ## Results 54 | 55 | ![](assets/performance_summary.png) 56 | 57 | The above plots are from the experiments conducted on the **Pets** dataset. More results from the other two datasets have been discussed in the above-mentioned report and can be found here: https://app.wandb.ai/authors/scl. 58 | 59 | ## Visualization of the embeddings learned by supervised contrastive learning 60 | 61 | ![](assets/embeddings_vis.png) 62 | 63 | ## About executing the notebooks 64 | 65 | If you go to any of the notebooks listed in the repository and use an extension like "Open notebook in Google Colab" to open it, you should be able to run the experiments right off the bat. 66 | 67 | ## About the library versions 68 | 69 | At the time of performing the experiments, we used TensorFlow 2.2. We specifically did not denote the versions of the other libraries. All of our experiments were performed on [Google Colab](http://colab.research.google.com/). 70 | 71 | ## Feedback 72 | Via GitHub issues 73 | -------------------------------------------------------------------------------- /ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU", 15 | "widgets": { 16 | "application/vnd.jupyter.widget-state+json": { 17 | "0b4c5d36a97843e6bfc07d37a8e3cb6a": { 18 | "model_module": "@jupyter-widgets/controls", 19 | "model_name": "HBoxModel", 20 | "state": { 21 | "_view_name": "HBoxView", 22 | "_dom_classes": [], 23 | "_model_name": "HBoxModel", 24 | "_view_module": "@jupyter-widgets/controls", 25 | "_model_module_version": "1.5.0", 26 | "_view_count": null, 27 | "_view_module_version": "1.5.0", 28 | "box_style": "", 29 | "layout": "IPY_MODEL_e8ef5e7fafb142aaa238fd8ae119b315", 30 | "_model_module": "@jupyter-widgets/controls", 31 | "children": [ 32 | "IPY_MODEL_fb1635d6c5da4d4893558f5d32c234f2", 33 | "IPY_MODEL_0dae8ce37add4e81825f319cb66620df" 34 | ] 35 | } 36 | }, 37 | "e8ef5e7fafb142aaa238fd8ae119b315": { 38 | "model_module": "@jupyter-widgets/base", 39 | "model_name": "LayoutModel", 40 | "state": { 41 | "_view_name": "LayoutView", 42 | "grid_template_rows": null, 43 | "right": null, 44 | "justify_content": null, 45 | "_view_module": "@jupyter-widgets/base", 46 | "overflow": null, 47 | "_model_module_version": "1.2.0", 48 | "_view_count": null, 49 | "flex_flow": null, 50 | "width": null, 51 | "min_width": null, 52 | "border": null, 53 | "align_items": null, 54 | "bottom": null, 55 | "_model_module": "@jupyter-widgets/base", 56 | "top": null, 57 | "grid_column": null, 58 | "overflow_y": null, 59 | "overflow_x": null, 60 | "grid_auto_flow": null, 61 | "grid_area": null, 62 | "grid_template_columns": null, 63 | "flex": null, 64 | "_model_name": "LayoutModel", 65 | "justify_items": null, 66 | "grid_row": null, 67 | "max_height": null, 68 | "align_content": null, 69 | "visibility": null, 70 | "align_self": null, 71 | "height": null, 72 | "min_height": null, 73 | "padding": null, 74 | "grid_auto_rows": null, 75 | "grid_gap": null, 76 | "max_width": null, 77 | "order": null, 78 | "_view_module_version": "1.2.0", 79 | "grid_template_areas": null, 80 | "object_position": null, 81 | "object_fit": null, 82 | "grid_auto_columns": null, 83 | "margin": null, 84 | "display": null, 85 | "left": null 86 | } 87 | }, 88 | "fb1635d6c5da4d4893558f5d32c234f2": { 89 | "model_module": "@jupyter-widgets/controls", 90 | "model_name": "FloatProgressModel", 91 | "state": { 92 | "_view_name": "ProgressView", 93 | "style": "IPY_MODEL_e80797a1b4e24f19a9ce53ce7e9e9299", 94 | "_dom_classes": [], 95 | "description": "100%", 96 | "_model_name": "FloatProgressModel", 97 | "bar_style": "success", 98 | "max": 1250, 99 | "_view_module": "@jupyter-widgets/controls", 100 | "_model_module_version": "1.5.0", 101 | "value": 1250, 102 | "_view_count": null, 103 | "_view_module_version": "1.5.0", 104 | "orientation": "horizontal", 105 | "min": 0, 106 | "description_tooltip": null, 107 | "_model_module": "@jupyter-widgets/controls", 108 | "layout": "IPY_MODEL_4d2932cd6daf495d972886f07bfcc227" 109 | } 110 | }, 111 | "0dae8ce37add4e81825f319cb66620df": { 112 | "model_module": "@jupyter-widgets/controls", 113 | "model_name": "HTMLModel", 114 | "state": { 115 | "_view_name": "HTMLView", 116 | "style": "IPY_MODEL_717104f84c7044818e3f5e4346de1ec0", 117 | "_dom_classes": [], 118 | "description": "", 119 | "_model_name": "HTMLModel", 120 | "placeholder": "​", 121 | "_view_module": "@jupyter-widgets/controls", 122 | "_model_module_version": "1.5.0", 123 | "value": " 1250/1250 [01:03<00:00, 19.77it/s]", 124 | "_view_count": null, 125 | "_view_module_version": "1.5.0", 126 | "description_tooltip": null, 127 | "_model_module": "@jupyter-widgets/controls", 128 | "layout": "IPY_MODEL_b0df6eda571748eaa741673ef3cf090f" 129 | } 130 | }, 131 | "e80797a1b4e24f19a9ce53ce7e9e9299": { 132 | "model_module": "@jupyter-widgets/controls", 133 | "model_name": "ProgressStyleModel", 134 | "state": { 135 | "_view_name": "StyleView", 136 | "_model_name": "ProgressStyleModel", 137 | "description_width": "initial", 138 | "_view_module": "@jupyter-widgets/base", 139 | "_model_module_version": "1.5.0", 140 | "_view_count": null, 141 | "_view_module_version": "1.2.0", 142 | "bar_color": null, 143 | "_model_module": "@jupyter-widgets/controls" 144 | } 145 | }, 146 | "4d2932cd6daf495d972886f07bfcc227": { 147 | "model_module": "@jupyter-widgets/base", 148 | "model_name": "LayoutModel", 149 | "state": { 150 | "_view_name": "LayoutView", 151 | "grid_template_rows": null, 152 | "right": null, 153 | "justify_content": null, 154 | "_view_module": "@jupyter-widgets/base", 155 | "overflow": null, 156 | "_model_module_version": "1.2.0", 157 | "_view_count": null, 158 | "flex_flow": null, 159 | "width": null, 160 | "min_width": null, 161 | "border": null, 162 | "align_items": null, 163 | "bottom": null, 164 | "_model_module": "@jupyter-widgets/base", 165 | "top": null, 166 | "grid_column": null, 167 | "overflow_y": null, 168 | "overflow_x": null, 169 | "grid_auto_flow": null, 170 | "grid_area": null, 171 | "grid_template_columns": null, 172 | "flex": null, 173 | "_model_name": "LayoutModel", 174 | "justify_items": null, 175 | "grid_row": null, 176 | "max_height": null, 177 | "align_content": null, 178 | "visibility": null, 179 | "align_self": null, 180 | "height": null, 181 | "min_height": null, 182 | "padding": null, 183 | "grid_auto_rows": null, 184 | "grid_gap": null, 185 | "max_width": null, 186 | "order": null, 187 | "_view_module_version": "1.2.0", 188 | "grid_template_areas": null, 189 | "object_position": null, 190 | "object_fit": null, 191 | "grid_auto_columns": null, 192 | "margin": null, 193 | "display": null, 194 | "left": null 195 | } 196 | }, 197 | "717104f84c7044818e3f5e4346de1ec0": { 198 | "model_module": "@jupyter-widgets/controls", 199 | "model_name": "DescriptionStyleModel", 200 | "state": { 201 | "_view_name": "StyleView", 202 | "_model_name": "DescriptionStyleModel", 203 | "description_width": "", 204 | "_view_module": "@jupyter-widgets/base", 205 | "_model_module_version": "1.5.0", 206 | "_view_count": null, 207 | "_view_module_version": "1.2.0", 208 | "_model_module": "@jupyter-widgets/controls" 209 | } 210 | }, 211 | "b0df6eda571748eaa741673ef3cf090f": { 212 | "model_module": "@jupyter-widgets/base", 213 | "model_name": "LayoutModel", 214 | "state": { 215 | "_view_name": "LayoutView", 216 | "grid_template_rows": null, 217 | "right": null, 218 | "justify_content": null, 219 | "_view_module": "@jupyter-widgets/base", 220 | "overflow": null, 221 | "_model_module_version": "1.2.0", 222 | "_view_count": null, 223 | "flex_flow": null, 224 | "width": null, 225 | "min_width": null, 226 | "border": null, 227 | "align_items": null, 228 | "bottom": null, 229 | "_model_module": "@jupyter-widgets/base", 230 | "top": null, 231 | "grid_column": null, 232 | "overflow_y": null, 233 | "overflow_x": null, 234 | "grid_auto_flow": null, 235 | "grid_area": null, 236 | "grid_template_columns": null, 237 | "flex": null, 238 | "_model_name": "LayoutModel", 239 | "justify_items": null, 240 | "grid_row": null, 241 | "max_height": null, 242 | "align_content": null, 243 | "visibility": null, 244 | "align_self": null, 245 | "height": null, 246 | "min_height": null, 247 | "padding": null, 248 | "grid_auto_rows": null, 249 | "grid_gap": null, 250 | "max_width": null, 251 | "order": null, 252 | "_view_module_version": "1.2.0", 253 | "grid_template_areas": null, 254 | "object_position": null, 255 | "object_fit": null, 256 | "grid_auto_columns": null, 257 | "margin": null, 258 | "display": null, 259 | "left": null 260 | } 261 | }, 262 | "438699af2ed94a3eb67e4f408bd1e7a8": { 263 | "model_module": "@jupyter-widgets/controls", 264 | "model_name": "HBoxModel", 265 | "state": { 266 | "_view_name": "HBoxView", 267 | "_dom_classes": [], 268 | "_model_name": "HBoxModel", 269 | "_view_module": "@jupyter-widgets/controls", 270 | "_model_module_version": "1.5.0", 271 | "_view_count": null, 272 | "_view_module_version": "1.5.0", 273 | "box_style": "", 274 | "layout": "IPY_MODEL_d948c937f41a41a6b66587aaee7f2202", 275 | "_model_module": "@jupyter-widgets/controls", 276 | "children": [ 277 | "IPY_MODEL_5777b6af6ec84082babcd7616c4e5c02", 278 | "IPY_MODEL_4c97103d541d4016be6c513edf6f4b4d" 279 | ] 280 | } 281 | }, 282 | "d948c937f41a41a6b66587aaee7f2202": { 283 | "model_module": "@jupyter-widgets/base", 284 | "model_name": "LayoutModel", 285 | "state": { 286 | "_view_name": "LayoutView", 287 | "grid_template_rows": null, 288 | "right": null, 289 | "justify_content": null, 290 | "_view_module": "@jupyter-widgets/base", 291 | "overflow": null, 292 | "_model_module_version": "1.2.0", 293 | "_view_count": null, 294 | "flex_flow": null, 295 | "width": null, 296 | "min_width": null, 297 | "border": null, 298 | "align_items": null, 299 | "bottom": null, 300 | "_model_module": "@jupyter-widgets/base", 301 | "top": null, 302 | "grid_column": null, 303 | "overflow_y": null, 304 | "overflow_x": null, 305 | "grid_auto_flow": null, 306 | "grid_area": null, 307 | "grid_template_columns": null, 308 | "flex": null, 309 | "_model_name": "LayoutModel", 310 | "justify_items": null, 311 | "grid_row": null, 312 | "max_height": null, 313 | "align_content": null, 314 | "visibility": null, 315 | "align_self": null, 316 | "height": null, 317 | "min_height": null, 318 | "padding": null, 319 | "grid_auto_rows": null, 320 | "grid_gap": null, 321 | "max_width": null, 322 | "order": null, 323 | "_view_module_version": "1.2.0", 324 | "grid_template_areas": null, 325 | "object_position": null, 326 | "object_fit": null, 327 | "grid_auto_columns": null, 328 | "margin": null, 329 | "display": null, 330 | "left": null 331 | } 332 | }, 333 | "5777b6af6ec84082babcd7616c4e5c02": { 334 | "model_module": "@jupyter-widgets/controls", 335 | "model_name": "FloatProgressModel", 336 | "state": { 337 | "_view_name": "ProgressView", 338 | "style": "IPY_MODEL_f5864dec9b1641d6a69059cdd0f7254e", 339 | "_dom_classes": [], 340 | "description": "100%", 341 | "_model_name": "FloatProgressModel", 342 | "bar_style": "success", 343 | "max": 250, 344 | "_view_module": "@jupyter-widgets/controls", 345 | "_model_module_version": "1.5.0", 346 | "value": 250, 347 | "_view_count": null, 348 | "_view_module_version": "1.5.0", 349 | "orientation": "horizontal", 350 | "min": 0, 351 | "description_tooltip": null, 352 | "_model_module": "@jupyter-widgets/controls", 353 | "layout": "IPY_MODEL_1b4aa502d7234a6c8d740ed7b1cb4325" 354 | } 355 | }, 356 | "4c97103d541d4016be6c513edf6f4b4d": { 357 | "model_module": "@jupyter-widgets/controls", 358 | "model_name": "HTMLModel", 359 | "state": { 360 | "_view_name": "HTMLView", 361 | "style": "IPY_MODEL_9d2f753a736946b4939711ac254ca213", 362 | "_dom_classes": [], 363 | "description": "", 364 | "_model_name": "HTMLModel", 365 | "placeholder": "​", 366 | "_view_module": "@jupyter-widgets/controls", 367 | "_model_module_version": "1.5.0", 368 | "value": " 250/250 [00:01<00:00, 216.34it/s]", 369 | "_view_count": null, 370 | "_view_module_version": "1.5.0", 371 | "description_tooltip": null, 372 | "_model_module": "@jupyter-widgets/controls", 373 | "layout": "IPY_MODEL_ec97824c3bd6422cbed5093a5db9acf6" 374 | } 375 | }, 376 | "f5864dec9b1641d6a69059cdd0f7254e": { 377 | "model_module": "@jupyter-widgets/controls", 378 | "model_name": "ProgressStyleModel", 379 | "state": { 380 | "_view_name": "StyleView", 381 | "_model_name": "ProgressStyleModel", 382 | "description_width": "initial", 383 | "_view_module": "@jupyter-widgets/base", 384 | "_model_module_version": "1.5.0", 385 | "_view_count": null, 386 | "_view_module_version": "1.2.0", 387 | "bar_color": null, 388 | "_model_module": "@jupyter-widgets/controls" 389 | } 390 | }, 391 | "1b4aa502d7234a6c8d740ed7b1cb4325": { 392 | "model_module": "@jupyter-widgets/base", 393 | "model_name": "LayoutModel", 394 | "state": { 395 | "_view_name": "LayoutView", 396 | "grid_template_rows": null, 397 | "right": null, 398 | "justify_content": null, 399 | "_view_module": "@jupyter-widgets/base", 400 | "overflow": null, 401 | "_model_module_version": "1.2.0", 402 | "_view_count": null, 403 | "flex_flow": null, 404 | "width": null, 405 | "min_width": null, 406 | "border": null, 407 | "align_items": null, 408 | "bottom": null, 409 | "_model_module": "@jupyter-widgets/base", 410 | "top": null, 411 | "grid_column": null, 412 | "overflow_y": null, 413 | "overflow_x": null, 414 | "grid_auto_flow": null, 415 | "grid_area": null, 416 | "grid_template_columns": null, 417 | "flex": null, 418 | "_model_name": "LayoutModel", 419 | "justify_items": null, 420 | "grid_row": null, 421 | "max_height": null, 422 | "align_content": null, 423 | "visibility": null, 424 | "align_self": null, 425 | "height": null, 426 | "min_height": null, 427 | "padding": null, 428 | "grid_auto_rows": null, 429 | "grid_gap": null, 430 | "max_width": null, 431 | "order": null, 432 | "_view_module_version": "1.2.0", 433 | "grid_template_areas": null, 434 | "object_position": null, 435 | "object_fit": null, 436 | "grid_auto_columns": null, 437 | "margin": null, 438 | "display": null, 439 | "left": null 440 | } 441 | }, 442 | "9d2f753a736946b4939711ac254ca213": { 443 | "model_module": "@jupyter-widgets/controls", 444 | "model_name": "DescriptionStyleModel", 445 | "state": { 446 | "_view_name": "StyleView", 447 | "_model_name": "DescriptionStyleModel", 448 | "description_width": "", 449 | "_view_module": "@jupyter-widgets/base", 450 | "_model_module_version": "1.5.0", 451 | "_view_count": null, 452 | "_view_module_version": "1.2.0", 453 | "_model_module": "@jupyter-widgets/controls" 454 | } 455 | }, 456 | "ec97824c3bd6422cbed5093a5db9acf6": { 457 | "model_module": "@jupyter-widgets/base", 458 | "model_name": "LayoutModel", 459 | "state": { 460 | "_view_name": "LayoutView", 461 | "grid_template_rows": null, 462 | "right": null, 463 | "justify_content": null, 464 | "_view_module": "@jupyter-widgets/base", 465 | "overflow": null, 466 | "_model_module_version": "1.2.0", 467 | "_view_count": null, 468 | "flex_flow": null, 469 | "width": null, 470 | "min_width": null, 471 | "border": null, 472 | "align_items": null, 473 | "bottom": null, 474 | "_model_module": "@jupyter-widgets/base", 475 | "top": null, 476 | "grid_column": null, 477 | "overflow_y": null, 478 | "overflow_x": null, 479 | "grid_auto_flow": null, 480 | "grid_area": null, 481 | "grid_template_columns": null, 482 | "flex": null, 483 | "_model_name": "LayoutModel", 484 | "justify_items": null, 485 | "grid_row": null, 486 | "max_height": null, 487 | "align_content": null, 488 | "visibility": null, 489 | "align_self": null, 490 | "height": null, 491 | "min_height": null, 492 | "padding": null, 493 | "grid_auto_rows": null, 494 | "grid_gap": null, 495 | "max_width": null, 496 | "order": null, 497 | "_view_module_version": "1.2.0", 498 | "grid_template_areas": null, 499 | "object_position": null, 500 | "object_fit": null, 501 | "grid_auto_columns": null, 502 | "margin": null, 503 | "display": null, 504 | "left": null 505 | } 506 | } 507 | } 508 | } 509 | }, 510 | "cells": [ 511 | { 512 | "cell_type": "markdown", 513 | "metadata": { 514 | "id": "JuiT6O71HUAy", 515 | "colab_type": "text" 516 | }, 517 | "source": [ 518 | "# Initial Setup" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "metadata": { 524 | "id": "FgWG4d-K3xRt", 525 | "colab_type": "code", 526 | "colab": {} 527 | }, 528 | "source": [ 529 | "import tensorflow as tf\n", 530 | "print(tf.__version__)" 531 | ], 532 | "execution_count": 0, 533 | "outputs": [] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "metadata": { 538 | "id": "gtxvkdsm338L", 539 | "colab_type": "code", 540 | "colab": {} 541 | }, 542 | "source": [ 543 | "!pip install wandb\n", 544 | "import wandb\n", 545 | "wandb.login()" 546 | ], 547 | "execution_count": 0, 548 | "outputs": [] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "metadata": { 553 | "id": "E8cv8vit3ydm", 554 | "colab_type": "code", 555 | "colab": {} 556 | }, 557 | "source": [ 558 | "from tensorflow.keras.layers import *\n", 559 | "from tensorflow.keras.models import *\n", 560 | "from wandb.keras import WandbCallback\n", 561 | "import tensorflow_datasets as tfds\n", 562 | "import matplotlib.pyplot as plt\n", 563 | "import numpy as np\n", 564 | "import time\n", 565 | "import cv2\n", 566 | "from tqdm.notebook import tqdm\n", 567 | "from imutils import paths\n", 568 | "tf.random.set_seed(666)\n", 569 | "np.random.seed(666)\n", 570 | "\n", 571 | "tfds.disable_progress_bar()" 572 | ], 573 | "execution_count": 0, 574 | "outputs": [] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": { 579 | "id": "ebM6CaFsHcya", 580 | "colab_type": "text" 581 | }, 582 | "source": [ 583 | "# Imagenet Subset " 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "metadata": { 589 | "id": "4vPz9Alk31qZ", 590 | "colab_type": "code", 591 | "outputId": "5961df4c-b0fd-4c31-99a6-3ecbffba9eec", 592 | "colab": { 593 | "base_uri": "https://localhost:8080/", 594 | "height": 34 595 | } 596 | }, 597 | "source": [ 598 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n" 599 | ], 600 | "execution_count": 0, 601 | "outputs": [ 602 | { 603 | "output_type": "stream", 604 | "text": [ 605 | "fatal: destination path 'imagenet-5-categories' already exists and is not an empty directory.\n" 606 | ], 607 | "name": "stdout" 608 | } 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "metadata": { 614 | "id": "5vVPALgj4Ogg", 615 | "colab_type": "code", 616 | "colab": {} 617 | }, 618 | "source": [ 619 | "# Train and test image paths\n", 620 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n", 621 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n" 622 | ], 623 | "execution_count": 0, 624 | "outputs": [] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "metadata": { 629 | "id": "YM_w3yZi4RQf", 630 | "colab_type": "code", 631 | "colab": {} 632 | }, 633 | "source": [ 634 | "def prepare_images(image_paths):\n", 635 | " images = []\n", 636 | " labels = []\n", 637 | "\n", 638 | " for image in tqdm(image_paths):\n", 639 | " image_pixels = plt.imread(image)\n", 640 | " image_pixels = cv2.resize(image_pixels, (128,128))\n", 641 | " image_pixels = image_pixels/255.\n", 642 | "\n", 643 | " label = image.split(\"/\")[2].split(\"_\")[0]\n", 644 | "\n", 645 | " images.append(image_pixels)\n", 646 | " labels.append(label)\n", 647 | "\n", 648 | " images = np.array(images)\n", 649 | " labels = np.array(labels)\n", 650 | "\n", 651 | " print(images.shape, labels.shape)\n", 652 | "\n", 653 | " return images, labels" 654 | ], 655 | "execution_count": 0, 656 | "outputs": [] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "metadata": { 661 | "id": "KeNWTqpG4b0e", 662 | "colab_type": "code", 663 | "outputId": "2b8f6cf4-97c6-4265-dbf2-9b6262043410", 664 | "colab": { 665 | "base_uri": "https://localhost:8080/", 666 | "height": 148, 667 | "referenced_widgets": [ 668 | "0b4c5d36a97843e6bfc07d37a8e3cb6a", 669 | "e8ef5e7fafb142aaa238fd8ae119b315", 670 | "fb1635d6c5da4d4893558f5d32c234f2", 671 | "0dae8ce37add4e81825f319cb66620df", 672 | "e80797a1b4e24f19a9ce53ce7e9e9299", 673 | "4d2932cd6daf495d972886f07bfcc227", 674 | "717104f84c7044818e3f5e4346de1ec0", 675 | "b0df6eda571748eaa741673ef3cf090f", 676 | "438699af2ed94a3eb67e4f408bd1e7a8", 677 | "d948c937f41a41a6b66587aaee7f2202", 678 | "5777b6af6ec84082babcd7616c4e5c02", 679 | "4c97103d541d4016be6c513edf6f4b4d", 680 | "f5864dec9b1641d6a69059cdd0f7254e", 681 | "1b4aa502d7234a6c8d740ed7b1cb4325", 682 | "9d2f753a736946b4939711ac254ca213", 683 | "ec97824c3bd6422cbed5093a5db9acf6" 684 | ] 685 | } 686 | }, 687 | "source": [ 688 | "X_train, y_train = prepare_images(train_images)\n", 689 | "X_test, y_test = prepare_images(test_images)" 690 | ], 691 | "execution_count": 0, 692 | "outputs": [ 693 | { 694 | "output_type": "display_data", 695 | "data": { 696 | "application/vnd.jupyter.widget-view+json": { 697 | "model_id": "0b4c5d36a97843e6bfc07d37a8e3cb6a", 698 | "version_minor": 0, 699 | "version_major": 2 700 | }, 701 | "text/plain": [ 702 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))" 703 | ] 704 | }, 705 | "metadata": { 706 | "tags": [] 707 | } 708 | }, 709 | { 710 | "output_type": "stream", 711 | "text": [ 712 | "\n", 713 | "(1250, 128, 128, 3) (1250,)\n" 714 | ], 715 | "name": "stdout" 716 | }, 717 | { 718 | "output_type": "display_data", 719 | "data": { 720 | "application/vnd.jupyter.widget-view+json": { 721 | "model_id": "438699af2ed94a3eb67e4f408bd1e7a8", 722 | "version_minor": 0, 723 | "version_major": 2 724 | }, 725 | "text/plain": [ 726 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))" 727 | ] 728 | }, 729 | "metadata": { 730 | "tags": [] 731 | } 732 | }, 733 | { 734 | "output_type": "stream", 735 | "text": [ 736 | "\n", 737 | "(250, 128, 128, 3) (250,)\n" 738 | ], 739 | "name": "stdout" 740 | } 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "metadata": { 746 | "id": "qGdwXZJk4eDH", 747 | "colab_type": "code", 748 | "colab": {} 749 | }, 750 | "source": [ 751 | "from sklearn import preprocessing\n", 752 | "le = preprocessing.LabelEncoder()\n", 753 | "y_train_enc = le.fit_transform(y_train)\n", 754 | "y_test_enc = le.transform(y_test)\n" 755 | ], 756 | "execution_count": 0, 757 | "outputs": [] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "metadata": { 762 | "id": "nmX3x8wE4zBo", 763 | "colab_type": "code", 764 | "colab": {} 765 | }, 766 | "source": [ 767 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n", 768 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))" 769 | ], 770 | "execution_count": 0, 771 | "outputs": [] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "metadata": { 776 | "id": "9yBgpLe443a-", 777 | "colab_type": "code", 778 | "colab": {} 779 | }, 780 | "source": [ 781 | "@tf.function\n", 782 | "def aug(image, label):\n", 783 | " x=tf.image.random_brightness(image,max_delta=0)\n", 784 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n", 785 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n", 786 | " x = tf.image.random_hue(x, max_delta=0.4)\n", 787 | " x = tf.clip_by_value(x, 0, 1)\n", 788 | "\n", 789 | " return x, label" 790 | ], 791 | "execution_count": 0, 792 | "outputs": [] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "metadata": { 797 | "id": "icCj5VGk45ce", 798 | "colab_type": "code", 799 | "colab": {} 800 | }, 801 | "source": [ 802 | "IMG_SHAPE = 128\n", 803 | "BS = 64\n", 804 | "AUTO = tf.data.experimental.AUTOTUNE\n", 805 | "train_ds = (\n", 806 | " train_ds\n", 807 | " .shuffle(100)\n", 808 | " .batch(BS)\n", 809 | " .map(aug, num_parallel_calls=AUTO)\n", 810 | " .prefetch(AUTO)\n", 811 | ")\n", 812 | "validation_ds = (\n", 813 | " validation_ds\n", 814 | " .shuffle(100)\n", 815 | " .batch(BS)\n", 816 | " .prefetch(AUTO)\n", 817 | ")" 818 | ], 819 | "execution_count": 0, 820 | "outputs": [] 821 | }, 822 | { 823 | "cell_type": "markdown", 824 | "metadata": { 825 | "id": "tkxjWEeIHrCf", 826 | "colab_type": "text" 827 | }, 828 | "source": [ 829 | "# Model building and training wih RMSprop\n" 830 | ] 831 | }, 832 | { 833 | "cell_type": "code", 834 | "metadata": { 835 | "id": "umbRNW-A4755", 836 | "colab_type": "code", 837 | "colab": {} 838 | }, 839 | "source": [ 840 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n", 841 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])" 842 | ], 843 | "execution_count": 0, 844 | "outputs": [] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "metadata": { 849 | "id": "WVilaFIu5Hft", 850 | "colab_type": "code", 851 | "colab": {} 852 | }, 853 | "source": [ 854 | "decay_steps = 1000\n", 855 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n", 856 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n", 857 | "\n", 858 | "model.compile(optimizer=tf.keras.optimizers.RMSprop(lr_decayed_fn),\n", 859 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 860 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])" 861 | ], 862 | "execution_count": 0, 863 | "outputs": [] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "metadata": { 868 | "id": "X3PxnSYd5x2W", 869 | "colab_type": "code", 870 | "colab": {} 871 | }, 872 | "source": [ 873 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n", 874 | "\trestore_best_weights=True, verbose=2)" 875 | ], 876 | "execution_count": 0, 877 | "outputs": [] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "metadata": { 882 | "id": "S5lMKwCQ54KX", 883 | "colab_type": "code", 884 | "outputId": "c6e1573d-c593-4447-c182-34674d7ab6ef", 885 | "colab": { 886 | "base_uri": "https://localhost:8080/", 887 | "height": 390 888 | } 889 | }, 890 | "source": [ 891 | "import time\n", 892 | "import wandb\n", 893 | "\n", 894 | "wandb.init(entity='authors',project='scl')\n", 895 | "start = time.time()\n", 896 | "model.fit(train_ds,\n", 897 | " validation_data=validation_ds,\n", 898 | " epochs=50,\n", 899 | " callbacks=[wandb.keras.WandbCallback(), es])\n", 900 | "end = time.time()\n", 901 | "wandb.log({\"training_time\": end - start})" 902 | ], 903 | "execution_count": 0, 904 | "outputs": [ 905 | { 906 | "output_type": "display_data", 907 | "data": { 908 | "text/html": [ 909 | "\n", 910 | " Logging results to Weights & Biases (Documentation).
\n", 911 | " Project page: https://app.wandb.ai/authors/scl
\n", 912 | " Run page: https://app.wandb.ai/authors/scl/runs/2h40mbhd
\n", 913 | " " 914 | ], 915 | "text/plain": [ 916 | "" 917 | ] 918 | }, 919 | "metadata": { 920 | "tags": [] 921 | } 922 | }, 923 | { 924 | "output_type": "stream", 925 | "text": [ 926 | "Epoch 1/50\n", 927 | "20/20 [==============================] - 6s 312ms/step - loss: 1.6051 - sparse_categorical_accuracy: 0.2976 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n", 928 | "Epoch 2/50\n", 929 | "20/20 [==============================] - 5s 263ms/step - loss: 1.5929 - sparse_categorical_accuracy: 0.3056 - val_loss: 1.7039 - val_sparse_categorical_accuracy: 0.2000\n", 930 | "Epoch 3/50\n", 931 | "20/20 [==============================] - 5s 271ms/step - loss: 1.6160 - sparse_categorical_accuracy: 0.2792 - val_loss: 1.6867 - val_sparse_categorical_accuracy: 0.2120\n", 932 | "Epoch 4/50\n", 933 | "20/20 [==============================] - 4s 223ms/step - loss: 1.5946 - sparse_categorical_accuracy: 0.2776 - val_loss: 1.6938 - val_sparse_categorical_accuracy: 0.2000\n", 934 | "Epoch 5/50\n", 935 | "20/20 [==============================] - 5s 273ms/step - loss: 1.5765 - sparse_categorical_accuracy: 0.2880 - val_loss: 1.5870 - val_sparse_categorical_accuracy: 0.3040\n", 936 | "Epoch 6/50\n", 937 | "20/20 [==============================] - 6s 276ms/step - loss: 1.5180 - sparse_categorical_accuracy: 0.3552 - val_loss: 1.5748 - val_sparse_categorical_accuracy: 0.3080\n", 938 | "Epoch 7/50\n", 939 | "20/20 [==============================] - 5s 227ms/step - loss: 1.5345 - sparse_categorical_accuracy: 0.3512 - val_loss: 1.6253 - val_sparse_categorical_accuracy: 0.2680\n", 940 | "Epoch 8/50\n", 941 | "20/20 [==============================] - ETA: 0s - loss: 1.5690 - sparse_categorical_accuracy: 0.3136Restoring model weights from the end of the best epoch.\n", 942 | "20/20 [==============================] - 5s 231ms/step - loss: 1.5690 - sparse_categorical_accuracy: 0.3136 - val_loss: 1.6235 - val_sparse_categorical_accuracy: 0.2800\n", 943 | "Epoch 00008: early stopping\n" 944 | ], 945 | "name": "stdout" 946 | } 947 | ] 948 | }, 949 | { 950 | "cell_type": "code", 951 | "metadata": { 952 | "id": "edc3Fu_C6AJO", 953 | "colab_type": "code", 954 | "colab": {} 955 | }, 956 | "source": [ 957 | "model.save_weights(\"full_supervised_learning.h5\")" 958 | ], 959 | "execution_count": 0, 960 | "outputs": [] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "metadata": { 965 | "id": "wOPN7pPwBN0V", 966 | "colab_type": "code", 967 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd", 968 | "colab": { 969 | "base_uri": "https://localhost:8080/", 970 | "height": 34 971 | } 972 | }, 973 | "source": [ 974 | "wandb.save(\"full_supervised_learning.h5\")" 975 | ], 976 | "execution_count": 0, 977 | "outputs": [ 978 | { 979 | "output_type": "execute_result", 980 | "data": { 981 | "text/plain": [ 982 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']" 983 | ] 984 | }, 985 | "metadata": { 986 | "tags": [] 987 | }, 988 | "execution_count": 94 989 | } 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "metadata": { 995 | "id": "FNk0NhWFBSYe", 996 | "colab_type": "code", 997 | "colab": {} 998 | }, 999 | "source": [ 1000 | "" 1001 | ], 1002 | "execution_count": 0, 1003 | "outputs": [] 1004 | } 1005 | ] 1006 | } -------------------------------------------------------------------------------- /ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_Adam.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Fully_Supervised_Training_IMGNET_subset_Adam.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU", 15 | "widgets": { 16 | "application/vnd.jupyter.widget-state+json": { 17 | "2a8b0bfc58804efea484cd4cd71dd00f": { 18 | "model_module": "@jupyter-widgets/controls", 19 | "model_name": "HBoxModel", 20 | "state": { 21 | "_view_name": "HBoxView", 22 | "_dom_classes": [], 23 | "_model_name": "HBoxModel", 24 | "_view_module": "@jupyter-widgets/controls", 25 | "_model_module_version": "1.5.0", 26 | "_view_count": null, 27 | "_view_module_version": "1.5.0", 28 | "box_style": "", 29 | "layout": "IPY_MODEL_126f1d14800843aa9332f00aead8395b", 30 | "_model_module": "@jupyter-widgets/controls", 31 | "children": [ 32 | "IPY_MODEL_805d93bf8f4446d3871cd93392101584", 33 | "IPY_MODEL_33d042e9035942359936f5d4bd6a108a" 34 | ] 35 | } 36 | }, 37 | "126f1d14800843aa9332f00aead8395b": { 38 | "model_module": "@jupyter-widgets/base", 39 | "model_name": "LayoutModel", 40 | "state": { 41 | "_view_name": "LayoutView", 42 | "grid_template_rows": null, 43 | "right": null, 44 | "justify_content": null, 45 | "_view_module": "@jupyter-widgets/base", 46 | "overflow": null, 47 | "_model_module_version": "1.2.0", 48 | "_view_count": null, 49 | "flex_flow": null, 50 | "width": null, 51 | "min_width": null, 52 | "border": null, 53 | "align_items": null, 54 | "bottom": null, 55 | "_model_module": "@jupyter-widgets/base", 56 | "top": null, 57 | "grid_column": null, 58 | "overflow_y": null, 59 | "overflow_x": null, 60 | "grid_auto_flow": null, 61 | "grid_area": null, 62 | "grid_template_columns": null, 63 | "flex": null, 64 | "_model_name": "LayoutModel", 65 | "justify_items": null, 66 | "grid_row": null, 67 | "max_height": null, 68 | "align_content": null, 69 | "visibility": null, 70 | "align_self": null, 71 | "height": null, 72 | "min_height": null, 73 | "padding": null, 74 | "grid_auto_rows": null, 75 | "grid_gap": null, 76 | "max_width": null, 77 | "order": null, 78 | "_view_module_version": "1.2.0", 79 | "grid_template_areas": null, 80 | "object_position": null, 81 | "object_fit": null, 82 | "grid_auto_columns": null, 83 | "margin": null, 84 | "display": null, 85 | "left": null 86 | } 87 | }, 88 | "805d93bf8f4446d3871cd93392101584": { 89 | "model_module": "@jupyter-widgets/controls", 90 | "model_name": "FloatProgressModel", 91 | "state": { 92 | "_view_name": "ProgressView", 93 | "style": "IPY_MODEL_05b5252d8859432c8724c8fb9636ef33", 94 | "_dom_classes": [], 95 | "description": "100%", 96 | "_model_name": "FloatProgressModel", 97 | "bar_style": "success", 98 | "max": 1250, 99 | "_view_module": "@jupyter-widgets/controls", 100 | "_model_module_version": "1.5.0", 101 | "value": 1250, 102 | "_view_count": null, 103 | "_view_module_version": "1.5.0", 104 | "orientation": "horizontal", 105 | "min": 0, 106 | "description_tooltip": null, 107 | "_model_module": "@jupyter-widgets/controls", 108 | "layout": "IPY_MODEL_77179de1c044473abb9da855d5ab8eca" 109 | } 110 | }, 111 | "33d042e9035942359936f5d4bd6a108a": { 112 | "model_module": "@jupyter-widgets/controls", 113 | "model_name": "HTMLModel", 114 | "state": { 115 | "_view_name": "HTMLView", 116 | "style": "IPY_MODEL_eb72220a381347d1bce57c5706554460", 117 | "_dom_classes": [], 118 | "description": "", 119 | "_model_name": "HTMLModel", 120 | "placeholder": "​", 121 | "_view_module": "@jupyter-widgets/controls", 122 | "_model_module_version": "1.5.0", 123 | "value": " 1250/1250 [00:06<00:00, 187.65it/s]", 124 | "_view_count": null, 125 | "_view_module_version": "1.5.0", 126 | "description_tooltip": null, 127 | "_model_module": "@jupyter-widgets/controls", 128 | "layout": "IPY_MODEL_fe2c9cb4e208436db3c15d342f8d08e7" 129 | } 130 | }, 131 | "05b5252d8859432c8724c8fb9636ef33": { 132 | "model_module": "@jupyter-widgets/controls", 133 | "model_name": "ProgressStyleModel", 134 | "state": { 135 | "_view_name": "StyleView", 136 | "_model_name": "ProgressStyleModel", 137 | "description_width": "initial", 138 | "_view_module": "@jupyter-widgets/base", 139 | "_model_module_version": "1.5.0", 140 | "_view_count": null, 141 | "_view_module_version": "1.2.0", 142 | "bar_color": null, 143 | "_model_module": "@jupyter-widgets/controls" 144 | } 145 | }, 146 | "77179de1c044473abb9da855d5ab8eca": { 147 | "model_module": "@jupyter-widgets/base", 148 | "model_name": "LayoutModel", 149 | "state": { 150 | "_view_name": "LayoutView", 151 | "grid_template_rows": null, 152 | "right": null, 153 | "justify_content": null, 154 | "_view_module": "@jupyter-widgets/base", 155 | "overflow": null, 156 | "_model_module_version": "1.2.0", 157 | "_view_count": null, 158 | "flex_flow": null, 159 | "width": null, 160 | "min_width": null, 161 | "border": null, 162 | "align_items": null, 163 | "bottom": null, 164 | "_model_module": "@jupyter-widgets/base", 165 | "top": null, 166 | "grid_column": null, 167 | "overflow_y": null, 168 | "overflow_x": null, 169 | "grid_auto_flow": null, 170 | "grid_area": null, 171 | "grid_template_columns": null, 172 | "flex": null, 173 | "_model_name": "LayoutModel", 174 | "justify_items": null, 175 | "grid_row": null, 176 | "max_height": null, 177 | "align_content": null, 178 | "visibility": null, 179 | "align_self": null, 180 | "height": null, 181 | "min_height": null, 182 | "padding": null, 183 | "grid_auto_rows": null, 184 | "grid_gap": null, 185 | "max_width": null, 186 | "order": null, 187 | "_view_module_version": "1.2.0", 188 | "grid_template_areas": null, 189 | "object_position": null, 190 | "object_fit": null, 191 | "grid_auto_columns": null, 192 | "margin": null, 193 | "display": null, 194 | "left": null 195 | } 196 | }, 197 | "eb72220a381347d1bce57c5706554460": { 198 | "model_module": "@jupyter-widgets/controls", 199 | "model_name": "DescriptionStyleModel", 200 | "state": { 201 | "_view_name": "StyleView", 202 | "_model_name": "DescriptionStyleModel", 203 | "description_width": "", 204 | "_view_module": "@jupyter-widgets/base", 205 | "_model_module_version": "1.5.0", 206 | "_view_count": null, 207 | "_view_module_version": "1.2.0", 208 | "_model_module": "@jupyter-widgets/controls" 209 | } 210 | }, 211 | "fe2c9cb4e208436db3c15d342f8d08e7": { 212 | "model_module": "@jupyter-widgets/base", 213 | "model_name": "LayoutModel", 214 | "state": { 215 | "_view_name": "LayoutView", 216 | "grid_template_rows": null, 217 | "right": null, 218 | "justify_content": null, 219 | "_view_module": "@jupyter-widgets/base", 220 | "overflow": null, 221 | "_model_module_version": "1.2.0", 222 | "_view_count": null, 223 | "flex_flow": null, 224 | "width": null, 225 | "min_width": null, 226 | "border": null, 227 | "align_items": null, 228 | "bottom": null, 229 | "_model_module": "@jupyter-widgets/base", 230 | "top": null, 231 | "grid_column": null, 232 | "overflow_y": null, 233 | "overflow_x": null, 234 | "grid_auto_flow": null, 235 | "grid_area": null, 236 | "grid_template_columns": null, 237 | "flex": null, 238 | "_model_name": "LayoutModel", 239 | "justify_items": null, 240 | "grid_row": null, 241 | "max_height": null, 242 | "align_content": null, 243 | "visibility": null, 244 | "align_self": null, 245 | "height": null, 246 | "min_height": null, 247 | "padding": null, 248 | "grid_auto_rows": null, 249 | "grid_gap": null, 250 | "max_width": null, 251 | "order": null, 252 | "_view_module_version": "1.2.0", 253 | "grid_template_areas": null, 254 | "object_position": null, 255 | "object_fit": null, 256 | "grid_auto_columns": null, 257 | "margin": null, 258 | "display": null, 259 | "left": null 260 | } 261 | }, 262 | "c7e81694439c41408bb1090cb82cd2b1": { 263 | "model_module": "@jupyter-widgets/controls", 264 | "model_name": "HBoxModel", 265 | "state": { 266 | "_view_name": "HBoxView", 267 | "_dom_classes": [], 268 | "_model_name": "HBoxModel", 269 | "_view_module": "@jupyter-widgets/controls", 270 | "_model_module_version": "1.5.0", 271 | "_view_count": null, 272 | "_view_module_version": "1.5.0", 273 | "box_style": "", 274 | "layout": "IPY_MODEL_7683c42a14e14250bf6e375a42028b87", 275 | "_model_module": "@jupyter-widgets/controls", 276 | "children": [ 277 | "IPY_MODEL_aac930db883d4977a9f3f08b1a0fa7a5", 278 | "IPY_MODEL_3862f5cde98f49f783d8fc54ada40a11" 279 | ] 280 | } 281 | }, 282 | "7683c42a14e14250bf6e375a42028b87": { 283 | "model_module": "@jupyter-widgets/base", 284 | "model_name": "LayoutModel", 285 | "state": { 286 | "_view_name": "LayoutView", 287 | "grid_template_rows": null, 288 | "right": null, 289 | "justify_content": null, 290 | "_view_module": "@jupyter-widgets/base", 291 | "overflow": null, 292 | "_model_module_version": "1.2.0", 293 | "_view_count": null, 294 | "flex_flow": null, 295 | "width": null, 296 | "min_width": null, 297 | "border": null, 298 | "align_items": null, 299 | "bottom": null, 300 | "_model_module": "@jupyter-widgets/base", 301 | "top": null, 302 | "grid_column": null, 303 | "overflow_y": null, 304 | "overflow_x": null, 305 | "grid_auto_flow": null, 306 | "grid_area": null, 307 | "grid_template_columns": null, 308 | "flex": null, 309 | "_model_name": "LayoutModel", 310 | "justify_items": null, 311 | "grid_row": null, 312 | "max_height": null, 313 | "align_content": null, 314 | "visibility": null, 315 | "align_self": null, 316 | "height": null, 317 | "min_height": null, 318 | "padding": null, 319 | "grid_auto_rows": null, 320 | "grid_gap": null, 321 | "max_width": null, 322 | "order": null, 323 | "_view_module_version": "1.2.0", 324 | "grid_template_areas": null, 325 | "object_position": null, 326 | "object_fit": null, 327 | "grid_auto_columns": null, 328 | "margin": null, 329 | "display": null, 330 | "left": null 331 | } 332 | }, 333 | "aac930db883d4977a9f3f08b1a0fa7a5": { 334 | "model_module": "@jupyter-widgets/controls", 335 | "model_name": "FloatProgressModel", 336 | "state": { 337 | "_view_name": "ProgressView", 338 | "style": "IPY_MODEL_e9b5eef0caa148a8aaaa2ed576fe68cc", 339 | "_dom_classes": [], 340 | "description": "100%", 341 | "_model_name": "FloatProgressModel", 342 | "bar_style": "success", 343 | "max": 250, 344 | "_view_module": "@jupyter-widgets/controls", 345 | "_model_module_version": "1.5.0", 346 | "value": 250, 347 | "_view_count": null, 348 | "_view_module_version": "1.5.0", 349 | "orientation": "horizontal", 350 | "min": 0, 351 | "description_tooltip": null, 352 | "_model_module": "@jupyter-widgets/controls", 353 | "layout": "IPY_MODEL_2d671cb0ab744e509c98e0d1b2fcac17" 354 | } 355 | }, 356 | "3862f5cde98f49f783d8fc54ada40a11": { 357 | "model_module": "@jupyter-widgets/controls", 358 | "model_name": "HTMLModel", 359 | "state": { 360 | "_view_name": "HTMLView", 361 | "style": "IPY_MODEL_79ff0547109643e8b8002a9bf1e6ebcc", 362 | "_dom_classes": [], 363 | "description": "", 364 | "_model_name": "HTMLModel", 365 | "placeholder": "​", 366 | "_view_module": "@jupyter-widgets/controls", 367 | "_model_module_version": "1.5.0", 368 | "value": " 250/250 [00:01<00:00, 221.46it/s]", 369 | "_view_count": null, 370 | "_view_module_version": "1.5.0", 371 | "description_tooltip": null, 372 | "_model_module": "@jupyter-widgets/controls", 373 | "layout": "IPY_MODEL_8c382843a4f9475fb1a084b08ea0a7ae" 374 | } 375 | }, 376 | "e9b5eef0caa148a8aaaa2ed576fe68cc": { 377 | "model_module": "@jupyter-widgets/controls", 378 | "model_name": "ProgressStyleModel", 379 | "state": { 380 | "_view_name": "StyleView", 381 | "_model_name": "ProgressStyleModel", 382 | "description_width": "initial", 383 | "_view_module": "@jupyter-widgets/base", 384 | "_model_module_version": "1.5.0", 385 | "_view_count": null, 386 | "_view_module_version": "1.2.0", 387 | "bar_color": null, 388 | "_model_module": "@jupyter-widgets/controls" 389 | } 390 | }, 391 | "2d671cb0ab744e509c98e0d1b2fcac17": { 392 | "model_module": "@jupyter-widgets/base", 393 | "model_name": "LayoutModel", 394 | "state": { 395 | "_view_name": "LayoutView", 396 | "grid_template_rows": null, 397 | "right": null, 398 | "justify_content": null, 399 | "_view_module": "@jupyter-widgets/base", 400 | "overflow": null, 401 | "_model_module_version": "1.2.0", 402 | "_view_count": null, 403 | "flex_flow": null, 404 | "width": null, 405 | "min_width": null, 406 | "border": null, 407 | "align_items": null, 408 | "bottom": null, 409 | "_model_module": "@jupyter-widgets/base", 410 | "top": null, 411 | "grid_column": null, 412 | "overflow_y": null, 413 | "overflow_x": null, 414 | "grid_auto_flow": null, 415 | "grid_area": null, 416 | "grid_template_columns": null, 417 | "flex": null, 418 | "_model_name": "LayoutModel", 419 | "justify_items": null, 420 | "grid_row": null, 421 | "max_height": null, 422 | "align_content": null, 423 | "visibility": null, 424 | "align_self": null, 425 | "height": null, 426 | "min_height": null, 427 | "padding": null, 428 | "grid_auto_rows": null, 429 | "grid_gap": null, 430 | "max_width": null, 431 | "order": null, 432 | "_view_module_version": "1.2.0", 433 | "grid_template_areas": null, 434 | "object_position": null, 435 | "object_fit": null, 436 | "grid_auto_columns": null, 437 | "margin": null, 438 | "display": null, 439 | "left": null 440 | } 441 | }, 442 | "79ff0547109643e8b8002a9bf1e6ebcc": { 443 | "model_module": "@jupyter-widgets/controls", 444 | "model_name": "DescriptionStyleModel", 445 | "state": { 446 | "_view_name": "StyleView", 447 | "_model_name": "DescriptionStyleModel", 448 | "description_width": "", 449 | "_view_module": "@jupyter-widgets/base", 450 | "_model_module_version": "1.5.0", 451 | "_view_count": null, 452 | "_view_module_version": "1.2.0", 453 | "_model_module": "@jupyter-widgets/controls" 454 | } 455 | }, 456 | "8c382843a4f9475fb1a084b08ea0a7ae": { 457 | "model_module": "@jupyter-widgets/base", 458 | "model_name": "LayoutModel", 459 | "state": { 460 | "_view_name": "LayoutView", 461 | "grid_template_rows": null, 462 | "right": null, 463 | "justify_content": null, 464 | "_view_module": "@jupyter-widgets/base", 465 | "overflow": null, 466 | "_model_module_version": "1.2.0", 467 | "_view_count": null, 468 | "flex_flow": null, 469 | "width": null, 470 | "min_width": null, 471 | "border": null, 472 | "align_items": null, 473 | "bottom": null, 474 | "_model_module": "@jupyter-widgets/base", 475 | "top": null, 476 | "grid_column": null, 477 | "overflow_y": null, 478 | "overflow_x": null, 479 | "grid_auto_flow": null, 480 | "grid_area": null, 481 | "grid_template_columns": null, 482 | "flex": null, 483 | "_model_name": "LayoutModel", 484 | "justify_items": null, 485 | "grid_row": null, 486 | "max_height": null, 487 | "align_content": null, 488 | "visibility": null, 489 | "align_self": null, 490 | "height": null, 491 | "min_height": null, 492 | "padding": null, 493 | "grid_auto_rows": null, 494 | "grid_gap": null, 495 | "max_width": null, 496 | "order": null, 497 | "_view_module_version": "1.2.0", 498 | "grid_template_areas": null, 499 | "object_position": null, 500 | "object_fit": null, 501 | "grid_auto_columns": null, 502 | "margin": null, 503 | "display": null, 504 | "left": null 505 | } 506 | } 507 | } 508 | } 509 | }, 510 | "cells": [ 511 | { 512 | "cell_type": "markdown", 513 | "metadata": { 514 | "id": "JuiT6O71HUAy", 515 | "colab_type": "text" 516 | }, 517 | "source": [ 518 | "# Initial Setup" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "metadata": { 524 | "id": "FgWG4d-K3xRt", 525 | "colab_type": "code", 526 | "outputId": "140af918-8fea-4aac-d44d-80191d1a2877", 527 | "colab": { 528 | "base_uri": "https://localhost:8080/", 529 | "height": 35 530 | } 531 | }, 532 | "source": [ 533 | "import tensorflow as tf\n", 534 | "print(tf.__version__)" 535 | ], 536 | "execution_count": 1, 537 | "outputs": [ 538 | { 539 | "output_type": "stream", 540 | "text": [ 541 | "2.2.0\n" 542 | ], 543 | "name": "stdout" 544 | } 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "metadata": { 550 | "id": "gtxvkdsm338L", 551 | "colab_type": "code", 552 | "colab": { 553 | "base_uri": "https://localhost:8080/", 554 | "height": 1000 555 | }, 556 | "outputId": "56474563-95cf-45af-a1c0-af5a447d8fa7" 557 | }, 558 | "source": [ 559 | "!pip install wandb\n", 560 | "import wandb\n", 561 | "wandb.login()" 562 | ], 563 | "execution_count": 2, 564 | "outputs": [ 565 | { 566 | "output_type": "stream", 567 | "text": [ 568 | "Collecting wandb\n", 569 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d1/c7/8bf2c62c3f133f45e135a8a116e4e0f162043248e3db54de30996eaf1a8a/wandb-0.8.36-py2.py3-none-any.whl (1.4MB)\n", 570 | "\u001b[K |████████████████████████████████| 1.4MB 4.8MB/s \n", 571 | "\u001b[?25hCollecting sentry-sdk>=0.4.0\n", 572 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1b/95/9a20eebcedab2c1c63fad59fe19a0469edfc2a25b8576497e8084629c2ff/sentry_sdk-0.14.4-py2.py3-none-any.whl (104kB)\n", 573 | "\u001b[K |████████████████████████████████| 112kB 29.8MB/s \n", 574 | "\u001b[?25hRequirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n", 575 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n", 576 | "Collecting configparser>=3.8.1\n", 577 | " Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n", 578 | "Requirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n", 579 | "Collecting gql==0.2.0\n", 580 | " Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n", 581 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n", 582 | "Collecting watchdog>=0.8.3\n", 583 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)\n", 584 | "\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n", 585 | "\u001b[?25hRequirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n", 586 | "Collecting subprocess32>=3.5.3\n", 587 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n", 588 | "\u001b[K |████████████████████████████████| 102kB 10.6MB/s \n", 589 | "\u001b[?25hCollecting shortuuid>=0.5.0\n", 590 | " Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n", 591 | "Collecting GitPython>=1.0.0\n", 592 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/44/33/917e6fde1cad13daa7053f39b7c8af3be287314f75f1b1ea8d3fe37a8571/GitPython-3.1.2-py3-none-any.whl (451kB)\n", 593 | "\u001b[K |████████████████████████████████| 460kB 23.4MB/s \n", 594 | "\u001b[?25hRequirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n", 595 | "Collecting docker-pycreds>=0.4.0\n", 596 | " Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n", 597 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n", 598 | "Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n", 599 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.4.5.1)\n", 600 | "Collecting graphql-core<2,>=0.5.0\n", 601 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n", 602 | "\u001b[K |████████████████████████████████| 71kB 9.4MB/s \n", 603 | "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n", 604 | "Collecting pathtools>=0.1.1\n", 605 | " Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n", 606 | "Collecting gitdb<5,>=4.0.1\n", 607 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", 608 | "\u001b[K |████████████████████████████████| 71kB 11.3MB/s \n", 609 | "\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n", 610 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n", 611 | "Collecting smmap<4,>=3.0.1\n", 612 | " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", 613 | "Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools\n", 614 | " Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 615 | " Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=bd431a7c7f187272be19a9836b1e7616f86e03d34c0a4de4b4f6810140d9de42\n", 616 | " Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n", 617 | " Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 618 | " Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=c9fed8e42385522813f4c68267548bd110f3f869f0abc97b6fe982888c9aabf6\n", 619 | " Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e\n", 620 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 621 | " Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=0dc181757046168379a173ffa1dddec577f6fe87f03c62125ecfbdb3ab14bd4d\n", 622 | " Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n", 623 | " Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 624 | " Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=e4e49bb64de6665036a068a85324e126b1c685b4540c8caa5d1ffb89190c6b50\n", 625 | " Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n", 626 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 627 | " Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=de6e98828915023d5115bd39dd81f22676351c6a70e74bf13ce914cacd0cee48\n", 628 | " Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n", 629 | "Successfully built gql watchdog subprocess32 graphql-core pathtools\n", 630 | "Installing collected packages: sentry-sdk, configparser, graphql-core, gql, pathtools, watchdog, subprocess32, shortuuid, smmap, gitdb, GitPython, docker-pycreds, wandb\n", 631 | "Successfully installed GitPython-3.1.2 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.4 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.8.36 watchdog-0.10.2\n" 632 | ], 633 | "name": "stdout" 634 | }, 635 | { 636 | "output_type": "display_data", 637 | "data": { 638 | "application/javascript": [ 639 | "\n", 640 | " window._wandbApiKey = new Promise((resolve, reject) => {\n", 641 | " function loadScript(url) {\n", 642 | " return new Promise(function(resolve, reject) {\n", 643 | " let newScript = document.createElement(\"script\");\n", 644 | " newScript.onerror = reject;\n", 645 | " newScript.onload = resolve;\n", 646 | " document.body.appendChild(newScript);\n", 647 | " newScript.src = url;\n", 648 | " });\n", 649 | " }\n", 650 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", 651 | " const iframe = document.createElement('iframe')\n", 652 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n", 653 | " document.body.appendChild(iframe)\n", 654 | " const handshake = new Postmate({\n", 655 | " container: iframe,\n", 656 | " url: 'https://app.wandb.ai/authorize'\n", 657 | " });\n", 658 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", 659 | " handshake.then(function(child) {\n", 660 | " child.on('authorize', data => {\n", 661 | " clearTimeout(timeout)\n", 662 | " resolve(data)\n", 663 | " });\n", 664 | " });\n", 665 | " })\n", 666 | " });\n", 667 | " " 668 | ], 669 | "text/plain": [ 670 | "" 671 | ] 672 | }, 673 | "metadata": { 674 | "tags": [] 675 | } 676 | }, 677 | { 678 | "output_type": "stream", 679 | "text": [ 680 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n" 681 | ], 682 | "name": "stderr" 683 | }, 684 | { 685 | "output_type": "stream", 686 | "text": [ 687 | "API Key: ··········\n" 688 | ], 689 | "name": "stdout" 690 | }, 691 | { 692 | "output_type": "stream", 693 | "text": [ 694 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" 695 | ], 696 | "name": "stderr" 697 | }, 698 | { 699 | "output_type": "execute_result", 700 | "data": { 701 | "text/plain": [ 702 | "True" 703 | ] 704 | }, 705 | "metadata": { 706 | "tags": [] 707 | }, 708 | "execution_count": 2 709 | } 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "metadata": { 715 | "id": "E8cv8vit3ydm", 716 | "colab_type": "code", 717 | "colab": {} 718 | }, 719 | "source": [ 720 | "from tensorflow.keras.layers import *\n", 721 | "from tensorflow.keras.models import *\n", 722 | "from wandb.keras import WandbCallback\n", 723 | "import tensorflow_datasets as tfds\n", 724 | "import matplotlib.pyplot as plt\n", 725 | "import numpy as np\n", 726 | "import time\n", 727 | "import cv2\n", 728 | "from tqdm.notebook import tqdm\n", 729 | "from imutils import paths\n", 730 | "tf.random.set_seed(666)\n", 731 | "np.random.seed(666)\n", 732 | "\n", 733 | "tfds.disable_progress_bar()" 734 | ], 735 | "execution_count": 0, 736 | "outputs": [] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "metadata": { 741 | "id": "ebM6CaFsHcya", 742 | "colab_type": "text" 743 | }, 744 | "source": [ 745 | "# Imagenet Subset " 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "metadata": { 751 | "id": "4vPz9Alk31qZ", 752 | "colab_type": "code", 753 | "outputId": "14ab0cb9-a64e-4ff7-cdf4-f873752cec89", 754 | "colab": { 755 | "base_uri": "https://localhost:8080/", 756 | "height": 106 757 | } 758 | }, 759 | "source": [ 760 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n" 761 | ], 762 | "execution_count": 4, 763 | "outputs": [ 764 | { 765 | "output_type": "stream", 766 | "text": [ 767 | "Cloning into 'imagenet-5-categories'...\n", 768 | "remote: Enumerating objects: 1532, done.\u001b[K\n", 769 | "remote: Total 1532 (delta 0), reused 0 (delta 0), pack-reused 1532\u001b[K\n", 770 | "Receiving objects: 100% (1532/1532), 88.56 MiB | 51.26 MiB/s, done.\n", 771 | "Resolving deltas: 100% (1/1), done.\n" 772 | ], 773 | "name": "stdout" 774 | } 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "id": "5vVPALgj4Ogg", 781 | "colab_type": "code", 782 | "colab": {} 783 | }, 784 | "source": [ 785 | "# Train and test image paths\n", 786 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n", 787 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n" 788 | ], 789 | "execution_count": 0, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "metadata": { 795 | "id": "YM_w3yZi4RQf", 796 | "colab_type": "code", 797 | "colab": {} 798 | }, 799 | "source": [ 800 | "def prepare_images(image_paths):\n", 801 | " images = []\n", 802 | " labels = []\n", 803 | "\n", 804 | " for image in tqdm(image_paths):\n", 805 | " image_pixels = plt.imread(image)\n", 806 | " image_pixels = cv2.resize(image_pixels, (128,128))\n", 807 | " image_pixels = image_pixels/255.\n", 808 | "\n", 809 | " label = image.split(\"/\")[2].split(\"_\")[0]\n", 810 | "\n", 811 | " images.append(image_pixels)\n", 812 | " labels.append(label)\n", 813 | "\n", 814 | " images = np.array(images)\n", 815 | " labels = np.array(labels)\n", 816 | "\n", 817 | " print(images.shape, labels.shape)\n", 818 | "\n", 819 | " return images, labels" 820 | ], 821 | "execution_count": 0, 822 | "outputs": [] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "metadata": { 827 | "id": "KeNWTqpG4b0e", 828 | "colab_type": "code", 829 | "outputId": "66c2ec48-d6df-4f2b-a638-db3adb024e79", 830 | "colab": { 831 | "base_uri": "https://localhost:8080/", 832 | "height": 152, 833 | "referenced_widgets": [ 834 | "2a8b0bfc58804efea484cd4cd71dd00f", 835 | "126f1d14800843aa9332f00aead8395b", 836 | "805d93bf8f4446d3871cd93392101584", 837 | "33d042e9035942359936f5d4bd6a108a", 838 | "05b5252d8859432c8724c8fb9636ef33", 839 | "77179de1c044473abb9da855d5ab8eca", 840 | "eb72220a381347d1bce57c5706554460", 841 | "fe2c9cb4e208436db3c15d342f8d08e7", 842 | "c7e81694439c41408bb1090cb82cd2b1", 843 | "7683c42a14e14250bf6e375a42028b87", 844 | "aac930db883d4977a9f3f08b1a0fa7a5", 845 | "3862f5cde98f49f783d8fc54ada40a11", 846 | "e9b5eef0caa148a8aaaa2ed576fe68cc", 847 | "2d671cb0ab744e509c98e0d1b2fcac17", 848 | "79ff0547109643e8b8002a9bf1e6ebcc", 849 | "8c382843a4f9475fb1a084b08ea0a7ae" 850 | ] 851 | } 852 | }, 853 | "source": [ 854 | "X_train, y_train = prepare_images(train_images)\n", 855 | "X_test, y_test = prepare_images(test_images)" 856 | ], 857 | "execution_count": 7, 858 | "outputs": [ 859 | { 860 | "output_type": "display_data", 861 | "data": { 862 | "application/vnd.jupyter.widget-view+json": { 863 | "model_id": "2a8b0bfc58804efea484cd4cd71dd00f", 864 | "version_minor": 0, 865 | "version_major": 2 866 | }, 867 | "text/plain": [ 868 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))" 869 | ] 870 | }, 871 | "metadata": { 872 | "tags": [] 873 | } 874 | }, 875 | { 876 | "output_type": "stream", 877 | "text": [ 878 | "\n", 879 | "(1250, 128, 128, 3) (1250,)\n" 880 | ], 881 | "name": "stdout" 882 | }, 883 | { 884 | "output_type": "display_data", 885 | "data": { 886 | "application/vnd.jupyter.widget-view+json": { 887 | "model_id": "c7e81694439c41408bb1090cb82cd2b1", 888 | "version_minor": 0, 889 | "version_major": 2 890 | }, 891 | "text/plain": [ 892 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))" 893 | ] 894 | }, 895 | "metadata": { 896 | "tags": [] 897 | } 898 | }, 899 | { 900 | "output_type": "stream", 901 | "text": [ 902 | "\n", 903 | "(250, 128, 128, 3) (250,)\n" 904 | ], 905 | "name": "stdout" 906 | } 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "metadata": { 912 | "id": "qGdwXZJk4eDH", 913 | "colab_type": "code", 914 | "colab": {} 915 | }, 916 | "source": [ 917 | "from sklearn import preprocessing\n", 918 | "le = preprocessing.LabelEncoder()\n", 919 | "y_train_enc = le.fit_transform(y_train)\n", 920 | "y_test_enc = le.transform(y_test)\n" 921 | ], 922 | "execution_count": 0, 923 | "outputs": [] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "metadata": { 928 | "id": "nmX3x8wE4zBo", 929 | "colab_type": "code", 930 | "colab": {} 931 | }, 932 | "source": [ 933 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n", 934 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))" 935 | ], 936 | "execution_count": 0, 937 | "outputs": [] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "metadata": { 942 | "id": "9yBgpLe443a-", 943 | "colab_type": "code", 944 | "colab": {} 945 | }, 946 | "source": [ 947 | "@tf.function\n", 948 | "def aug(image, label):\n", 949 | " x=tf.image.random_brightness(image,max_delta=0)\n", 950 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n", 951 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n", 952 | " x = tf.image.random_hue(x, max_delta=0.4)\n", 953 | " x = tf.clip_by_value(x, 0, 1)\n", 954 | "\n", 955 | " return x, label" 956 | ], 957 | "execution_count": 0, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "metadata": { 963 | "id": "icCj5VGk45ce", 964 | "colab_type": "code", 965 | "colab": {} 966 | }, 967 | "source": [ 968 | "IMG_SHAPE = 128\n", 969 | "BS = 64\n", 970 | "AUTO = tf.data.experimental.AUTOTUNE\n", 971 | "train_ds = (\n", 972 | " train_ds\n", 973 | " .shuffle(100)\n", 974 | " .batch(BS)\n", 975 | " .map(aug, num_parallel_calls=AUTO)\n", 976 | " .prefetch(AUTO)\n", 977 | ")\n", 978 | "validation_ds = (\n", 979 | " validation_ds\n", 980 | " .shuffle(100)\n", 981 | " .batch(BS)\n", 982 | " .prefetch(AUTO)\n", 983 | ")" 984 | ], 985 | "execution_count": 0, 986 | "outputs": [] 987 | }, 988 | { 989 | "cell_type": "markdown", 990 | "metadata": { 991 | "id": "tkxjWEeIHrCf", 992 | "colab_type": "text" 993 | }, 994 | "source": [ 995 | "# Model building and training wih Adam\n" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "metadata": { 1001 | "id": "umbRNW-A4755", 1002 | "colab_type": "code", 1003 | "colab": {} 1004 | }, 1005 | "source": [ 1006 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n", 1007 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])" 1008 | ], 1009 | "execution_count": 0, 1010 | "outputs": [] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "metadata": { 1015 | "id": "WVilaFIu5Hft", 1016 | "colab_type": "code", 1017 | "colab": {} 1018 | }, 1019 | "source": [ 1020 | "decay_steps = 1000\n", 1021 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n", 1022 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n", 1023 | "\n", 1024 | "model.compile(optimizer=tf.keras.optimizers.Adam(lr_decayed_fn),\n", 1025 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 1026 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])" 1027 | ], 1028 | "execution_count": 0, 1029 | "outputs": [] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "metadata": { 1034 | "id": "X3PxnSYd5x2W", 1035 | "colab_type": "code", 1036 | "colab": {} 1037 | }, 1038 | "source": [ 1039 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n", 1040 | "\trestore_best_weights=True, verbose=2)" 1041 | ], 1042 | "execution_count": 0, 1043 | "outputs": [] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "metadata": { 1048 | "id": "S5lMKwCQ54KX", 1049 | "colab_type": "code", 1050 | "outputId": "be75d42a-5d10-4dd9-bf66-a780f8a52c0c", 1051 | "colab": { 1052 | "base_uri": "https://localhost:8080/", 1053 | "height": 210 1054 | } 1055 | }, 1056 | "source": [ 1057 | "import time\n", 1058 | "import wandb\n", 1059 | "\n", 1060 | "wandb.init(entity='authors',project='scl',id=\"ADA\")\n", 1061 | "start = time.time()\n", 1062 | "model.fit(train_ds,\n", 1063 | " validation_data=validation_ds,\n", 1064 | " epochs=50,\n", 1065 | " callbacks=[wandb.keras.WandbCallback(), es])\n", 1066 | "end = time.time()\n", 1067 | "wandb.log({\"training_time\": end - start})" 1068 | ], 1069 | "execution_count": 24, 1070 | "outputs": [ 1071 | { 1072 | "output_type": "display_data", 1073 | "data": { 1074 | "text/html": [ 1075 | "\n", 1076 | " Logging results to Weights & Biases (Documentation).
\n", 1077 | " Project page: https://app.wandb.ai/authors/scl
\n", 1078 | " Run page: https://app.wandb.ai/authors/scl/runs/ADA
\n", 1079 | " " 1080 | ], 1081 | "text/plain": [ 1082 | "" 1083 | ] 1084 | }, 1085 | "metadata": { 1086 | "tags": [] 1087 | } 1088 | }, 1089 | { 1090 | "output_type": "stream", 1091 | "text": [ 1092 | "Epoch 1/50\n", 1093 | "20/20 [==============================] - 5s 234ms/step - loss: 1.6424 - sparse_categorical_accuracy: 0.2544 - val_loss: 1.6301 - val_sparse_categorical_accuracy: 0.2480\n", 1094 | "Epoch 2/50\n", 1095 | "20/20 [==============================] - 3s 143ms/step - loss: 1.5880 - sparse_categorical_accuracy: 0.3136 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n", 1096 | "Epoch 3/50\n", 1097 | "20/20 [==============================] - ETA: 0s - loss: 1.5839 - sparse_categorical_accuracy: 0.3200Restoring model weights from the end of the best epoch.\n", 1098 | "20/20 [==============================] - 3s 145ms/step - loss: 1.5839 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n", 1099 | "Epoch 00003: early stopping\n" 1100 | ], 1101 | "name": "stdout" 1102 | } 1103 | ] 1104 | }, 1105 | { 1106 | "cell_type": "code", 1107 | "metadata": { 1108 | "id": "edc3Fu_C6AJO", 1109 | "colab_type": "code", 1110 | "colab": {} 1111 | }, 1112 | "source": [ 1113 | "model.save_weights(\"full_supervised_learning.h5\")" 1114 | ], 1115 | "execution_count": 0, 1116 | "outputs": [] 1117 | }, 1118 | { 1119 | "cell_type": "code", 1120 | "metadata": { 1121 | "id": "wOPN7pPwBN0V", 1122 | "colab_type": "code", 1123 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd", 1124 | "colab": { 1125 | "base_uri": "https://localhost:8080/", 1126 | "height": 34 1127 | } 1128 | }, 1129 | "source": [ 1130 | "wandb.save(\"full_supervised_learning.h5\")" 1131 | ], 1132 | "execution_count": 0, 1133 | "outputs": [ 1134 | { 1135 | "output_type": "execute_result", 1136 | "data": { 1137 | "text/plain": [ 1138 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']" 1139 | ] 1140 | }, 1141 | "metadata": { 1142 | "tags": [] 1143 | }, 1144 | "execution_count": 94 1145 | } 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "metadata": { 1151 | "id": "FNk0NhWFBSYe", 1152 | "colab_type": "code", 1153 | "colab": {} 1154 | }, 1155 | "source": [ 1156 | "" 1157 | ], 1158 | "execution_count": 0, 1159 | "outputs": [] 1160 | } 1161 | ] 1162 | } -------------------------------------------------------------------------------- /ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_SGD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Fully_Supervised_Training_IMGNET_subset_SGD.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU", 15 | "widgets": { 16 | "application/vnd.jupyter.widget-state+json": { 17 | "2a8b0bfc58804efea484cd4cd71dd00f": { 18 | "model_module": "@jupyter-widgets/controls", 19 | "model_name": "HBoxModel", 20 | "state": { 21 | "_view_name": "HBoxView", 22 | "_dom_classes": [], 23 | "_model_name": "HBoxModel", 24 | "_view_module": "@jupyter-widgets/controls", 25 | "_model_module_version": "1.5.0", 26 | "_view_count": null, 27 | "_view_module_version": "1.5.0", 28 | "box_style": "", 29 | "layout": "IPY_MODEL_126f1d14800843aa9332f00aead8395b", 30 | "_model_module": "@jupyter-widgets/controls", 31 | "children": [ 32 | "IPY_MODEL_805d93bf8f4446d3871cd93392101584", 33 | "IPY_MODEL_33d042e9035942359936f5d4bd6a108a" 34 | ] 35 | } 36 | }, 37 | "126f1d14800843aa9332f00aead8395b": { 38 | "model_module": "@jupyter-widgets/base", 39 | "model_name": "LayoutModel", 40 | "state": { 41 | "_view_name": "LayoutView", 42 | "grid_template_rows": null, 43 | "right": null, 44 | "justify_content": null, 45 | "_view_module": "@jupyter-widgets/base", 46 | "overflow": null, 47 | "_model_module_version": "1.2.0", 48 | "_view_count": null, 49 | "flex_flow": null, 50 | "width": null, 51 | "min_width": null, 52 | "border": null, 53 | "align_items": null, 54 | "bottom": null, 55 | "_model_module": "@jupyter-widgets/base", 56 | "top": null, 57 | "grid_column": null, 58 | "overflow_y": null, 59 | "overflow_x": null, 60 | "grid_auto_flow": null, 61 | "grid_area": null, 62 | "grid_template_columns": null, 63 | "flex": null, 64 | "_model_name": "LayoutModel", 65 | "justify_items": null, 66 | "grid_row": null, 67 | "max_height": null, 68 | "align_content": null, 69 | "visibility": null, 70 | "align_self": null, 71 | "height": null, 72 | "min_height": null, 73 | "padding": null, 74 | "grid_auto_rows": null, 75 | "grid_gap": null, 76 | "max_width": null, 77 | "order": null, 78 | "_view_module_version": "1.2.0", 79 | "grid_template_areas": null, 80 | "object_position": null, 81 | "object_fit": null, 82 | "grid_auto_columns": null, 83 | "margin": null, 84 | "display": null, 85 | "left": null 86 | } 87 | }, 88 | "805d93bf8f4446d3871cd93392101584": { 89 | "model_module": "@jupyter-widgets/controls", 90 | "model_name": "FloatProgressModel", 91 | "state": { 92 | "_view_name": "ProgressView", 93 | "style": "IPY_MODEL_05b5252d8859432c8724c8fb9636ef33", 94 | "_dom_classes": [], 95 | "description": "100%", 96 | "_model_name": "FloatProgressModel", 97 | "bar_style": "success", 98 | "max": 1250, 99 | "_view_module": "@jupyter-widgets/controls", 100 | "_model_module_version": "1.5.0", 101 | "value": 1250, 102 | "_view_count": null, 103 | "_view_module_version": "1.5.0", 104 | "orientation": "horizontal", 105 | "min": 0, 106 | "description_tooltip": null, 107 | "_model_module": "@jupyter-widgets/controls", 108 | "layout": "IPY_MODEL_77179de1c044473abb9da855d5ab8eca" 109 | } 110 | }, 111 | "33d042e9035942359936f5d4bd6a108a": { 112 | "model_module": "@jupyter-widgets/controls", 113 | "model_name": "HTMLModel", 114 | "state": { 115 | "_view_name": "HTMLView", 116 | "style": "IPY_MODEL_eb72220a381347d1bce57c5706554460", 117 | "_dom_classes": [], 118 | "description": "", 119 | "_model_name": "HTMLModel", 120 | "placeholder": "​", 121 | "_view_module": "@jupyter-widgets/controls", 122 | "_model_module_version": "1.5.0", 123 | "value": " 1250/1250 [00:06<00:00, 187.65it/s]", 124 | "_view_count": null, 125 | "_view_module_version": "1.5.0", 126 | "description_tooltip": null, 127 | "_model_module": "@jupyter-widgets/controls", 128 | "layout": "IPY_MODEL_fe2c9cb4e208436db3c15d342f8d08e7" 129 | } 130 | }, 131 | "05b5252d8859432c8724c8fb9636ef33": { 132 | "model_module": "@jupyter-widgets/controls", 133 | "model_name": "ProgressStyleModel", 134 | "state": { 135 | "_view_name": "StyleView", 136 | "_model_name": "ProgressStyleModel", 137 | "description_width": "initial", 138 | "_view_module": "@jupyter-widgets/base", 139 | "_model_module_version": "1.5.0", 140 | "_view_count": null, 141 | "_view_module_version": "1.2.0", 142 | "bar_color": null, 143 | "_model_module": "@jupyter-widgets/controls" 144 | } 145 | }, 146 | "77179de1c044473abb9da855d5ab8eca": { 147 | "model_module": "@jupyter-widgets/base", 148 | "model_name": "LayoutModel", 149 | "state": { 150 | "_view_name": "LayoutView", 151 | "grid_template_rows": null, 152 | "right": null, 153 | "justify_content": null, 154 | "_view_module": "@jupyter-widgets/base", 155 | "overflow": null, 156 | "_model_module_version": "1.2.0", 157 | "_view_count": null, 158 | "flex_flow": null, 159 | "width": null, 160 | "min_width": null, 161 | "border": null, 162 | "align_items": null, 163 | "bottom": null, 164 | "_model_module": "@jupyter-widgets/base", 165 | "top": null, 166 | "grid_column": null, 167 | "overflow_y": null, 168 | "overflow_x": null, 169 | "grid_auto_flow": null, 170 | "grid_area": null, 171 | "grid_template_columns": null, 172 | "flex": null, 173 | "_model_name": "LayoutModel", 174 | "justify_items": null, 175 | "grid_row": null, 176 | "max_height": null, 177 | "align_content": null, 178 | "visibility": null, 179 | "align_self": null, 180 | "height": null, 181 | "min_height": null, 182 | "padding": null, 183 | "grid_auto_rows": null, 184 | "grid_gap": null, 185 | "max_width": null, 186 | "order": null, 187 | "_view_module_version": "1.2.0", 188 | "grid_template_areas": null, 189 | "object_position": null, 190 | "object_fit": null, 191 | "grid_auto_columns": null, 192 | "margin": null, 193 | "display": null, 194 | "left": null 195 | } 196 | }, 197 | "eb72220a381347d1bce57c5706554460": { 198 | "model_module": "@jupyter-widgets/controls", 199 | "model_name": "DescriptionStyleModel", 200 | "state": { 201 | "_view_name": "StyleView", 202 | "_model_name": "DescriptionStyleModel", 203 | "description_width": "", 204 | "_view_module": "@jupyter-widgets/base", 205 | "_model_module_version": "1.5.0", 206 | "_view_count": null, 207 | "_view_module_version": "1.2.0", 208 | "_model_module": "@jupyter-widgets/controls" 209 | } 210 | }, 211 | "fe2c9cb4e208436db3c15d342f8d08e7": { 212 | "model_module": "@jupyter-widgets/base", 213 | "model_name": "LayoutModel", 214 | "state": { 215 | "_view_name": "LayoutView", 216 | "grid_template_rows": null, 217 | "right": null, 218 | "justify_content": null, 219 | "_view_module": "@jupyter-widgets/base", 220 | "overflow": null, 221 | "_model_module_version": "1.2.0", 222 | "_view_count": null, 223 | "flex_flow": null, 224 | "width": null, 225 | "min_width": null, 226 | "border": null, 227 | "align_items": null, 228 | "bottom": null, 229 | "_model_module": "@jupyter-widgets/base", 230 | "top": null, 231 | "grid_column": null, 232 | "overflow_y": null, 233 | "overflow_x": null, 234 | "grid_auto_flow": null, 235 | "grid_area": null, 236 | "grid_template_columns": null, 237 | "flex": null, 238 | "_model_name": "LayoutModel", 239 | "justify_items": null, 240 | "grid_row": null, 241 | "max_height": null, 242 | "align_content": null, 243 | "visibility": null, 244 | "align_self": null, 245 | "height": null, 246 | "min_height": null, 247 | "padding": null, 248 | "grid_auto_rows": null, 249 | "grid_gap": null, 250 | "max_width": null, 251 | "order": null, 252 | "_view_module_version": "1.2.0", 253 | "grid_template_areas": null, 254 | "object_position": null, 255 | "object_fit": null, 256 | "grid_auto_columns": null, 257 | "margin": null, 258 | "display": null, 259 | "left": null 260 | } 261 | }, 262 | "c7e81694439c41408bb1090cb82cd2b1": { 263 | "model_module": "@jupyter-widgets/controls", 264 | "model_name": "HBoxModel", 265 | "state": { 266 | "_view_name": "HBoxView", 267 | "_dom_classes": [], 268 | "_model_name": "HBoxModel", 269 | "_view_module": "@jupyter-widgets/controls", 270 | "_model_module_version": "1.5.0", 271 | "_view_count": null, 272 | "_view_module_version": "1.5.0", 273 | "box_style": "", 274 | "layout": "IPY_MODEL_7683c42a14e14250bf6e375a42028b87", 275 | "_model_module": "@jupyter-widgets/controls", 276 | "children": [ 277 | "IPY_MODEL_aac930db883d4977a9f3f08b1a0fa7a5", 278 | "IPY_MODEL_3862f5cde98f49f783d8fc54ada40a11" 279 | ] 280 | } 281 | }, 282 | "7683c42a14e14250bf6e375a42028b87": { 283 | "model_module": "@jupyter-widgets/base", 284 | "model_name": "LayoutModel", 285 | "state": { 286 | "_view_name": "LayoutView", 287 | "grid_template_rows": null, 288 | "right": null, 289 | "justify_content": null, 290 | "_view_module": "@jupyter-widgets/base", 291 | "overflow": null, 292 | "_model_module_version": "1.2.0", 293 | "_view_count": null, 294 | "flex_flow": null, 295 | "width": null, 296 | "min_width": null, 297 | "border": null, 298 | "align_items": null, 299 | "bottom": null, 300 | "_model_module": "@jupyter-widgets/base", 301 | "top": null, 302 | "grid_column": null, 303 | "overflow_y": null, 304 | "overflow_x": null, 305 | "grid_auto_flow": null, 306 | "grid_area": null, 307 | "grid_template_columns": null, 308 | "flex": null, 309 | "_model_name": "LayoutModel", 310 | "justify_items": null, 311 | "grid_row": null, 312 | "max_height": null, 313 | "align_content": null, 314 | "visibility": null, 315 | "align_self": null, 316 | "height": null, 317 | "min_height": null, 318 | "padding": null, 319 | "grid_auto_rows": null, 320 | "grid_gap": null, 321 | "max_width": null, 322 | "order": null, 323 | "_view_module_version": "1.2.0", 324 | "grid_template_areas": null, 325 | "object_position": null, 326 | "object_fit": null, 327 | "grid_auto_columns": null, 328 | "margin": null, 329 | "display": null, 330 | "left": null 331 | } 332 | }, 333 | "aac930db883d4977a9f3f08b1a0fa7a5": { 334 | "model_module": "@jupyter-widgets/controls", 335 | "model_name": "FloatProgressModel", 336 | "state": { 337 | "_view_name": "ProgressView", 338 | "style": "IPY_MODEL_e9b5eef0caa148a8aaaa2ed576fe68cc", 339 | "_dom_classes": [], 340 | "description": "100%", 341 | "_model_name": "FloatProgressModel", 342 | "bar_style": "success", 343 | "max": 250, 344 | "_view_module": "@jupyter-widgets/controls", 345 | "_model_module_version": "1.5.0", 346 | "value": 250, 347 | "_view_count": null, 348 | "_view_module_version": "1.5.0", 349 | "orientation": "horizontal", 350 | "min": 0, 351 | "description_tooltip": null, 352 | "_model_module": "@jupyter-widgets/controls", 353 | "layout": "IPY_MODEL_2d671cb0ab744e509c98e0d1b2fcac17" 354 | } 355 | }, 356 | "3862f5cde98f49f783d8fc54ada40a11": { 357 | "model_module": "@jupyter-widgets/controls", 358 | "model_name": "HTMLModel", 359 | "state": { 360 | "_view_name": "HTMLView", 361 | "style": "IPY_MODEL_79ff0547109643e8b8002a9bf1e6ebcc", 362 | "_dom_classes": [], 363 | "description": "", 364 | "_model_name": "HTMLModel", 365 | "placeholder": "​", 366 | "_view_module": "@jupyter-widgets/controls", 367 | "_model_module_version": "1.5.0", 368 | "value": " 250/250 [00:01<00:00, 221.46it/s]", 369 | "_view_count": null, 370 | "_view_module_version": "1.5.0", 371 | "description_tooltip": null, 372 | "_model_module": "@jupyter-widgets/controls", 373 | "layout": "IPY_MODEL_8c382843a4f9475fb1a084b08ea0a7ae" 374 | } 375 | }, 376 | "e9b5eef0caa148a8aaaa2ed576fe68cc": { 377 | "model_module": "@jupyter-widgets/controls", 378 | "model_name": "ProgressStyleModel", 379 | "state": { 380 | "_view_name": "StyleView", 381 | "_model_name": "ProgressStyleModel", 382 | "description_width": "initial", 383 | "_view_module": "@jupyter-widgets/base", 384 | "_model_module_version": "1.5.0", 385 | "_view_count": null, 386 | "_view_module_version": "1.2.0", 387 | "bar_color": null, 388 | "_model_module": "@jupyter-widgets/controls" 389 | } 390 | }, 391 | "2d671cb0ab744e509c98e0d1b2fcac17": { 392 | "model_module": "@jupyter-widgets/base", 393 | "model_name": "LayoutModel", 394 | "state": { 395 | "_view_name": "LayoutView", 396 | "grid_template_rows": null, 397 | "right": null, 398 | "justify_content": null, 399 | "_view_module": "@jupyter-widgets/base", 400 | "overflow": null, 401 | "_model_module_version": "1.2.0", 402 | "_view_count": null, 403 | "flex_flow": null, 404 | "width": null, 405 | "min_width": null, 406 | "border": null, 407 | "align_items": null, 408 | "bottom": null, 409 | "_model_module": "@jupyter-widgets/base", 410 | "top": null, 411 | "grid_column": null, 412 | "overflow_y": null, 413 | "overflow_x": null, 414 | "grid_auto_flow": null, 415 | "grid_area": null, 416 | "grid_template_columns": null, 417 | "flex": null, 418 | "_model_name": "LayoutModel", 419 | "justify_items": null, 420 | "grid_row": null, 421 | "max_height": null, 422 | "align_content": null, 423 | "visibility": null, 424 | "align_self": null, 425 | "height": null, 426 | "min_height": null, 427 | "padding": null, 428 | "grid_auto_rows": null, 429 | "grid_gap": null, 430 | "max_width": null, 431 | "order": null, 432 | "_view_module_version": "1.2.0", 433 | "grid_template_areas": null, 434 | "object_position": null, 435 | "object_fit": null, 436 | "grid_auto_columns": null, 437 | "margin": null, 438 | "display": null, 439 | "left": null 440 | } 441 | }, 442 | "79ff0547109643e8b8002a9bf1e6ebcc": { 443 | "model_module": "@jupyter-widgets/controls", 444 | "model_name": "DescriptionStyleModel", 445 | "state": { 446 | "_view_name": "StyleView", 447 | "_model_name": "DescriptionStyleModel", 448 | "description_width": "", 449 | "_view_module": "@jupyter-widgets/base", 450 | "_model_module_version": "1.5.0", 451 | "_view_count": null, 452 | "_view_module_version": "1.2.0", 453 | "_model_module": "@jupyter-widgets/controls" 454 | } 455 | }, 456 | "8c382843a4f9475fb1a084b08ea0a7ae": { 457 | "model_module": "@jupyter-widgets/base", 458 | "model_name": "LayoutModel", 459 | "state": { 460 | "_view_name": "LayoutView", 461 | "grid_template_rows": null, 462 | "right": null, 463 | "justify_content": null, 464 | "_view_module": "@jupyter-widgets/base", 465 | "overflow": null, 466 | "_model_module_version": "1.2.0", 467 | "_view_count": null, 468 | "flex_flow": null, 469 | "width": null, 470 | "min_width": null, 471 | "border": null, 472 | "align_items": null, 473 | "bottom": null, 474 | "_model_module": "@jupyter-widgets/base", 475 | "top": null, 476 | "grid_column": null, 477 | "overflow_y": null, 478 | "overflow_x": null, 479 | "grid_auto_flow": null, 480 | "grid_area": null, 481 | "grid_template_columns": null, 482 | "flex": null, 483 | "_model_name": "LayoutModel", 484 | "justify_items": null, 485 | "grid_row": null, 486 | "max_height": null, 487 | "align_content": null, 488 | "visibility": null, 489 | "align_self": null, 490 | "height": null, 491 | "min_height": null, 492 | "padding": null, 493 | "grid_auto_rows": null, 494 | "grid_gap": null, 495 | "max_width": null, 496 | "order": null, 497 | "_view_module_version": "1.2.0", 498 | "grid_template_areas": null, 499 | "object_position": null, 500 | "object_fit": null, 501 | "grid_auto_columns": null, 502 | "margin": null, 503 | "display": null, 504 | "left": null 505 | } 506 | } 507 | } 508 | } 509 | }, 510 | "cells": [ 511 | { 512 | "cell_type": "markdown", 513 | "metadata": { 514 | "id": "JuiT6O71HUAy", 515 | "colab_type": "text" 516 | }, 517 | "source": [ 518 | "# Initial Setup" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "metadata": { 524 | "id": "FgWG4d-K3xRt", 525 | "colab_type": "code", 526 | "outputId": "140af918-8fea-4aac-d44d-80191d1a2877", 527 | "colab": { 528 | "base_uri": "https://localhost:8080/", 529 | "height": 35 530 | } 531 | }, 532 | "source": [ 533 | "import tensorflow as tf\n", 534 | "print(tf.__version__)" 535 | ], 536 | "execution_count": 1, 537 | "outputs": [ 538 | { 539 | "output_type": "stream", 540 | "text": [ 541 | "2.2.0\n" 542 | ], 543 | "name": "stdout" 544 | } 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "metadata": { 550 | "id": "gtxvkdsm338L", 551 | "colab_type": "code", 552 | "colab": { 553 | "base_uri": "https://localhost:8080/", 554 | "height": 1000 555 | }, 556 | "outputId": "56474563-95cf-45af-a1c0-af5a447d8fa7" 557 | }, 558 | "source": [ 559 | "!pip install wandb\n", 560 | "import wandb\n", 561 | "wandb.login()" 562 | ], 563 | "execution_count": 2, 564 | "outputs": [ 565 | { 566 | "output_type": "stream", 567 | "text": [ 568 | "Collecting wandb\n", 569 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d1/c7/8bf2c62c3f133f45e135a8a116e4e0f162043248e3db54de30996eaf1a8a/wandb-0.8.36-py2.py3-none-any.whl (1.4MB)\n", 570 | "\u001b[K |████████████████████████████████| 1.4MB 4.8MB/s \n", 571 | "\u001b[?25hCollecting sentry-sdk>=0.4.0\n", 572 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1b/95/9a20eebcedab2c1c63fad59fe19a0469edfc2a25b8576497e8084629c2ff/sentry_sdk-0.14.4-py2.py3-none-any.whl (104kB)\n", 573 | "\u001b[K |████████████████████████████████| 112kB 29.8MB/s \n", 574 | "\u001b[?25hRequirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n", 575 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n", 576 | "Collecting configparser>=3.8.1\n", 577 | " Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n", 578 | "Requirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n", 579 | "Collecting gql==0.2.0\n", 580 | " Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n", 581 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n", 582 | "Collecting watchdog>=0.8.3\n", 583 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)\n", 584 | "\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n", 585 | "\u001b[?25hRequirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n", 586 | "Collecting subprocess32>=3.5.3\n", 587 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n", 588 | "\u001b[K |████████████████████████████████| 102kB 10.6MB/s \n", 589 | "\u001b[?25hCollecting shortuuid>=0.5.0\n", 590 | " Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n", 591 | "Collecting GitPython>=1.0.0\n", 592 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/44/33/917e6fde1cad13daa7053f39b7c8af3be287314f75f1b1ea8d3fe37a8571/GitPython-3.1.2-py3-none-any.whl (451kB)\n", 593 | "\u001b[K |████████████████████████████████| 460kB 23.4MB/s \n", 594 | "\u001b[?25hRequirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n", 595 | "Collecting docker-pycreds>=0.4.0\n", 596 | " Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n", 597 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n", 598 | "Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n", 599 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.4.5.1)\n", 600 | "Collecting graphql-core<2,>=0.5.0\n", 601 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n", 602 | "\u001b[K |████████████████████████████████| 71kB 9.4MB/s \n", 603 | "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n", 604 | "Collecting pathtools>=0.1.1\n", 605 | " Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n", 606 | "Collecting gitdb<5,>=4.0.1\n", 607 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", 608 | "\u001b[K |████████████████████████████████| 71kB 11.3MB/s \n", 609 | "\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n", 610 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n", 611 | "Collecting smmap<4,>=3.0.1\n", 612 | " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", 613 | "Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools\n", 614 | " Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 615 | " Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=bd431a7c7f187272be19a9836b1e7616f86e03d34c0a4de4b4f6810140d9de42\n", 616 | " Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n", 617 | " Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 618 | " Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=c9fed8e42385522813f4c68267548bd110f3f869f0abc97b6fe982888c9aabf6\n", 619 | " Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e\n", 620 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 621 | " Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=0dc181757046168379a173ffa1dddec577f6fe87f03c62125ecfbdb3ab14bd4d\n", 622 | " Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n", 623 | " Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 624 | " Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=e4e49bb64de6665036a068a85324e126b1c685b4540c8caa5d1ffb89190c6b50\n", 625 | " Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n", 626 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 627 | " Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=de6e98828915023d5115bd39dd81f22676351c6a70e74bf13ce914cacd0cee48\n", 628 | " Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n", 629 | "Successfully built gql watchdog subprocess32 graphql-core pathtools\n", 630 | "Installing collected packages: sentry-sdk, configparser, graphql-core, gql, pathtools, watchdog, subprocess32, shortuuid, smmap, gitdb, GitPython, docker-pycreds, wandb\n", 631 | "Successfully installed GitPython-3.1.2 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.4 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.8.36 watchdog-0.10.2\n" 632 | ], 633 | "name": "stdout" 634 | }, 635 | { 636 | "output_type": "display_data", 637 | "data": { 638 | "application/javascript": [ 639 | "\n", 640 | " window._wandbApiKey = new Promise((resolve, reject) => {\n", 641 | " function loadScript(url) {\n", 642 | " return new Promise(function(resolve, reject) {\n", 643 | " let newScript = document.createElement(\"script\");\n", 644 | " newScript.onerror = reject;\n", 645 | " newScript.onload = resolve;\n", 646 | " document.body.appendChild(newScript);\n", 647 | " newScript.src = url;\n", 648 | " });\n", 649 | " }\n", 650 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", 651 | " const iframe = document.createElement('iframe')\n", 652 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n", 653 | " document.body.appendChild(iframe)\n", 654 | " const handshake = new Postmate({\n", 655 | " container: iframe,\n", 656 | " url: 'https://app.wandb.ai/authorize'\n", 657 | " });\n", 658 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", 659 | " handshake.then(function(child) {\n", 660 | " child.on('authorize', data => {\n", 661 | " clearTimeout(timeout)\n", 662 | " resolve(data)\n", 663 | " });\n", 664 | " });\n", 665 | " })\n", 666 | " });\n", 667 | " " 668 | ], 669 | "text/plain": [ 670 | "" 671 | ] 672 | }, 673 | "metadata": { 674 | "tags": [] 675 | } 676 | }, 677 | { 678 | "output_type": "stream", 679 | "text": [ 680 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n" 681 | ], 682 | "name": "stderr" 683 | }, 684 | { 685 | "output_type": "stream", 686 | "text": [ 687 | "API Key: ··········\n" 688 | ], 689 | "name": "stdout" 690 | }, 691 | { 692 | "output_type": "stream", 693 | "text": [ 694 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" 695 | ], 696 | "name": "stderr" 697 | }, 698 | { 699 | "output_type": "execute_result", 700 | "data": { 701 | "text/plain": [ 702 | "True" 703 | ] 704 | }, 705 | "metadata": { 706 | "tags": [] 707 | }, 708 | "execution_count": 2 709 | } 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "metadata": { 715 | "id": "E8cv8vit3ydm", 716 | "colab_type": "code", 717 | "colab": {} 718 | }, 719 | "source": [ 720 | "from tensorflow.keras.layers import *\n", 721 | "from tensorflow.keras.models import *\n", 722 | "from wandb.keras import WandbCallback\n", 723 | "import tensorflow_datasets as tfds\n", 724 | "import matplotlib.pyplot as plt\n", 725 | "import numpy as np\n", 726 | "import time\n", 727 | "import cv2\n", 728 | "from tqdm.notebook import tqdm\n", 729 | "from imutils import paths\n", 730 | "tf.random.set_seed(666)\n", 731 | "np.random.seed(666)\n", 732 | "\n", 733 | "tfds.disable_progress_bar()" 734 | ], 735 | "execution_count": 0, 736 | "outputs": [] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "metadata": { 741 | "id": "ebM6CaFsHcya", 742 | "colab_type": "text" 743 | }, 744 | "source": [ 745 | "# Imagenet Subset " 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "metadata": { 751 | "id": "4vPz9Alk31qZ", 752 | "colab_type": "code", 753 | "outputId": "14ab0cb9-a64e-4ff7-cdf4-f873752cec89", 754 | "colab": { 755 | "base_uri": "https://localhost:8080/", 756 | "height": 106 757 | } 758 | }, 759 | "source": [ 760 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n" 761 | ], 762 | "execution_count": 4, 763 | "outputs": [ 764 | { 765 | "output_type": "stream", 766 | "text": [ 767 | "Cloning into 'imagenet-5-categories'...\n", 768 | "remote: Enumerating objects: 1532, done.\u001b[K\n", 769 | "remote: Total 1532 (delta 0), reused 0 (delta 0), pack-reused 1532\u001b[K\n", 770 | "Receiving objects: 100% (1532/1532), 88.56 MiB | 51.26 MiB/s, done.\n", 771 | "Resolving deltas: 100% (1/1), done.\n" 772 | ], 773 | "name": "stdout" 774 | } 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "id": "5vVPALgj4Ogg", 781 | "colab_type": "code", 782 | "colab": {} 783 | }, 784 | "source": [ 785 | "# Train and test image paths\n", 786 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n", 787 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n" 788 | ], 789 | "execution_count": 0, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "metadata": { 795 | "id": "YM_w3yZi4RQf", 796 | "colab_type": "code", 797 | "colab": {} 798 | }, 799 | "source": [ 800 | "def prepare_images(image_paths):\n", 801 | " images = []\n", 802 | " labels = []\n", 803 | "\n", 804 | " for image in tqdm(image_paths):\n", 805 | " image_pixels = plt.imread(image)\n", 806 | " image_pixels = cv2.resize(image_pixels, (128,128))\n", 807 | " image_pixels = image_pixels/255.\n", 808 | "\n", 809 | " label = image.split(\"/\")[2].split(\"_\")[0]\n", 810 | "\n", 811 | " images.append(image_pixels)\n", 812 | " labels.append(label)\n", 813 | "\n", 814 | " images = np.array(images)\n", 815 | " labels = np.array(labels)\n", 816 | "\n", 817 | " print(images.shape, labels.shape)\n", 818 | "\n", 819 | " return images, labels" 820 | ], 821 | "execution_count": 0, 822 | "outputs": [] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "metadata": { 827 | "id": "KeNWTqpG4b0e", 828 | "colab_type": "code", 829 | "outputId": "66c2ec48-d6df-4f2b-a638-db3adb024e79", 830 | "colab": { 831 | "base_uri": "https://localhost:8080/", 832 | "height": 152, 833 | "referenced_widgets": [ 834 | "2a8b0bfc58804efea484cd4cd71dd00f", 835 | "126f1d14800843aa9332f00aead8395b", 836 | "805d93bf8f4446d3871cd93392101584", 837 | "33d042e9035942359936f5d4bd6a108a", 838 | "05b5252d8859432c8724c8fb9636ef33", 839 | "77179de1c044473abb9da855d5ab8eca", 840 | "eb72220a381347d1bce57c5706554460", 841 | "fe2c9cb4e208436db3c15d342f8d08e7", 842 | "c7e81694439c41408bb1090cb82cd2b1", 843 | "7683c42a14e14250bf6e375a42028b87", 844 | "aac930db883d4977a9f3f08b1a0fa7a5", 845 | "3862f5cde98f49f783d8fc54ada40a11", 846 | "e9b5eef0caa148a8aaaa2ed576fe68cc", 847 | "2d671cb0ab744e509c98e0d1b2fcac17", 848 | "79ff0547109643e8b8002a9bf1e6ebcc", 849 | "8c382843a4f9475fb1a084b08ea0a7ae" 850 | ] 851 | } 852 | }, 853 | "source": [ 854 | "X_train, y_train = prepare_images(train_images)\n", 855 | "X_test, y_test = prepare_images(test_images)" 856 | ], 857 | "execution_count": 7, 858 | "outputs": [ 859 | { 860 | "output_type": "display_data", 861 | "data": { 862 | "application/vnd.jupyter.widget-view+json": { 863 | "model_id": "2a8b0bfc58804efea484cd4cd71dd00f", 864 | "version_minor": 0, 865 | "version_major": 2 866 | }, 867 | "text/plain": [ 868 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))" 869 | ] 870 | }, 871 | "metadata": { 872 | "tags": [] 873 | } 874 | }, 875 | { 876 | "output_type": "stream", 877 | "text": [ 878 | "\n", 879 | "(1250, 128, 128, 3) (1250,)\n" 880 | ], 881 | "name": "stdout" 882 | }, 883 | { 884 | "output_type": "display_data", 885 | "data": { 886 | "application/vnd.jupyter.widget-view+json": { 887 | "model_id": "c7e81694439c41408bb1090cb82cd2b1", 888 | "version_minor": 0, 889 | "version_major": 2 890 | }, 891 | "text/plain": [ 892 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))" 893 | ] 894 | }, 895 | "metadata": { 896 | "tags": [] 897 | } 898 | }, 899 | { 900 | "output_type": "stream", 901 | "text": [ 902 | "\n", 903 | "(250, 128, 128, 3) (250,)\n" 904 | ], 905 | "name": "stdout" 906 | } 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "metadata": { 912 | "id": "qGdwXZJk4eDH", 913 | "colab_type": "code", 914 | "colab": {} 915 | }, 916 | "source": [ 917 | "from sklearn import preprocessing\n", 918 | "le = preprocessing.LabelEncoder()\n", 919 | "y_train_enc = le.fit_transform(y_train)\n", 920 | "y_test_enc = le.transform(y_test)\n" 921 | ], 922 | "execution_count": 0, 923 | "outputs": [] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "metadata": { 928 | "id": "nmX3x8wE4zBo", 929 | "colab_type": "code", 930 | "colab": {} 931 | }, 932 | "source": [ 933 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n", 934 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))" 935 | ], 936 | "execution_count": 0, 937 | "outputs": [] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "metadata": { 942 | "id": "9yBgpLe443a-", 943 | "colab_type": "code", 944 | "colab": {} 945 | }, 946 | "source": [ 947 | "@tf.function\n", 948 | "def aug(image, label):\n", 949 | " x=tf.image.random_brightness(image,max_delta=0)\n", 950 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n", 951 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n", 952 | " x = tf.image.random_hue(x, max_delta=0.4)\n", 953 | " x = tf.clip_by_value(x, 0, 1)\n", 954 | "\n", 955 | " return x, label" 956 | ], 957 | "execution_count": 0, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "metadata": { 963 | "id": "icCj5VGk45ce", 964 | "colab_type": "code", 965 | "colab": {} 966 | }, 967 | "source": [ 968 | "IMG_SHAPE = 128\n", 969 | "BS = 64\n", 970 | "AUTO = tf.data.experimental.AUTOTUNE\n", 971 | "train_ds = (\n", 972 | " train_ds\n", 973 | " .shuffle(100)\n", 974 | " .batch(BS)\n", 975 | " .map(aug, num_parallel_calls=AUTO)\n", 976 | " .prefetch(AUTO)\n", 977 | ")\n", 978 | "validation_ds = (\n", 979 | " validation_ds\n", 980 | " .shuffle(100)\n", 981 | " .batch(BS)\n", 982 | " .prefetch(AUTO)\n", 983 | ")" 984 | ], 985 | "execution_count": 0, 986 | "outputs": [] 987 | }, 988 | { 989 | "cell_type": "markdown", 990 | "metadata": { 991 | "id": "tkxjWEeIHrCf", 992 | "colab_type": "text" 993 | }, 994 | "source": [ 995 | "# Model building and training wih SGD\n" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "metadata": { 1001 | "id": "umbRNW-A4755", 1002 | "colab_type": "code", 1003 | "colab": {} 1004 | }, 1005 | "source": [ 1006 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n", 1007 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])" 1008 | ], 1009 | "execution_count": 0, 1010 | "outputs": [] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "metadata": { 1015 | "id": "WVilaFIu5Hft", 1016 | "colab_type": "code", 1017 | "colab": {} 1018 | }, 1019 | "source": [ 1020 | "decay_steps = 1000\n", 1021 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n", 1022 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n", 1023 | "\n", 1024 | "model.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn),\n", 1025 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 1026 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])" 1027 | ], 1028 | "execution_count": 0, 1029 | "outputs": [] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "metadata": { 1034 | "id": "X3PxnSYd5x2W", 1035 | "colab_type": "code", 1036 | "colab": {} 1037 | }, 1038 | "source": [ 1039 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n", 1040 | "\trestore_best_weights=True, verbose=2)" 1041 | ], 1042 | "execution_count": 0, 1043 | "outputs": [] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "metadata": { 1048 | "id": "S5lMKwCQ54KX", 1049 | "colab_type": "code", 1050 | "outputId": "4104b50e-107a-43c5-fda0-6f864532d24d", 1051 | "colab": { 1052 | "base_uri": "https://localhost:8080/", 1053 | "height": 210 1054 | } 1055 | }, 1056 | "source": [ 1057 | "import time\n", 1058 | "import wandb\n", 1059 | "\n", 1060 | "wandb.init(entity='authors',project='scl',id='cr_Entropy_SGD')\n", 1061 | "start = time.time()\n", 1062 | "model.fit(train_ds,\n", 1063 | " validation_data=validation_ds,\n", 1064 | " epochs=50,\n", 1065 | " callbacks=[wandb.keras.WandbCallback(), es])\n", 1066 | "end = time.time()\n", 1067 | "wandb.log({\"training_time\": end - start})" 1068 | ], 1069 | "execution_count": 18, 1070 | "outputs": [ 1071 | { 1072 | "output_type": "display_data", 1073 | "data": { 1074 | "text/html": [ 1075 | "\n", 1076 | " Logging results to Weights & Biases (Documentation).
\n", 1077 | " Project page: https://app.wandb.ai/authors/scl
\n", 1078 | " Run page: https://app.wandb.ai/authors/scl/runs/cr_Entropy_SGD
\n", 1079 | " " 1080 | ], 1081 | "text/plain": [ 1082 | "" 1083 | ] 1084 | }, 1085 | "metadata": { 1086 | "tags": [] 1087 | } 1088 | }, 1089 | { 1090 | "output_type": "stream", 1091 | "text": [ 1092 | "Epoch 1/50\n", 1093 | "20/20 [==============================] - 4s 200ms/step - loss: 1.5290 - sparse_categorical_accuracy: 0.3736 - val_loss: 1.6414 - val_sparse_categorical_accuracy: 0.2640\n", 1094 | "Epoch 2/50\n", 1095 | "20/20 [==============================] - 3s 141ms/step - loss: 1.5102 - sparse_categorical_accuracy: 0.3896 - val_loss: 1.6797 - val_sparse_categorical_accuracy: 0.2240\n", 1096 | "Epoch 3/50\n", 1097 | "20/20 [==============================] - ETA: 0s - loss: 1.5144 - sparse_categorical_accuracy: 0.3888Restoring model weights from the end of the best epoch.\n", 1098 | "20/20 [==============================] - 3s 151ms/step - loss: 1.5144 - sparse_categorical_accuracy: 0.3888 - val_loss: 1.6787 - val_sparse_categorical_accuracy: 0.2280\n", 1099 | "Epoch 00003: early stopping\n" 1100 | ], 1101 | "name": "stdout" 1102 | } 1103 | ] 1104 | }, 1105 | { 1106 | "cell_type": "code", 1107 | "metadata": { 1108 | "id": "edc3Fu_C6AJO", 1109 | "colab_type": "code", 1110 | "colab": {} 1111 | }, 1112 | "source": [ 1113 | "model.save_weights(\"full_supervised_learning.h5\")" 1114 | ], 1115 | "execution_count": 0, 1116 | "outputs": [] 1117 | }, 1118 | { 1119 | "cell_type": "code", 1120 | "metadata": { 1121 | "id": "wOPN7pPwBN0V", 1122 | "colab_type": "code", 1123 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd", 1124 | "colab": { 1125 | "base_uri": "https://localhost:8080/", 1126 | "height": 34 1127 | } 1128 | }, 1129 | "source": [ 1130 | "wandb.save(\"full_supervised_learning.h5\")" 1131 | ], 1132 | "execution_count": 0, 1133 | "outputs": [ 1134 | { 1135 | "output_type": "execute_result", 1136 | "data": { 1137 | "text/plain": [ 1138 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']" 1139 | ] 1140 | }, 1141 | "metadata": { 1142 | "tags": [] 1143 | }, 1144 | "execution_count": 94 1145 | } 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "metadata": { 1151 | "id": "FNk0NhWFBSYe", 1152 | "colab_type": "code", 1153 | "colab": {} 1154 | }, 1155 | "source": [ 1156 | "" 1157 | ], 1158 | "execution_count": 0, 1159 | "outputs": [] 1160 | } 1161 | ] 1162 | } --------------------------------------------------------------------------------