├── assets
├── framework.png
├── embeddings_vis.png
└── performance_summary.png
├── README.md
└── ImageNet_Subset
├── Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb
├── Fully_Supervised_Training_IMGNET_subset_Adam.ipynb
└── Fully_Supervised_Training_IMGNET_subset_SGD.ipynb
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/framework.png
--------------------------------------------------------------------------------
/assets/embeddings_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/embeddings_vis.png
--------------------------------------------------------------------------------
/assets/performance_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2/HEAD/assets/performance_summary.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Supervised-Contrastive-Learning-in-TensorFlow-2
2 |
3 | (Collaboratively done by [Shweta Shaw](https://www.linkedin.com/in/sweta-shaw-797540159/) and myself)
4 |
5 | Implements the ideas presented in [Supervised Contrastive Learning](https://arxiv.org/pdf/2004.11362v1.pdf) by Khosla et al. The authors propose a two-stage framework to enhance the performance of image classifiers and also achieves SoTA results.
6 |
7 | 
8 |
9 | (Figures gathered from the paper)
10 |
11 | A detailed discussion of the paper and the results of our experiments are available here in [this report](https://app.wandb.ai/authors/scl/reports/Improving-Image-Classification-with-Supervised-Contrastive-Learning--VmlldzoxMzQwNzE).
12 |
13 | This repository consists of the notebooks (runnable on Colab) showing the experiments we have done.
14 |
15 | ## Acknowledgements
16 | - [Contrastive loss for supervised classification](https://towardsdatascience.com/contrastive-loss-for-supervised-classification-224ae35692e7).
17 | - [Prannay Khosla](https://twitter.com/PrannayKhosla) for sharing his comments on our work.
18 |
19 | ## About the notebooks
20 | ```
21 | ├── Flowers
22 | │ ├── Contrastive_Training_Flowers.ipynb
23 | │ ├── Contrastive_Training_Flowers_Augmentation.ipynb
24 | │ ├── Fully_Supervised_Training_Flowers.ipynb
25 | │ └── Fully_Supervised_Training_Flowers_Augmentation.ipynb
26 | ├── ImageNet_Subset
27 | │ ├── Contrastive_Training_Imagenet_subset_Adam.ipynb
28 | │ ├── Contrastive_Training_Imagenet_subset_RMSprop.ipynb
29 | │ ├── Contrastive_Training_Imagenet_subset_SGD.ipynb
30 | │ ├── Fully_Supervised_Training_IMGNET_subset_Adam.ipynb
31 | │ ├── Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb
32 | │ └── Fully_Supervised_Training_IMGNET_subset_SGD.ipynb
33 | ├── Pets
34 | │ ├── Contrastive_Training_Pets.ipynb
35 | │ └── Fully_Supervised_Training_Pets.ipynb
36 | ├── Visualization_ImageNet_subset.ipynb
37 | ├── Visualization_Pets.ipynb
38 | ```
39 |
40 | - `Contrastive_Training_*.ipynb` notebooks show the supervised contrastive framework proposed in the paper.
41 | - `Fully_Supervised_Training_*.ipynb` notebooks show the typical fully supervised training with different datasets.
42 | - `Visualization_ImageNet_*.ipynb` notebooks show the visualizations of the embeddings learned by the supervised contrastive learning framework.
43 |
44 | ## About the datasets
45 | - Flowers
46 | - Cats vs. Dogs
47 | - ImageNet Subset (https://github.com/thunderInfy/imagenet-5-categories)
48 |
49 | ## Things to note
50 | - The authors used AutoAugment in the paper. However, we used simple augmnetation operations which worked for the datasets we tried. Note that, there's no augmentation for the Pets dataset as we got pretty good results on that one even without any data augmentation.
51 | - LARS optimizer was used in the paper, however we used Adam. We have also shown the effect of different optimizers like SGD and RMSProp along with learning rate schedules.
52 |
53 | ## Results
54 |
55 | 
56 |
57 | The above plots are from the experiments conducted on the **Pets** dataset. More results from the other two datasets have been discussed in the above-mentioned report and can be found here: https://app.wandb.ai/authors/scl.
58 |
59 | ## Visualization of the embeddings learned by supervised contrastive learning
60 |
61 | 
62 |
63 | ## About executing the notebooks
64 |
65 | If you go to any of the notebooks listed in the repository and use an extension like "Open notebook in Google Colab" to open it, you should be able to run the experiments right off the bat.
66 |
67 | ## About the library versions
68 |
69 | At the time of performing the experiments, we used TensorFlow 2.2. We specifically did not denote the versions of the other libraries. All of our experiments were performed on [Google Colab](http://colab.research.google.com/).
70 |
71 | ## Feedback
72 | Via GitHub issues
73 |
--------------------------------------------------------------------------------
/ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Fully_Supervised_Training_IMGNET_subset_RMSprop.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU",
15 | "widgets": {
16 | "application/vnd.jupyter.widget-state+json": {
17 | "0b4c5d36a97843e6bfc07d37a8e3cb6a": {
18 | "model_module": "@jupyter-widgets/controls",
19 | "model_name": "HBoxModel",
20 | "state": {
21 | "_view_name": "HBoxView",
22 | "_dom_classes": [],
23 | "_model_name": "HBoxModel",
24 | "_view_module": "@jupyter-widgets/controls",
25 | "_model_module_version": "1.5.0",
26 | "_view_count": null,
27 | "_view_module_version": "1.5.0",
28 | "box_style": "",
29 | "layout": "IPY_MODEL_e8ef5e7fafb142aaa238fd8ae119b315",
30 | "_model_module": "@jupyter-widgets/controls",
31 | "children": [
32 | "IPY_MODEL_fb1635d6c5da4d4893558f5d32c234f2",
33 | "IPY_MODEL_0dae8ce37add4e81825f319cb66620df"
34 | ]
35 | }
36 | },
37 | "e8ef5e7fafb142aaa238fd8ae119b315": {
38 | "model_module": "@jupyter-widgets/base",
39 | "model_name": "LayoutModel",
40 | "state": {
41 | "_view_name": "LayoutView",
42 | "grid_template_rows": null,
43 | "right": null,
44 | "justify_content": null,
45 | "_view_module": "@jupyter-widgets/base",
46 | "overflow": null,
47 | "_model_module_version": "1.2.0",
48 | "_view_count": null,
49 | "flex_flow": null,
50 | "width": null,
51 | "min_width": null,
52 | "border": null,
53 | "align_items": null,
54 | "bottom": null,
55 | "_model_module": "@jupyter-widgets/base",
56 | "top": null,
57 | "grid_column": null,
58 | "overflow_y": null,
59 | "overflow_x": null,
60 | "grid_auto_flow": null,
61 | "grid_area": null,
62 | "grid_template_columns": null,
63 | "flex": null,
64 | "_model_name": "LayoutModel",
65 | "justify_items": null,
66 | "grid_row": null,
67 | "max_height": null,
68 | "align_content": null,
69 | "visibility": null,
70 | "align_self": null,
71 | "height": null,
72 | "min_height": null,
73 | "padding": null,
74 | "grid_auto_rows": null,
75 | "grid_gap": null,
76 | "max_width": null,
77 | "order": null,
78 | "_view_module_version": "1.2.0",
79 | "grid_template_areas": null,
80 | "object_position": null,
81 | "object_fit": null,
82 | "grid_auto_columns": null,
83 | "margin": null,
84 | "display": null,
85 | "left": null
86 | }
87 | },
88 | "fb1635d6c5da4d4893558f5d32c234f2": {
89 | "model_module": "@jupyter-widgets/controls",
90 | "model_name": "FloatProgressModel",
91 | "state": {
92 | "_view_name": "ProgressView",
93 | "style": "IPY_MODEL_e80797a1b4e24f19a9ce53ce7e9e9299",
94 | "_dom_classes": [],
95 | "description": "100%",
96 | "_model_name": "FloatProgressModel",
97 | "bar_style": "success",
98 | "max": 1250,
99 | "_view_module": "@jupyter-widgets/controls",
100 | "_model_module_version": "1.5.0",
101 | "value": 1250,
102 | "_view_count": null,
103 | "_view_module_version": "1.5.0",
104 | "orientation": "horizontal",
105 | "min": 0,
106 | "description_tooltip": null,
107 | "_model_module": "@jupyter-widgets/controls",
108 | "layout": "IPY_MODEL_4d2932cd6daf495d972886f07bfcc227"
109 | }
110 | },
111 | "0dae8ce37add4e81825f319cb66620df": {
112 | "model_module": "@jupyter-widgets/controls",
113 | "model_name": "HTMLModel",
114 | "state": {
115 | "_view_name": "HTMLView",
116 | "style": "IPY_MODEL_717104f84c7044818e3f5e4346de1ec0",
117 | "_dom_classes": [],
118 | "description": "",
119 | "_model_name": "HTMLModel",
120 | "placeholder": "",
121 | "_view_module": "@jupyter-widgets/controls",
122 | "_model_module_version": "1.5.0",
123 | "value": " 1250/1250 [01:03<00:00, 19.77it/s]",
124 | "_view_count": null,
125 | "_view_module_version": "1.5.0",
126 | "description_tooltip": null,
127 | "_model_module": "@jupyter-widgets/controls",
128 | "layout": "IPY_MODEL_b0df6eda571748eaa741673ef3cf090f"
129 | }
130 | },
131 | "e80797a1b4e24f19a9ce53ce7e9e9299": {
132 | "model_module": "@jupyter-widgets/controls",
133 | "model_name": "ProgressStyleModel",
134 | "state": {
135 | "_view_name": "StyleView",
136 | "_model_name": "ProgressStyleModel",
137 | "description_width": "initial",
138 | "_view_module": "@jupyter-widgets/base",
139 | "_model_module_version": "1.5.0",
140 | "_view_count": null,
141 | "_view_module_version": "1.2.0",
142 | "bar_color": null,
143 | "_model_module": "@jupyter-widgets/controls"
144 | }
145 | },
146 | "4d2932cd6daf495d972886f07bfcc227": {
147 | "model_module": "@jupyter-widgets/base",
148 | "model_name": "LayoutModel",
149 | "state": {
150 | "_view_name": "LayoutView",
151 | "grid_template_rows": null,
152 | "right": null,
153 | "justify_content": null,
154 | "_view_module": "@jupyter-widgets/base",
155 | "overflow": null,
156 | "_model_module_version": "1.2.0",
157 | "_view_count": null,
158 | "flex_flow": null,
159 | "width": null,
160 | "min_width": null,
161 | "border": null,
162 | "align_items": null,
163 | "bottom": null,
164 | "_model_module": "@jupyter-widgets/base",
165 | "top": null,
166 | "grid_column": null,
167 | "overflow_y": null,
168 | "overflow_x": null,
169 | "grid_auto_flow": null,
170 | "grid_area": null,
171 | "grid_template_columns": null,
172 | "flex": null,
173 | "_model_name": "LayoutModel",
174 | "justify_items": null,
175 | "grid_row": null,
176 | "max_height": null,
177 | "align_content": null,
178 | "visibility": null,
179 | "align_self": null,
180 | "height": null,
181 | "min_height": null,
182 | "padding": null,
183 | "grid_auto_rows": null,
184 | "grid_gap": null,
185 | "max_width": null,
186 | "order": null,
187 | "_view_module_version": "1.2.0",
188 | "grid_template_areas": null,
189 | "object_position": null,
190 | "object_fit": null,
191 | "grid_auto_columns": null,
192 | "margin": null,
193 | "display": null,
194 | "left": null
195 | }
196 | },
197 | "717104f84c7044818e3f5e4346de1ec0": {
198 | "model_module": "@jupyter-widgets/controls",
199 | "model_name": "DescriptionStyleModel",
200 | "state": {
201 | "_view_name": "StyleView",
202 | "_model_name": "DescriptionStyleModel",
203 | "description_width": "",
204 | "_view_module": "@jupyter-widgets/base",
205 | "_model_module_version": "1.5.0",
206 | "_view_count": null,
207 | "_view_module_version": "1.2.0",
208 | "_model_module": "@jupyter-widgets/controls"
209 | }
210 | },
211 | "b0df6eda571748eaa741673ef3cf090f": {
212 | "model_module": "@jupyter-widgets/base",
213 | "model_name": "LayoutModel",
214 | "state": {
215 | "_view_name": "LayoutView",
216 | "grid_template_rows": null,
217 | "right": null,
218 | "justify_content": null,
219 | "_view_module": "@jupyter-widgets/base",
220 | "overflow": null,
221 | "_model_module_version": "1.2.0",
222 | "_view_count": null,
223 | "flex_flow": null,
224 | "width": null,
225 | "min_width": null,
226 | "border": null,
227 | "align_items": null,
228 | "bottom": null,
229 | "_model_module": "@jupyter-widgets/base",
230 | "top": null,
231 | "grid_column": null,
232 | "overflow_y": null,
233 | "overflow_x": null,
234 | "grid_auto_flow": null,
235 | "grid_area": null,
236 | "grid_template_columns": null,
237 | "flex": null,
238 | "_model_name": "LayoutModel",
239 | "justify_items": null,
240 | "grid_row": null,
241 | "max_height": null,
242 | "align_content": null,
243 | "visibility": null,
244 | "align_self": null,
245 | "height": null,
246 | "min_height": null,
247 | "padding": null,
248 | "grid_auto_rows": null,
249 | "grid_gap": null,
250 | "max_width": null,
251 | "order": null,
252 | "_view_module_version": "1.2.0",
253 | "grid_template_areas": null,
254 | "object_position": null,
255 | "object_fit": null,
256 | "grid_auto_columns": null,
257 | "margin": null,
258 | "display": null,
259 | "left": null
260 | }
261 | },
262 | "438699af2ed94a3eb67e4f408bd1e7a8": {
263 | "model_module": "@jupyter-widgets/controls",
264 | "model_name": "HBoxModel",
265 | "state": {
266 | "_view_name": "HBoxView",
267 | "_dom_classes": [],
268 | "_model_name": "HBoxModel",
269 | "_view_module": "@jupyter-widgets/controls",
270 | "_model_module_version": "1.5.0",
271 | "_view_count": null,
272 | "_view_module_version": "1.5.0",
273 | "box_style": "",
274 | "layout": "IPY_MODEL_d948c937f41a41a6b66587aaee7f2202",
275 | "_model_module": "@jupyter-widgets/controls",
276 | "children": [
277 | "IPY_MODEL_5777b6af6ec84082babcd7616c4e5c02",
278 | "IPY_MODEL_4c97103d541d4016be6c513edf6f4b4d"
279 | ]
280 | }
281 | },
282 | "d948c937f41a41a6b66587aaee7f2202": {
283 | "model_module": "@jupyter-widgets/base",
284 | "model_name": "LayoutModel",
285 | "state": {
286 | "_view_name": "LayoutView",
287 | "grid_template_rows": null,
288 | "right": null,
289 | "justify_content": null,
290 | "_view_module": "@jupyter-widgets/base",
291 | "overflow": null,
292 | "_model_module_version": "1.2.0",
293 | "_view_count": null,
294 | "flex_flow": null,
295 | "width": null,
296 | "min_width": null,
297 | "border": null,
298 | "align_items": null,
299 | "bottom": null,
300 | "_model_module": "@jupyter-widgets/base",
301 | "top": null,
302 | "grid_column": null,
303 | "overflow_y": null,
304 | "overflow_x": null,
305 | "grid_auto_flow": null,
306 | "grid_area": null,
307 | "grid_template_columns": null,
308 | "flex": null,
309 | "_model_name": "LayoutModel",
310 | "justify_items": null,
311 | "grid_row": null,
312 | "max_height": null,
313 | "align_content": null,
314 | "visibility": null,
315 | "align_self": null,
316 | "height": null,
317 | "min_height": null,
318 | "padding": null,
319 | "grid_auto_rows": null,
320 | "grid_gap": null,
321 | "max_width": null,
322 | "order": null,
323 | "_view_module_version": "1.2.0",
324 | "grid_template_areas": null,
325 | "object_position": null,
326 | "object_fit": null,
327 | "grid_auto_columns": null,
328 | "margin": null,
329 | "display": null,
330 | "left": null
331 | }
332 | },
333 | "5777b6af6ec84082babcd7616c4e5c02": {
334 | "model_module": "@jupyter-widgets/controls",
335 | "model_name": "FloatProgressModel",
336 | "state": {
337 | "_view_name": "ProgressView",
338 | "style": "IPY_MODEL_f5864dec9b1641d6a69059cdd0f7254e",
339 | "_dom_classes": [],
340 | "description": "100%",
341 | "_model_name": "FloatProgressModel",
342 | "bar_style": "success",
343 | "max": 250,
344 | "_view_module": "@jupyter-widgets/controls",
345 | "_model_module_version": "1.5.0",
346 | "value": 250,
347 | "_view_count": null,
348 | "_view_module_version": "1.5.0",
349 | "orientation": "horizontal",
350 | "min": 0,
351 | "description_tooltip": null,
352 | "_model_module": "@jupyter-widgets/controls",
353 | "layout": "IPY_MODEL_1b4aa502d7234a6c8d740ed7b1cb4325"
354 | }
355 | },
356 | "4c97103d541d4016be6c513edf6f4b4d": {
357 | "model_module": "@jupyter-widgets/controls",
358 | "model_name": "HTMLModel",
359 | "state": {
360 | "_view_name": "HTMLView",
361 | "style": "IPY_MODEL_9d2f753a736946b4939711ac254ca213",
362 | "_dom_classes": [],
363 | "description": "",
364 | "_model_name": "HTMLModel",
365 | "placeholder": "",
366 | "_view_module": "@jupyter-widgets/controls",
367 | "_model_module_version": "1.5.0",
368 | "value": " 250/250 [00:01<00:00, 216.34it/s]",
369 | "_view_count": null,
370 | "_view_module_version": "1.5.0",
371 | "description_tooltip": null,
372 | "_model_module": "@jupyter-widgets/controls",
373 | "layout": "IPY_MODEL_ec97824c3bd6422cbed5093a5db9acf6"
374 | }
375 | },
376 | "f5864dec9b1641d6a69059cdd0f7254e": {
377 | "model_module": "@jupyter-widgets/controls",
378 | "model_name": "ProgressStyleModel",
379 | "state": {
380 | "_view_name": "StyleView",
381 | "_model_name": "ProgressStyleModel",
382 | "description_width": "initial",
383 | "_view_module": "@jupyter-widgets/base",
384 | "_model_module_version": "1.5.0",
385 | "_view_count": null,
386 | "_view_module_version": "1.2.0",
387 | "bar_color": null,
388 | "_model_module": "@jupyter-widgets/controls"
389 | }
390 | },
391 | "1b4aa502d7234a6c8d740ed7b1cb4325": {
392 | "model_module": "@jupyter-widgets/base",
393 | "model_name": "LayoutModel",
394 | "state": {
395 | "_view_name": "LayoutView",
396 | "grid_template_rows": null,
397 | "right": null,
398 | "justify_content": null,
399 | "_view_module": "@jupyter-widgets/base",
400 | "overflow": null,
401 | "_model_module_version": "1.2.0",
402 | "_view_count": null,
403 | "flex_flow": null,
404 | "width": null,
405 | "min_width": null,
406 | "border": null,
407 | "align_items": null,
408 | "bottom": null,
409 | "_model_module": "@jupyter-widgets/base",
410 | "top": null,
411 | "grid_column": null,
412 | "overflow_y": null,
413 | "overflow_x": null,
414 | "grid_auto_flow": null,
415 | "grid_area": null,
416 | "grid_template_columns": null,
417 | "flex": null,
418 | "_model_name": "LayoutModel",
419 | "justify_items": null,
420 | "grid_row": null,
421 | "max_height": null,
422 | "align_content": null,
423 | "visibility": null,
424 | "align_self": null,
425 | "height": null,
426 | "min_height": null,
427 | "padding": null,
428 | "grid_auto_rows": null,
429 | "grid_gap": null,
430 | "max_width": null,
431 | "order": null,
432 | "_view_module_version": "1.2.0",
433 | "grid_template_areas": null,
434 | "object_position": null,
435 | "object_fit": null,
436 | "grid_auto_columns": null,
437 | "margin": null,
438 | "display": null,
439 | "left": null
440 | }
441 | },
442 | "9d2f753a736946b4939711ac254ca213": {
443 | "model_module": "@jupyter-widgets/controls",
444 | "model_name": "DescriptionStyleModel",
445 | "state": {
446 | "_view_name": "StyleView",
447 | "_model_name": "DescriptionStyleModel",
448 | "description_width": "",
449 | "_view_module": "@jupyter-widgets/base",
450 | "_model_module_version": "1.5.0",
451 | "_view_count": null,
452 | "_view_module_version": "1.2.0",
453 | "_model_module": "@jupyter-widgets/controls"
454 | }
455 | },
456 | "ec97824c3bd6422cbed5093a5db9acf6": {
457 | "model_module": "@jupyter-widgets/base",
458 | "model_name": "LayoutModel",
459 | "state": {
460 | "_view_name": "LayoutView",
461 | "grid_template_rows": null,
462 | "right": null,
463 | "justify_content": null,
464 | "_view_module": "@jupyter-widgets/base",
465 | "overflow": null,
466 | "_model_module_version": "1.2.0",
467 | "_view_count": null,
468 | "flex_flow": null,
469 | "width": null,
470 | "min_width": null,
471 | "border": null,
472 | "align_items": null,
473 | "bottom": null,
474 | "_model_module": "@jupyter-widgets/base",
475 | "top": null,
476 | "grid_column": null,
477 | "overflow_y": null,
478 | "overflow_x": null,
479 | "grid_auto_flow": null,
480 | "grid_area": null,
481 | "grid_template_columns": null,
482 | "flex": null,
483 | "_model_name": "LayoutModel",
484 | "justify_items": null,
485 | "grid_row": null,
486 | "max_height": null,
487 | "align_content": null,
488 | "visibility": null,
489 | "align_self": null,
490 | "height": null,
491 | "min_height": null,
492 | "padding": null,
493 | "grid_auto_rows": null,
494 | "grid_gap": null,
495 | "max_width": null,
496 | "order": null,
497 | "_view_module_version": "1.2.0",
498 | "grid_template_areas": null,
499 | "object_position": null,
500 | "object_fit": null,
501 | "grid_auto_columns": null,
502 | "margin": null,
503 | "display": null,
504 | "left": null
505 | }
506 | }
507 | }
508 | }
509 | },
510 | "cells": [
511 | {
512 | "cell_type": "markdown",
513 | "metadata": {
514 | "id": "JuiT6O71HUAy",
515 | "colab_type": "text"
516 | },
517 | "source": [
518 | "# Initial Setup"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "metadata": {
524 | "id": "FgWG4d-K3xRt",
525 | "colab_type": "code",
526 | "colab": {}
527 | },
528 | "source": [
529 | "import tensorflow as tf\n",
530 | "print(tf.__version__)"
531 | ],
532 | "execution_count": 0,
533 | "outputs": []
534 | },
535 | {
536 | "cell_type": "code",
537 | "metadata": {
538 | "id": "gtxvkdsm338L",
539 | "colab_type": "code",
540 | "colab": {}
541 | },
542 | "source": [
543 | "!pip install wandb\n",
544 | "import wandb\n",
545 | "wandb.login()"
546 | ],
547 | "execution_count": 0,
548 | "outputs": []
549 | },
550 | {
551 | "cell_type": "code",
552 | "metadata": {
553 | "id": "E8cv8vit3ydm",
554 | "colab_type": "code",
555 | "colab": {}
556 | },
557 | "source": [
558 | "from tensorflow.keras.layers import *\n",
559 | "from tensorflow.keras.models import *\n",
560 | "from wandb.keras import WandbCallback\n",
561 | "import tensorflow_datasets as tfds\n",
562 | "import matplotlib.pyplot as plt\n",
563 | "import numpy as np\n",
564 | "import time\n",
565 | "import cv2\n",
566 | "from tqdm.notebook import tqdm\n",
567 | "from imutils import paths\n",
568 | "tf.random.set_seed(666)\n",
569 | "np.random.seed(666)\n",
570 | "\n",
571 | "tfds.disable_progress_bar()"
572 | ],
573 | "execution_count": 0,
574 | "outputs": []
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "metadata": {
579 | "id": "ebM6CaFsHcya",
580 | "colab_type": "text"
581 | },
582 | "source": [
583 | "# Imagenet Subset "
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "metadata": {
589 | "id": "4vPz9Alk31qZ",
590 | "colab_type": "code",
591 | "outputId": "5961df4c-b0fd-4c31-99a6-3ecbffba9eec",
592 | "colab": {
593 | "base_uri": "https://localhost:8080/",
594 | "height": 34
595 | }
596 | },
597 | "source": [
598 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n"
599 | ],
600 | "execution_count": 0,
601 | "outputs": [
602 | {
603 | "output_type": "stream",
604 | "text": [
605 | "fatal: destination path 'imagenet-5-categories' already exists and is not an empty directory.\n"
606 | ],
607 | "name": "stdout"
608 | }
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "metadata": {
614 | "id": "5vVPALgj4Ogg",
615 | "colab_type": "code",
616 | "colab": {}
617 | },
618 | "source": [
619 | "# Train and test image paths\n",
620 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n",
621 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n"
622 | ],
623 | "execution_count": 0,
624 | "outputs": []
625 | },
626 | {
627 | "cell_type": "code",
628 | "metadata": {
629 | "id": "YM_w3yZi4RQf",
630 | "colab_type": "code",
631 | "colab": {}
632 | },
633 | "source": [
634 | "def prepare_images(image_paths):\n",
635 | " images = []\n",
636 | " labels = []\n",
637 | "\n",
638 | " for image in tqdm(image_paths):\n",
639 | " image_pixels = plt.imread(image)\n",
640 | " image_pixels = cv2.resize(image_pixels, (128,128))\n",
641 | " image_pixels = image_pixels/255.\n",
642 | "\n",
643 | " label = image.split(\"/\")[2].split(\"_\")[0]\n",
644 | "\n",
645 | " images.append(image_pixels)\n",
646 | " labels.append(label)\n",
647 | "\n",
648 | " images = np.array(images)\n",
649 | " labels = np.array(labels)\n",
650 | "\n",
651 | " print(images.shape, labels.shape)\n",
652 | "\n",
653 | " return images, labels"
654 | ],
655 | "execution_count": 0,
656 | "outputs": []
657 | },
658 | {
659 | "cell_type": "code",
660 | "metadata": {
661 | "id": "KeNWTqpG4b0e",
662 | "colab_type": "code",
663 | "outputId": "2b8f6cf4-97c6-4265-dbf2-9b6262043410",
664 | "colab": {
665 | "base_uri": "https://localhost:8080/",
666 | "height": 148,
667 | "referenced_widgets": [
668 | "0b4c5d36a97843e6bfc07d37a8e3cb6a",
669 | "e8ef5e7fafb142aaa238fd8ae119b315",
670 | "fb1635d6c5da4d4893558f5d32c234f2",
671 | "0dae8ce37add4e81825f319cb66620df",
672 | "e80797a1b4e24f19a9ce53ce7e9e9299",
673 | "4d2932cd6daf495d972886f07bfcc227",
674 | "717104f84c7044818e3f5e4346de1ec0",
675 | "b0df6eda571748eaa741673ef3cf090f",
676 | "438699af2ed94a3eb67e4f408bd1e7a8",
677 | "d948c937f41a41a6b66587aaee7f2202",
678 | "5777b6af6ec84082babcd7616c4e5c02",
679 | "4c97103d541d4016be6c513edf6f4b4d",
680 | "f5864dec9b1641d6a69059cdd0f7254e",
681 | "1b4aa502d7234a6c8d740ed7b1cb4325",
682 | "9d2f753a736946b4939711ac254ca213",
683 | "ec97824c3bd6422cbed5093a5db9acf6"
684 | ]
685 | }
686 | },
687 | "source": [
688 | "X_train, y_train = prepare_images(train_images)\n",
689 | "X_test, y_test = prepare_images(test_images)"
690 | ],
691 | "execution_count": 0,
692 | "outputs": [
693 | {
694 | "output_type": "display_data",
695 | "data": {
696 | "application/vnd.jupyter.widget-view+json": {
697 | "model_id": "0b4c5d36a97843e6bfc07d37a8e3cb6a",
698 | "version_minor": 0,
699 | "version_major": 2
700 | },
701 | "text/plain": [
702 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))"
703 | ]
704 | },
705 | "metadata": {
706 | "tags": []
707 | }
708 | },
709 | {
710 | "output_type": "stream",
711 | "text": [
712 | "\n",
713 | "(1250, 128, 128, 3) (1250,)\n"
714 | ],
715 | "name": "stdout"
716 | },
717 | {
718 | "output_type": "display_data",
719 | "data": {
720 | "application/vnd.jupyter.widget-view+json": {
721 | "model_id": "438699af2ed94a3eb67e4f408bd1e7a8",
722 | "version_minor": 0,
723 | "version_major": 2
724 | },
725 | "text/plain": [
726 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))"
727 | ]
728 | },
729 | "metadata": {
730 | "tags": []
731 | }
732 | },
733 | {
734 | "output_type": "stream",
735 | "text": [
736 | "\n",
737 | "(250, 128, 128, 3) (250,)\n"
738 | ],
739 | "name": "stdout"
740 | }
741 | ]
742 | },
743 | {
744 | "cell_type": "code",
745 | "metadata": {
746 | "id": "qGdwXZJk4eDH",
747 | "colab_type": "code",
748 | "colab": {}
749 | },
750 | "source": [
751 | "from sklearn import preprocessing\n",
752 | "le = preprocessing.LabelEncoder()\n",
753 | "y_train_enc = le.fit_transform(y_train)\n",
754 | "y_test_enc = le.transform(y_test)\n"
755 | ],
756 | "execution_count": 0,
757 | "outputs": []
758 | },
759 | {
760 | "cell_type": "code",
761 | "metadata": {
762 | "id": "nmX3x8wE4zBo",
763 | "colab_type": "code",
764 | "colab": {}
765 | },
766 | "source": [
767 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n",
768 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))"
769 | ],
770 | "execution_count": 0,
771 | "outputs": []
772 | },
773 | {
774 | "cell_type": "code",
775 | "metadata": {
776 | "id": "9yBgpLe443a-",
777 | "colab_type": "code",
778 | "colab": {}
779 | },
780 | "source": [
781 | "@tf.function\n",
782 | "def aug(image, label):\n",
783 | " x=tf.image.random_brightness(image,max_delta=0)\n",
784 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n",
785 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n",
786 | " x = tf.image.random_hue(x, max_delta=0.4)\n",
787 | " x = tf.clip_by_value(x, 0, 1)\n",
788 | "\n",
789 | " return x, label"
790 | ],
791 | "execution_count": 0,
792 | "outputs": []
793 | },
794 | {
795 | "cell_type": "code",
796 | "metadata": {
797 | "id": "icCj5VGk45ce",
798 | "colab_type": "code",
799 | "colab": {}
800 | },
801 | "source": [
802 | "IMG_SHAPE = 128\n",
803 | "BS = 64\n",
804 | "AUTO = tf.data.experimental.AUTOTUNE\n",
805 | "train_ds = (\n",
806 | " train_ds\n",
807 | " .shuffle(100)\n",
808 | " .batch(BS)\n",
809 | " .map(aug, num_parallel_calls=AUTO)\n",
810 | " .prefetch(AUTO)\n",
811 | ")\n",
812 | "validation_ds = (\n",
813 | " validation_ds\n",
814 | " .shuffle(100)\n",
815 | " .batch(BS)\n",
816 | " .prefetch(AUTO)\n",
817 | ")"
818 | ],
819 | "execution_count": 0,
820 | "outputs": []
821 | },
822 | {
823 | "cell_type": "markdown",
824 | "metadata": {
825 | "id": "tkxjWEeIHrCf",
826 | "colab_type": "text"
827 | },
828 | "source": [
829 | "# Model building and training wih RMSprop\n"
830 | ]
831 | },
832 | {
833 | "cell_type": "code",
834 | "metadata": {
835 | "id": "umbRNW-A4755",
836 | "colab_type": "code",
837 | "colab": {}
838 | },
839 | "source": [
840 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n",
841 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])"
842 | ],
843 | "execution_count": 0,
844 | "outputs": []
845 | },
846 | {
847 | "cell_type": "code",
848 | "metadata": {
849 | "id": "WVilaFIu5Hft",
850 | "colab_type": "code",
851 | "colab": {}
852 | },
853 | "source": [
854 | "decay_steps = 1000\n",
855 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n",
856 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n",
857 | "\n",
858 | "model.compile(optimizer=tf.keras.optimizers.RMSprop(lr_decayed_fn),\n",
859 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
860 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
861 | ],
862 | "execution_count": 0,
863 | "outputs": []
864 | },
865 | {
866 | "cell_type": "code",
867 | "metadata": {
868 | "id": "X3PxnSYd5x2W",
869 | "colab_type": "code",
870 | "colab": {}
871 | },
872 | "source": [
873 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n",
874 | "\trestore_best_weights=True, verbose=2)"
875 | ],
876 | "execution_count": 0,
877 | "outputs": []
878 | },
879 | {
880 | "cell_type": "code",
881 | "metadata": {
882 | "id": "S5lMKwCQ54KX",
883 | "colab_type": "code",
884 | "outputId": "c6e1573d-c593-4447-c182-34674d7ab6ef",
885 | "colab": {
886 | "base_uri": "https://localhost:8080/",
887 | "height": 390
888 | }
889 | },
890 | "source": [
891 | "import time\n",
892 | "import wandb\n",
893 | "\n",
894 | "wandb.init(entity='authors',project='scl')\n",
895 | "start = time.time()\n",
896 | "model.fit(train_ds,\n",
897 | " validation_data=validation_ds,\n",
898 | " epochs=50,\n",
899 | " callbacks=[wandb.keras.WandbCallback(), es])\n",
900 | "end = time.time()\n",
901 | "wandb.log({\"training_time\": end - start})"
902 | ],
903 | "execution_count": 0,
904 | "outputs": [
905 | {
906 | "output_type": "display_data",
907 | "data": {
908 | "text/html": [
909 | "\n",
910 | " Logging results to Weights & Biases (Documentation).
\n",
911 | " Project page: https://app.wandb.ai/authors/scl
\n",
912 | " Run page: https://app.wandb.ai/authors/scl/runs/2h40mbhd
\n",
913 | " "
914 | ],
915 | "text/plain": [
916 | ""
917 | ]
918 | },
919 | "metadata": {
920 | "tags": []
921 | }
922 | },
923 | {
924 | "output_type": "stream",
925 | "text": [
926 | "Epoch 1/50\n",
927 | "20/20 [==============================] - 6s 312ms/step - loss: 1.6051 - sparse_categorical_accuracy: 0.2976 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n",
928 | "Epoch 2/50\n",
929 | "20/20 [==============================] - 5s 263ms/step - loss: 1.5929 - sparse_categorical_accuracy: 0.3056 - val_loss: 1.7039 - val_sparse_categorical_accuracy: 0.2000\n",
930 | "Epoch 3/50\n",
931 | "20/20 [==============================] - 5s 271ms/step - loss: 1.6160 - sparse_categorical_accuracy: 0.2792 - val_loss: 1.6867 - val_sparse_categorical_accuracy: 0.2120\n",
932 | "Epoch 4/50\n",
933 | "20/20 [==============================] - 4s 223ms/step - loss: 1.5946 - sparse_categorical_accuracy: 0.2776 - val_loss: 1.6938 - val_sparse_categorical_accuracy: 0.2000\n",
934 | "Epoch 5/50\n",
935 | "20/20 [==============================] - 5s 273ms/step - loss: 1.5765 - sparse_categorical_accuracy: 0.2880 - val_loss: 1.5870 - val_sparse_categorical_accuracy: 0.3040\n",
936 | "Epoch 6/50\n",
937 | "20/20 [==============================] - 6s 276ms/step - loss: 1.5180 - sparse_categorical_accuracy: 0.3552 - val_loss: 1.5748 - val_sparse_categorical_accuracy: 0.3080\n",
938 | "Epoch 7/50\n",
939 | "20/20 [==============================] - 5s 227ms/step - loss: 1.5345 - sparse_categorical_accuracy: 0.3512 - val_loss: 1.6253 - val_sparse_categorical_accuracy: 0.2680\n",
940 | "Epoch 8/50\n",
941 | "20/20 [==============================] - ETA: 0s - loss: 1.5690 - sparse_categorical_accuracy: 0.3136Restoring model weights from the end of the best epoch.\n",
942 | "20/20 [==============================] - 5s 231ms/step - loss: 1.5690 - sparse_categorical_accuracy: 0.3136 - val_loss: 1.6235 - val_sparse_categorical_accuracy: 0.2800\n",
943 | "Epoch 00008: early stopping\n"
944 | ],
945 | "name": "stdout"
946 | }
947 | ]
948 | },
949 | {
950 | "cell_type": "code",
951 | "metadata": {
952 | "id": "edc3Fu_C6AJO",
953 | "colab_type": "code",
954 | "colab": {}
955 | },
956 | "source": [
957 | "model.save_weights(\"full_supervised_learning.h5\")"
958 | ],
959 | "execution_count": 0,
960 | "outputs": []
961 | },
962 | {
963 | "cell_type": "code",
964 | "metadata": {
965 | "id": "wOPN7pPwBN0V",
966 | "colab_type": "code",
967 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd",
968 | "colab": {
969 | "base_uri": "https://localhost:8080/",
970 | "height": 34
971 | }
972 | },
973 | "source": [
974 | "wandb.save(\"full_supervised_learning.h5\")"
975 | ],
976 | "execution_count": 0,
977 | "outputs": [
978 | {
979 | "output_type": "execute_result",
980 | "data": {
981 | "text/plain": [
982 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']"
983 | ]
984 | },
985 | "metadata": {
986 | "tags": []
987 | },
988 | "execution_count": 94
989 | }
990 | ]
991 | },
992 | {
993 | "cell_type": "code",
994 | "metadata": {
995 | "id": "FNk0NhWFBSYe",
996 | "colab_type": "code",
997 | "colab": {}
998 | },
999 | "source": [
1000 | ""
1001 | ],
1002 | "execution_count": 0,
1003 | "outputs": []
1004 | }
1005 | ]
1006 | }
--------------------------------------------------------------------------------
/ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_Adam.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Fully_Supervised_Training_IMGNET_subset_Adam.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU",
15 | "widgets": {
16 | "application/vnd.jupyter.widget-state+json": {
17 | "2a8b0bfc58804efea484cd4cd71dd00f": {
18 | "model_module": "@jupyter-widgets/controls",
19 | "model_name": "HBoxModel",
20 | "state": {
21 | "_view_name": "HBoxView",
22 | "_dom_classes": [],
23 | "_model_name": "HBoxModel",
24 | "_view_module": "@jupyter-widgets/controls",
25 | "_model_module_version": "1.5.0",
26 | "_view_count": null,
27 | "_view_module_version": "1.5.0",
28 | "box_style": "",
29 | "layout": "IPY_MODEL_126f1d14800843aa9332f00aead8395b",
30 | "_model_module": "@jupyter-widgets/controls",
31 | "children": [
32 | "IPY_MODEL_805d93bf8f4446d3871cd93392101584",
33 | "IPY_MODEL_33d042e9035942359936f5d4bd6a108a"
34 | ]
35 | }
36 | },
37 | "126f1d14800843aa9332f00aead8395b": {
38 | "model_module": "@jupyter-widgets/base",
39 | "model_name": "LayoutModel",
40 | "state": {
41 | "_view_name": "LayoutView",
42 | "grid_template_rows": null,
43 | "right": null,
44 | "justify_content": null,
45 | "_view_module": "@jupyter-widgets/base",
46 | "overflow": null,
47 | "_model_module_version": "1.2.0",
48 | "_view_count": null,
49 | "flex_flow": null,
50 | "width": null,
51 | "min_width": null,
52 | "border": null,
53 | "align_items": null,
54 | "bottom": null,
55 | "_model_module": "@jupyter-widgets/base",
56 | "top": null,
57 | "grid_column": null,
58 | "overflow_y": null,
59 | "overflow_x": null,
60 | "grid_auto_flow": null,
61 | "grid_area": null,
62 | "grid_template_columns": null,
63 | "flex": null,
64 | "_model_name": "LayoutModel",
65 | "justify_items": null,
66 | "grid_row": null,
67 | "max_height": null,
68 | "align_content": null,
69 | "visibility": null,
70 | "align_self": null,
71 | "height": null,
72 | "min_height": null,
73 | "padding": null,
74 | "grid_auto_rows": null,
75 | "grid_gap": null,
76 | "max_width": null,
77 | "order": null,
78 | "_view_module_version": "1.2.0",
79 | "grid_template_areas": null,
80 | "object_position": null,
81 | "object_fit": null,
82 | "grid_auto_columns": null,
83 | "margin": null,
84 | "display": null,
85 | "left": null
86 | }
87 | },
88 | "805d93bf8f4446d3871cd93392101584": {
89 | "model_module": "@jupyter-widgets/controls",
90 | "model_name": "FloatProgressModel",
91 | "state": {
92 | "_view_name": "ProgressView",
93 | "style": "IPY_MODEL_05b5252d8859432c8724c8fb9636ef33",
94 | "_dom_classes": [],
95 | "description": "100%",
96 | "_model_name": "FloatProgressModel",
97 | "bar_style": "success",
98 | "max": 1250,
99 | "_view_module": "@jupyter-widgets/controls",
100 | "_model_module_version": "1.5.0",
101 | "value": 1250,
102 | "_view_count": null,
103 | "_view_module_version": "1.5.0",
104 | "orientation": "horizontal",
105 | "min": 0,
106 | "description_tooltip": null,
107 | "_model_module": "@jupyter-widgets/controls",
108 | "layout": "IPY_MODEL_77179de1c044473abb9da855d5ab8eca"
109 | }
110 | },
111 | "33d042e9035942359936f5d4bd6a108a": {
112 | "model_module": "@jupyter-widgets/controls",
113 | "model_name": "HTMLModel",
114 | "state": {
115 | "_view_name": "HTMLView",
116 | "style": "IPY_MODEL_eb72220a381347d1bce57c5706554460",
117 | "_dom_classes": [],
118 | "description": "",
119 | "_model_name": "HTMLModel",
120 | "placeholder": "",
121 | "_view_module": "@jupyter-widgets/controls",
122 | "_model_module_version": "1.5.0",
123 | "value": " 1250/1250 [00:06<00:00, 187.65it/s]",
124 | "_view_count": null,
125 | "_view_module_version": "1.5.0",
126 | "description_tooltip": null,
127 | "_model_module": "@jupyter-widgets/controls",
128 | "layout": "IPY_MODEL_fe2c9cb4e208436db3c15d342f8d08e7"
129 | }
130 | },
131 | "05b5252d8859432c8724c8fb9636ef33": {
132 | "model_module": "@jupyter-widgets/controls",
133 | "model_name": "ProgressStyleModel",
134 | "state": {
135 | "_view_name": "StyleView",
136 | "_model_name": "ProgressStyleModel",
137 | "description_width": "initial",
138 | "_view_module": "@jupyter-widgets/base",
139 | "_model_module_version": "1.5.0",
140 | "_view_count": null,
141 | "_view_module_version": "1.2.0",
142 | "bar_color": null,
143 | "_model_module": "@jupyter-widgets/controls"
144 | }
145 | },
146 | "77179de1c044473abb9da855d5ab8eca": {
147 | "model_module": "@jupyter-widgets/base",
148 | "model_name": "LayoutModel",
149 | "state": {
150 | "_view_name": "LayoutView",
151 | "grid_template_rows": null,
152 | "right": null,
153 | "justify_content": null,
154 | "_view_module": "@jupyter-widgets/base",
155 | "overflow": null,
156 | "_model_module_version": "1.2.0",
157 | "_view_count": null,
158 | "flex_flow": null,
159 | "width": null,
160 | "min_width": null,
161 | "border": null,
162 | "align_items": null,
163 | "bottom": null,
164 | "_model_module": "@jupyter-widgets/base",
165 | "top": null,
166 | "grid_column": null,
167 | "overflow_y": null,
168 | "overflow_x": null,
169 | "grid_auto_flow": null,
170 | "grid_area": null,
171 | "grid_template_columns": null,
172 | "flex": null,
173 | "_model_name": "LayoutModel",
174 | "justify_items": null,
175 | "grid_row": null,
176 | "max_height": null,
177 | "align_content": null,
178 | "visibility": null,
179 | "align_self": null,
180 | "height": null,
181 | "min_height": null,
182 | "padding": null,
183 | "grid_auto_rows": null,
184 | "grid_gap": null,
185 | "max_width": null,
186 | "order": null,
187 | "_view_module_version": "1.2.0",
188 | "grid_template_areas": null,
189 | "object_position": null,
190 | "object_fit": null,
191 | "grid_auto_columns": null,
192 | "margin": null,
193 | "display": null,
194 | "left": null
195 | }
196 | },
197 | "eb72220a381347d1bce57c5706554460": {
198 | "model_module": "@jupyter-widgets/controls",
199 | "model_name": "DescriptionStyleModel",
200 | "state": {
201 | "_view_name": "StyleView",
202 | "_model_name": "DescriptionStyleModel",
203 | "description_width": "",
204 | "_view_module": "@jupyter-widgets/base",
205 | "_model_module_version": "1.5.0",
206 | "_view_count": null,
207 | "_view_module_version": "1.2.0",
208 | "_model_module": "@jupyter-widgets/controls"
209 | }
210 | },
211 | "fe2c9cb4e208436db3c15d342f8d08e7": {
212 | "model_module": "@jupyter-widgets/base",
213 | "model_name": "LayoutModel",
214 | "state": {
215 | "_view_name": "LayoutView",
216 | "grid_template_rows": null,
217 | "right": null,
218 | "justify_content": null,
219 | "_view_module": "@jupyter-widgets/base",
220 | "overflow": null,
221 | "_model_module_version": "1.2.0",
222 | "_view_count": null,
223 | "flex_flow": null,
224 | "width": null,
225 | "min_width": null,
226 | "border": null,
227 | "align_items": null,
228 | "bottom": null,
229 | "_model_module": "@jupyter-widgets/base",
230 | "top": null,
231 | "grid_column": null,
232 | "overflow_y": null,
233 | "overflow_x": null,
234 | "grid_auto_flow": null,
235 | "grid_area": null,
236 | "grid_template_columns": null,
237 | "flex": null,
238 | "_model_name": "LayoutModel",
239 | "justify_items": null,
240 | "grid_row": null,
241 | "max_height": null,
242 | "align_content": null,
243 | "visibility": null,
244 | "align_self": null,
245 | "height": null,
246 | "min_height": null,
247 | "padding": null,
248 | "grid_auto_rows": null,
249 | "grid_gap": null,
250 | "max_width": null,
251 | "order": null,
252 | "_view_module_version": "1.2.0",
253 | "grid_template_areas": null,
254 | "object_position": null,
255 | "object_fit": null,
256 | "grid_auto_columns": null,
257 | "margin": null,
258 | "display": null,
259 | "left": null
260 | }
261 | },
262 | "c7e81694439c41408bb1090cb82cd2b1": {
263 | "model_module": "@jupyter-widgets/controls",
264 | "model_name": "HBoxModel",
265 | "state": {
266 | "_view_name": "HBoxView",
267 | "_dom_classes": [],
268 | "_model_name": "HBoxModel",
269 | "_view_module": "@jupyter-widgets/controls",
270 | "_model_module_version": "1.5.0",
271 | "_view_count": null,
272 | "_view_module_version": "1.5.0",
273 | "box_style": "",
274 | "layout": "IPY_MODEL_7683c42a14e14250bf6e375a42028b87",
275 | "_model_module": "@jupyter-widgets/controls",
276 | "children": [
277 | "IPY_MODEL_aac930db883d4977a9f3f08b1a0fa7a5",
278 | "IPY_MODEL_3862f5cde98f49f783d8fc54ada40a11"
279 | ]
280 | }
281 | },
282 | "7683c42a14e14250bf6e375a42028b87": {
283 | "model_module": "@jupyter-widgets/base",
284 | "model_name": "LayoutModel",
285 | "state": {
286 | "_view_name": "LayoutView",
287 | "grid_template_rows": null,
288 | "right": null,
289 | "justify_content": null,
290 | "_view_module": "@jupyter-widgets/base",
291 | "overflow": null,
292 | "_model_module_version": "1.2.0",
293 | "_view_count": null,
294 | "flex_flow": null,
295 | "width": null,
296 | "min_width": null,
297 | "border": null,
298 | "align_items": null,
299 | "bottom": null,
300 | "_model_module": "@jupyter-widgets/base",
301 | "top": null,
302 | "grid_column": null,
303 | "overflow_y": null,
304 | "overflow_x": null,
305 | "grid_auto_flow": null,
306 | "grid_area": null,
307 | "grid_template_columns": null,
308 | "flex": null,
309 | "_model_name": "LayoutModel",
310 | "justify_items": null,
311 | "grid_row": null,
312 | "max_height": null,
313 | "align_content": null,
314 | "visibility": null,
315 | "align_self": null,
316 | "height": null,
317 | "min_height": null,
318 | "padding": null,
319 | "grid_auto_rows": null,
320 | "grid_gap": null,
321 | "max_width": null,
322 | "order": null,
323 | "_view_module_version": "1.2.0",
324 | "grid_template_areas": null,
325 | "object_position": null,
326 | "object_fit": null,
327 | "grid_auto_columns": null,
328 | "margin": null,
329 | "display": null,
330 | "left": null
331 | }
332 | },
333 | "aac930db883d4977a9f3f08b1a0fa7a5": {
334 | "model_module": "@jupyter-widgets/controls",
335 | "model_name": "FloatProgressModel",
336 | "state": {
337 | "_view_name": "ProgressView",
338 | "style": "IPY_MODEL_e9b5eef0caa148a8aaaa2ed576fe68cc",
339 | "_dom_classes": [],
340 | "description": "100%",
341 | "_model_name": "FloatProgressModel",
342 | "bar_style": "success",
343 | "max": 250,
344 | "_view_module": "@jupyter-widgets/controls",
345 | "_model_module_version": "1.5.0",
346 | "value": 250,
347 | "_view_count": null,
348 | "_view_module_version": "1.5.0",
349 | "orientation": "horizontal",
350 | "min": 0,
351 | "description_tooltip": null,
352 | "_model_module": "@jupyter-widgets/controls",
353 | "layout": "IPY_MODEL_2d671cb0ab744e509c98e0d1b2fcac17"
354 | }
355 | },
356 | "3862f5cde98f49f783d8fc54ada40a11": {
357 | "model_module": "@jupyter-widgets/controls",
358 | "model_name": "HTMLModel",
359 | "state": {
360 | "_view_name": "HTMLView",
361 | "style": "IPY_MODEL_79ff0547109643e8b8002a9bf1e6ebcc",
362 | "_dom_classes": [],
363 | "description": "",
364 | "_model_name": "HTMLModel",
365 | "placeholder": "",
366 | "_view_module": "@jupyter-widgets/controls",
367 | "_model_module_version": "1.5.0",
368 | "value": " 250/250 [00:01<00:00, 221.46it/s]",
369 | "_view_count": null,
370 | "_view_module_version": "1.5.0",
371 | "description_tooltip": null,
372 | "_model_module": "@jupyter-widgets/controls",
373 | "layout": "IPY_MODEL_8c382843a4f9475fb1a084b08ea0a7ae"
374 | }
375 | },
376 | "e9b5eef0caa148a8aaaa2ed576fe68cc": {
377 | "model_module": "@jupyter-widgets/controls",
378 | "model_name": "ProgressStyleModel",
379 | "state": {
380 | "_view_name": "StyleView",
381 | "_model_name": "ProgressStyleModel",
382 | "description_width": "initial",
383 | "_view_module": "@jupyter-widgets/base",
384 | "_model_module_version": "1.5.0",
385 | "_view_count": null,
386 | "_view_module_version": "1.2.0",
387 | "bar_color": null,
388 | "_model_module": "@jupyter-widgets/controls"
389 | }
390 | },
391 | "2d671cb0ab744e509c98e0d1b2fcac17": {
392 | "model_module": "@jupyter-widgets/base",
393 | "model_name": "LayoutModel",
394 | "state": {
395 | "_view_name": "LayoutView",
396 | "grid_template_rows": null,
397 | "right": null,
398 | "justify_content": null,
399 | "_view_module": "@jupyter-widgets/base",
400 | "overflow": null,
401 | "_model_module_version": "1.2.0",
402 | "_view_count": null,
403 | "flex_flow": null,
404 | "width": null,
405 | "min_width": null,
406 | "border": null,
407 | "align_items": null,
408 | "bottom": null,
409 | "_model_module": "@jupyter-widgets/base",
410 | "top": null,
411 | "grid_column": null,
412 | "overflow_y": null,
413 | "overflow_x": null,
414 | "grid_auto_flow": null,
415 | "grid_area": null,
416 | "grid_template_columns": null,
417 | "flex": null,
418 | "_model_name": "LayoutModel",
419 | "justify_items": null,
420 | "grid_row": null,
421 | "max_height": null,
422 | "align_content": null,
423 | "visibility": null,
424 | "align_self": null,
425 | "height": null,
426 | "min_height": null,
427 | "padding": null,
428 | "grid_auto_rows": null,
429 | "grid_gap": null,
430 | "max_width": null,
431 | "order": null,
432 | "_view_module_version": "1.2.0",
433 | "grid_template_areas": null,
434 | "object_position": null,
435 | "object_fit": null,
436 | "grid_auto_columns": null,
437 | "margin": null,
438 | "display": null,
439 | "left": null
440 | }
441 | },
442 | "79ff0547109643e8b8002a9bf1e6ebcc": {
443 | "model_module": "@jupyter-widgets/controls",
444 | "model_name": "DescriptionStyleModel",
445 | "state": {
446 | "_view_name": "StyleView",
447 | "_model_name": "DescriptionStyleModel",
448 | "description_width": "",
449 | "_view_module": "@jupyter-widgets/base",
450 | "_model_module_version": "1.5.0",
451 | "_view_count": null,
452 | "_view_module_version": "1.2.0",
453 | "_model_module": "@jupyter-widgets/controls"
454 | }
455 | },
456 | "8c382843a4f9475fb1a084b08ea0a7ae": {
457 | "model_module": "@jupyter-widgets/base",
458 | "model_name": "LayoutModel",
459 | "state": {
460 | "_view_name": "LayoutView",
461 | "grid_template_rows": null,
462 | "right": null,
463 | "justify_content": null,
464 | "_view_module": "@jupyter-widgets/base",
465 | "overflow": null,
466 | "_model_module_version": "1.2.0",
467 | "_view_count": null,
468 | "flex_flow": null,
469 | "width": null,
470 | "min_width": null,
471 | "border": null,
472 | "align_items": null,
473 | "bottom": null,
474 | "_model_module": "@jupyter-widgets/base",
475 | "top": null,
476 | "grid_column": null,
477 | "overflow_y": null,
478 | "overflow_x": null,
479 | "grid_auto_flow": null,
480 | "grid_area": null,
481 | "grid_template_columns": null,
482 | "flex": null,
483 | "_model_name": "LayoutModel",
484 | "justify_items": null,
485 | "grid_row": null,
486 | "max_height": null,
487 | "align_content": null,
488 | "visibility": null,
489 | "align_self": null,
490 | "height": null,
491 | "min_height": null,
492 | "padding": null,
493 | "grid_auto_rows": null,
494 | "grid_gap": null,
495 | "max_width": null,
496 | "order": null,
497 | "_view_module_version": "1.2.0",
498 | "grid_template_areas": null,
499 | "object_position": null,
500 | "object_fit": null,
501 | "grid_auto_columns": null,
502 | "margin": null,
503 | "display": null,
504 | "left": null
505 | }
506 | }
507 | }
508 | }
509 | },
510 | "cells": [
511 | {
512 | "cell_type": "markdown",
513 | "metadata": {
514 | "id": "JuiT6O71HUAy",
515 | "colab_type": "text"
516 | },
517 | "source": [
518 | "# Initial Setup"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "metadata": {
524 | "id": "FgWG4d-K3xRt",
525 | "colab_type": "code",
526 | "outputId": "140af918-8fea-4aac-d44d-80191d1a2877",
527 | "colab": {
528 | "base_uri": "https://localhost:8080/",
529 | "height": 35
530 | }
531 | },
532 | "source": [
533 | "import tensorflow as tf\n",
534 | "print(tf.__version__)"
535 | ],
536 | "execution_count": 1,
537 | "outputs": [
538 | {
539 | "output_type": "stream",
540 | "text": [
541 | "2.2.0\n"
542 | ],
543 | "name": "stdout"
544 | }
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "metadata": {
550 | "id": "gtxvkdsm338L",
551 | "colab_type": "code",
552 | "colab": {
553 | "base_uri": "https://localhost:8080/",
554 | "height": 1000
555 | },
556 | "outputId": "56474563-95cf-45af-a1c0-af5a447d8fa7"
557 | },
558 | "source": [
559 | "!pip install wandb\n",
560 | "import wandb\n",
561 | "wandb.login()"
562 | ],
563 | "execution_count": 2,
564 | "outputs": [
565 | {
566 | "output_type": "stream",
567 | "text": [
568 | "Collecting wandb\n",
569 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d1/c7/8bf2c62c3f133f45e135a8a116e4e0f162043248e3db54de30996eaf1a8a/wandb-0.8.36-py2.py3-none-any.whl (1.4MB)\n",
570 | "\u001b[K |████████████████████████████████| 1.4MB 4.8MB/s \n",
571 | "\u001b[?25hCollecting sentry-sdk>=0.4.0\n",
572 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1b/95/9a20eebcedab2c1c63fad59fe19a0469edfc2a25b8576497e8084629c2ff/sentry_sdk-0.14.4-py2.py3-none-any.whl (104kB)\n",
573 | "\u001b[K |████████████████████████████████| 112kB 29.8MB/s \n",
574 | "\u001b[?25hRequirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n",
575 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n",
576 | "Collecting configparser>=3.8.1\n",
577 | " Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n",
578 | "Requirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n",
579 | "Collecting gql==0.2.0\n",
580 | " Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n",
581 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n",
582 | "Collecting watchdog>=0.8.3\n",
583 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)\n",
584 | "\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n",
585 | "\u001b[?25hRequirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n",
586 | "Collecting subprocess32>=3.5.3\n",
587 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n",
588 | "\u001b[K |████████████████████████████████| 102kB 10.6MB/s \n",
589 | "\u001b[?25hCollecting shortuuid>=0.5.0\n",
590 | " Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n",
591 | "Collecting GitPython>=1.0.0\n",
592 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/44/33/917e6fde1cad13daa7053f39b7c8af3be287314f75f1b1ea8d3fe37a8571/GitPython-3.1.2-py3-none-any.whl (451kB)\n",
593 | "\u001b[K |████████████████████████████████| 460kB 23.4MB/s \n",
594 | "\u001b[?25hRequirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n",
595 | "Collecting docker-pycreds>=0.4.0\n",
596 | " Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n",
597 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n",
598 | "Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n",
599 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.4.5.1)\n",
600 | "Collecting graphql-core<2,>=0.5.0\n",
601 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n",
602 | "\u001b[K |████████████████████████████████| 71kB 9.4MB/s \n",
603 | "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n",
604 | "Collecting pathtools>=0.1.1\n",
605 | " Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n",
606 | "Collecting gitdb<5,>=4.0.1\n",
607 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
608 | "\u001b[K |████████████████████████████████| 71kB 11.3MB/s \n",
609 | "\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n",
610 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n",
611 | "Collecting smmap<4,>=3.0.1\n",
612 | " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
613 | "Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools\n",
614 | " Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
615 | " Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=bd431a7c7f187272be19a9836b1e7616f86e03d34c0a4de4b4f6810140d9de42\n",
616 | " Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n",
617 | " Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
618 | " Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=c9fed8e42385522813f4c68267548bd110f3f869f0abc97b6fe982888c9aabf6\n",
619 | " Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e\n",
620 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
621 | " Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=0dc181757046168379a173ffa1dddec577f6fe87f03c62125ecfbdb3ab14bd4d\n",
622 | " Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n",
623 | " Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
624 | " Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=e4e49bb64de6665036a068a85324e126b1c685b4540c8caa5d1ffb89190c6b50\n",
625 | " Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n",
626 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
627 | " Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=de6e98828915023d5115bd39dd81f22676351c6a70e74bf13ce914cacd0cee48\n",
628 | " Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n",
629 | "Successfully built gql watchdog subprocess32 graphql-core pathtools\n",
630 | "Installing collected packages: sentry-sdk, configparser, graphql-core, gql, pathtools, watchdog, subprocess32, shortuuid, smmap, gitdb, GitPython, docker-pycreds, wandb\n",
631 | "Successfully installed GitPython-3.1.2 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.4 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.8.36 watchdog-0.10.2\n"
632 | ],
633 | "name": "stdout"
634 | },
635 | {
636 | "output_type": "display_data",
637 | "data": {
638 | "application/javascript": [
639 | "\n",
640 | " window._wandbApiKey = new Promise((resolve, reject) => {\n",
641 | " function loadScript(url) {\n",
642 | " return new Promise(function(resolve, reject) {\n",
643 | " let newScript = document.createElement(\"script\");\n",
644 | " newScript.onerror = reject;\n",
645 | " newScript.onload = resolve;\n",
646 | " document.body.appendChild(newScript);\n",
647 | " newScript.src = url;\n",
648 | " });\n",
649 | " }\n",
650 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
651 | " const iframe = document.createElement('iframe')\n",
652 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n",
653 | " document.body.appendChild(iframe)\n",
654 | " const handshake = new Postmate({\n",
655 | " container: iframe,\n",
656 | " url: 'https://app.wandb.ai/authorize'\n",
657 | " });\n",
658 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
659 | " handshake.then(function(child) {\n",
660 | " child.on('authorize', data => {\n",
661 | " clearTimeout(timeout)\n",
662 | " resolve(data)\n",
663 | " });\n",
664 | " });\n",
665 | " })\n",
666 | " });\n",
667 | " "
668 | ],
669 | "text/plain": [
670 | ""
671 | ]
672 | },
673 | "metadata": {
674 | "tags": []
675 | }
676 | },
677 | {
678 | "output_type": "stream",
679 | "text": [
680 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n"
681 | ],
682 | "name": "stderr"
683 | },
684 | {
685 | "output_type": "stream",
686 | "text": [
687 | "API Key: ··········\n"
688 | ],
689 | "name": "stdout"
690 | },
691 | {
692 | "output_type": "stream",
693 | "text": [
694 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
695 | ],
696 | "name": "stderr"
697 | },
698 | {
699 | "output_type": "execute_result",
700 | "data": {
701 | "text/plain": [
702 | "True"
703 | ]
704 | },
705 | "metadata": {
706 | "tags": []
707 | },
708 | "execution_count": 2
709 | }
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "metadata": {
715 | "id": "E8cv8vit3ydm",
716 | "colab_type": "code",
717 | "colab": {}
718 | },
719 | "source": [
720 | "from tensorflow.keras.layers import *\n",
721 | "from tensorflow.keras.models import *\n",
722 | "from wandb.keras import WandbCallback\n",
723 | "import tensorflow_datasets as tfds\n",
724 | "import matplotlib.pyplot as plt\n",
725 | "import numpy as np\n",
726 | "import time\n",
727 | "import cv2\n",
728 | "from tqdm.notebook import tqdm\n",
729 | "from imutils import paths\n",
730 | "tf.random.set_seed(666)\n",
731 | "np.random.seed(666)\n",
732 | "\n",
733 | "tfds.disable_progress_bar()"
734 | ],
735 | "execution_count": 0,
736 | "outputs": []
737 | },
738 | {
739 | "cell_type": "markdown",
740 | "metadata": {
741 | "id": "ebM6CaFsHcya",
742 | "colab_type": "text"
743 | },
744 | "source": [
745 | "# Imagenet Subset "
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "metadata": {
751 | "id": "4vPz9Alk31qZ",
752 | "colab_type": "code",
753 | "outputId": "14ab0cb9-a64e-4ff7-cdf4-f873752cec89",
754 | "colab": {
755 | "base_uri": "https://localhost:8080/",
756 | "height": 106
757 | }
758 | },
759 | "source": [
760 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n"
761 | ],
762 | "execution_count": 4,
763 | "outputs": [
764 | {
765 | "output_type": "stream",
766 | "text": [
767 | "Cloning into 'imagenet-5-categories'...\n",
768 | "remote: Enumerating objects: 1532, done.\u001b[K\n",
769 | "remote: Total 1532 (delta 0), reused 0 (delta 0), pack-reused 1532\u001b[K\n",
770 | "Receiving objects: 100% (1532/1532), 88.56 MiB | 51.26 MiB/s, done.\n",
771 | "Resolving deltas: 100% (1/1), done.\n"
772 | ],
773 | "name": "stdout"
774 | }
775 | ]
776 | },
777 | {
778 | "cell_type": "code",
779 | "metadata": {
780 | "id": "5vVPALgj4Ogg",
781 | "colab_type": "code",
782 | "colab": {}
783 | },
784 | "source": [
785 | "# Train and test image paths\n",
786 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n",
787 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n"
788 | ],
789 | "execution_count": 0,
790 | "outputs": []
791 | },
792 | {
793 | "cell_type": "code",
794 | "metadata": {
795 | "id": "YM_w3yZi4RQf",
796 | "colab_type": "code",
797 | "colab": {}
798 | },
799 | "source": [
800 | "def prepare_images(image_paths):\n",
801 | " images = []\n",
802 | " labels = []\n",
803 | "\n",
804 | " for image in tqdm(image_paths):\n",
805 | " image_pixels = plt.imread(image)\n",
806 | " image_pixels = cv2.resize(image_pixels, (128,128))\n",
807 | " image_pixels = image_pixels/255.\n",
808 | "\n",
809 | " label = image.split(\"/\")[2].split(\"_\")[0]\n",
810 | "\n",
811 | " images.append(image_pixels)\n",
812 | " labels.append(label)\n",
813 | "\n",
814 | " images = np.array(images)\n",
815 | " labels = np.array(labels)\n",
816 | "\n",
817 | " print(images.shape, labels.shape)\n",
818 | "\n",
819 | " return images, labels"
820 | ],
821 | "execution_count": 0,
822 | "outputs": []
823 | },
824 | {
825 | "cell_type": "code",
826 | "metadata": {
827 | "id": "KeNWTqpG4b0e",
828 | "colab_type": "code",
829 | "outputId": "66c2ec48-d6df-4f2b-a638-db3adb024e79",
830 | "colab": {
831 | "base_uri": "https://localhost:8080/",
832 | "height": 152,
833 | "referenced_widgets": [
834 | "2a8b0bfc58804efea484cd4cd71dd00f",
835 | "126f1d14800843aa9332f00aead8395b",
836 | "805d93bf8f4446d3871cd93392101584",
837 | "33d042e9035942359936f5d4bd6a108a",
838 | "05b5252d8859432c8724c8fb9636ef33",
839 | "77179de1c044473abb9da855d5ab8eca",
840 | "eb72220a381347d1bce57c5706554460",
841 | "fe2c9cb4e208436db3c15d342f8d08e7",
842 | "c7e81694439c41408bb1090cb82cd2b1",
843 | "7683c42a14e14250bf6e375a42028b87",
844 | "aac930db883d4977a9f3f08b1a0fa7a5",
845 | "3862f5cde98f49f783d8fc54ada40a11",
846 | "e9b5eef0caa148a8aaaa2ed576fe68cc",
847 | "2d671cb0ab744e509c98e0d1b2fcac17",
848 | "79ff0547109643e8b8002a9bf1e6ebcc",
849 | "8c382843a4f9475fb1a084b08ea0a7ae"
850 | ]
851 | }
852 | },
853 | "source": [
854 | "X_train, y_train = prepare_images(train_images)\n",
855 | "X_test, y_test = prepare_images(test_images)"
856 | ],
857 | "execution_count": 7,
858 | "outputs": [
859 | {
860 | "output_type": "display_data",
861 | "data": {
862 | "application/vnd.jupyter.widget-view+json": {
863 | "model_id": "2a8b0bfc58804efea484cd4cd71dd00f",
864 | "version_minor": 0,
865 | "version_major": 2
866 | },
867 | "text/plain": [
868 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))"
869 | ]
870 | },
871 | "metadata": {
872 | "tags": []
873 | }
874 | },
875 | {
876 | "output_type": "stream",
877 | "text": [
878 | "\n",
879 | "(1250, 128, 128, 3) (1250,)\n"
880 | ],
881 | "name": "stdout"
882 | },
883 | {
884 | "output_type": "display_data",
885 | "data": {
886 | "application/vnd.jupyter.widget-view+json": {
887 | "model_id": "c7e81694439c41408bb1090cb82cd2b1",
888 | "version_minor": 0,
889 | "version_major": 2
890 | },
891 | "text/plain": [
892 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))"
893 | ]
894 | },
895 | "metadata": {
896 | "tags": []
897 | }
898 | },
899 | {
900 | "output_type": "stream",
901 | "text": [
902 | "\n",
903 | "(250, 128, 128, 3) (250,)\n"
904 | ],
905 | "name": "stdout"
906 | }
907 | ]
908 | },
909 | {
910 | "cell_type": "code",
911 | "metadata": {
912 | "id": "qGdwXZJk4eDH",
913 | "colab_type": "code",
914 | "colab": {}
915 | },
916 | "source": [
917 | "from sklearn import preprocessing\n",
918 | "le = preprocessing.LabelEncoder()\n",
919 | "y_train_enc = le.fit_transform(y_train)\n",
920 | "y_test_enc = le.transform(y_test)\n"
921 | ],
922 | "execution_count": 0,
923 | "outputs": []
924 | },
925 | {
926 | "cell_type": "code",
927 | "metadata": {
928 | "id": "nmX3x8wE4zBo",
929 | "colab_type": "code",
930 | "colab": {}
931 | },
932 | "source": [
933 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n",
934 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))"
935 | ],
936 | "execution_count": 0,
937 | "outputs": []
938 | },
939 | {
940 | "cell_type": "code",
941 | "metadata": {
942 | "id": "9yBgpLe443a-",
943 | "colab_type": "code",
944 | "colab": {}
945 | },
946 | "source": [
947 | "@tf.function\n",
948 | "def aug(image, label):\n",
949 | " x=tf.image.random_brightness(image,max_delta=0)\n",
950 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n",
951 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n",
952 | " x = tf.image.random_hue(x, max_delta=0.4)\n",
953 | " x = tf.clip_by_value(x, 0, 1)\n",
954 | "\n",
955 | " return x, label"
956 | ],
957 | "execution_count": 0,
958 | "outputs": []
959 | },
960 | {
961 | "cell_type": "code",
962 | "metadata": {
963 | "id": "icCj5VGk45ce",
964 | "colab_type": "code",
965 | "colab": {}
966 | },
967 | "source": [
968 | "IMG_SHAPE = 128\n",
969 | "BS = 64\n",
970 | "AUTO = tf.data.experimental.AUTOTUNE\n",
971 | "train_ds = (\n",
972 | " train_ds\n",
973 | " .shuffle(100)\n",
974 | " .batch(BS)\n",
975 | " .map(aug, num_parallel_calls=AUTO)\n",
976 | " .prefetch(AUTO)\n",
977 | ")\n",
978 | "validation_ds = (\n",
979 | " validation_ds\n",
980 | " .shuffle(100)\n",
981 | " .batch(BS)\n",
982 | " .prefetch(AUTO)\n",
983 | ")"
984 | ],
985 | "execution_count": 0,
986 | "outputs": []
987 | },
988 | {
989 | "cell_type": "markdown",
990 | "metadata": {
991 | "id": "tkxjWEeIHrCf",
992 | "colab_type": "text"
993 | },
994 | "source": [
995 | "# Model building and training wih Adam\n"
996 | ]
997 | },
998 | {
999 | "cell_type": "code",
1000 | "metadata": {
1001 | "id": "umbRNW-A4755",
1002 | "colab_type": "code",
1003 | "colab": {}
1004 | },
1005 | "source": [
1006 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n",
1007 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])"
1008 | ],
1009 | "execution_count": 0,
1010 | "outputs": []
1011 | },
1012 | {
1013 | "cell_type": "code",
1014 | "metadata": {
1015 | "id": "WVilaFIu5Hft",
1016 | "colab_type": "code",
1017 | "colab": {}
1018 | },
1019 | "source": [
1020 | "decay_steps = 1000\n",
1021 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n",
1022 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n",
1023 | "\n",
1024 | "model.compile(optimizer=tf.keras.optimizers.Adam(lr_decayed_fn),\n",
1025 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
1026 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
1027 | ],
1028 | "execution_count": 0,
1029 | "outputs": []
1030 | },
1031 | {
1032 | "cell_type": "code",
1033 | "metadata": {
1034 | "id": "X3PxnSYd5x2W",
1035 | "colab_type": "code",
1036 | "colab": {}
1037 | },
1038 | "source": [
1039 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n",
1040 | "\trestore_best_weights=True, verbose=2)"
1041 | ],
1042 | "execution_count": 0,
1043 | "outputs": []
1044 | },
1045 | {
1046 | "cell_type": "code",
1047 | "metadata": {
1048 | "id": "S5lMKwCQ54KX",
1049 | "colab_type": "code",
1050 | "outputId": "be75d42a-5d10-4dd9-bf66-a780f8a52c0c",
1051 | "colab": {
1052 | "base_uri": "https://localhost:8080/",
1053 | "height": 210
1054 | }
1055 | },
1056 | "source": [
1057 | "import time\n",
1058 | "import wandb\n",
1059 | "\n",
1060 | "wandb.init(entity='authors',project='scl',id=\"ADA\")\n",
1061 | "start = time.time()\n",
1062 | "model.fit(train_ds,\n",
1063 | " validation_data=validation_ds,\n",
1064 | " epochs=50,\n",
1065 | " callbacks=[wandb.keras.WandbCallback(), es])\n",
1066 | "end = time.time()\n",
1067 | "wandb.log({\"training_time\": end - start})"
1068 | ],
1069 | "execution_count": 24,
1070 | "outputs": [
1071 | {
1072 | "output_type": "display_data",
1073 | "data": {
1074 | "text/html": [
1075 | "\n",
1076 | " Logging results to Weights & Biases (Documentation).
\n",
1077 | " Project page: https://app.wandb.ai/authors/scl
\n",
1078 | " Run page: https://app.wandb.ai/authors/scl/runs/ADA
\n",
1079 | " "
1080 | ],
1081 | "text/plain": [
1082 | ""
1083 | ]
1084 | },
1085 | "metadata": {
1086 | "tags": []
1087 | }
1088 | },
1089 | {
1090 | "output_type": "stream",
1091 | "text": [
1092 | "Epoch 1/50\n",
1093 | "20/20 [==============================] - 5s 234ms/step - loss: 1.6424 - sparse_categorical_accuracy: 0.2544 - val_loss: 1.6301 - val_sparse_categorical_accuracy: 0.2480\n",
1094 | "Epoch 2/50\n",
1095 | "20/20 [==============================] - 3s 143ms/step - loss: 1.5880 - sparse_categorical_accuracy: 0.3136 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n",
1096 | "Epoch 3/50\n",
1097 | "20/20 [==============================] - ETA: 0s - loss: 1.5839 - sparse_categorical_accuracy: 0.3200Restoring model weights from the end of the best epoch.\n",
1098 | "20/20 [==============================] - 3s 145ms/step - loss: 1.5839 - sparse_categorical_accuracy: 0.3200 - val_loss: 1.7048 - val_sparse_categorical_accuracy: 0.2000\n",
1099 | "Epoch 00003: early stopping\n"
1100 | ],
1101 | "name": "stdout"
1102 | }
1103 | ]
1104 | },
1105 | {
1106 | "cell_type": "code",
1107 | "metadata": {
1108 | "id": "edc3Fu_C6AJO",
1109 | "colab_type": "code",
1110 | "colab": {}
1111 | },
1112 | "source": [
1113 | "model.save_weights(\"full_supervised_learning.h5\")"
1114 | ],
1115 | "execution_count": 0,
1116 | "outputs": []
1117 | },
1118 | {
1119 | "cell_type": "code",
1120 | "metadata": {
1121 | "id": "wOPN7pPwBN0V",
1122 | "colab_type": "code",
1123 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd",
1124 | "colab": {
1125 | "base_uri": "https://localhost:8080/",
1126 | "height": 34
1127 | }
1128 | },
1129 | "source": [
1130 | "wandb.save(\"full_supervised_learning.h5\")"
1131 | ],
1132 | "execution_count": 0,
1133 | "outputs": [
1134 | {
1135 | "output_type": "execute_result",
1136 | "data": {
1137 | "text/plain": [
1138 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']"
1139 | ]
1140 | },
1141 | "metadata": {
1142 | "tags": []
1143 | },
1144 | "execution_count": 94
1145 | }
1146 | ]
1147 | },
1148 | {
1149 | "cell_type": "code",
1150 | "metadata": {
1151 | "id": "FNk0NhWFBSYe",
1152 | "colab_type": "code",
1153 | "colab": {}
1154 | },
1155 | "source": [
1156 | ""
1157 | ],
1158 | "execution_count": 0,
1159 | "outputs": []
1160 | }
1161 | ]
1162 | }
--------------------------------------------------------------------------------
/ImageNet_Subset/Fully_Supervised_Training_IMGNET_subset_SGD.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Fully_Supervised_Training_IMGNET_subset_SGD.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU",
15 | "widgets": {
16 | "application/vnd.jupyter.widget-state+json": {
17 | "2a8b0bfc58804efea484cd4cd71dd00f": {
18 | "model_module": "@jupyter-widgets/controls",
19 | "model_name": "HBoxModel",
20 | "state": {
21 | "_view_name": "HBoxView",
22 | "_dom_classes": [],
23 | "_model_name": "HBoxModel",
24 | "_view_module": "@jupyter-widgets/controls",
25 | "_model_module_version": "1.5.0",
26 | "_view_count": null,
27 | "_view_module_version": "1.5.0",
28 | "box_style": "",
29 | "layout": "IPY_MODEL_126f1d14800843aa9332f00aead8395b",
30 | "_model_module": "@jupyter-widgets/controls",
31 | "children": [
32 | "IPY_MODEL_805d93bf8f4446d3871cd93392101584",
33 | "IPY_MODEL_33d042e9035942359936f5d4bd6a108a"
34 | ]
35 | }
36 | },
37 | "126f1d14800843aa9332f00aead8395b": {
38 | "model_module": "@jupyter-widgets/base",
39 | "model_name": "LayoutModel",
40 | "state": {
41 | "_view_name": "LayoutView",
42 | "grid_template_rows": null,
43 | "right": null,
44 | "justify_content": null,
45 | "_view_module": "@jupyter-widgets/base",
46 | "overflow": null,
47 | "_model_module_version": "1.2.0",
48 | "_view_count": null,
49 | "flex_flow": null,
50 | "width": null,
51 | "min_width": null,
52 | "border": null,
53 | "align_items": null,
54 | "bottom": null,
55 | "_model_module": "@jupyter-widgets/base",
56 | "top": null,
57 | "grid_column": null,
58 | "overflow_y": null,
59 | "overflow_x": null,
60 | "grid_auto_flow": null,
61 | "grid_area": null,
62 | "grid_template_columns": null,
63 | "flex": null,
64 | "_model_name": "LayoutModel",
65 | "justify_items": null,
66 | "grid_row": null,
67 | "max_height": null,
68 | "align_content": null,
69 | "visibility": null,
70 | "align_self": null,
71 | "height": null,
72 | "min_height": null,
73 | "padding": null,
74 | "grid_auto_rows": null,
75 | "grid_gap": null,
76 | "max_width": null,
77 | "order": null,
78 | "_view_module_version": "1.2.0",
79 | "grid_template_areas": null,
80 | "object_position": null,
81 | "object_fit": null,
82 | "grid_auto_columns": null,
83 | "margin": null,
84 | "display": null,
85 | "left": null
86 | }
87 | },
88 | "805d93bf8f4446d3871cd93392101584": {
89 | "model_module": "@jupyter-widgets/controls",
90 | "model_name": "FloatProgressModel",
91 | "state": {
92 | "_view_name": "ProgressView",
93 | "style": "IPY_MODEL_05b5252d8859432c8724c8fb9636ef33",
94 | "_dom_classes": [],
95 | "description": "100%",
96 | "_model_name": "FloatProgressModel",
97 | "bar_style": "success",
98 | "max": 1250,
99 | "_view_module": "@jupyter-widgets/controls",
100 | "_model_module_version": "1.5.0",
101 | "value": 1250,
102 | "_view_count": null,
103 | "_view_module_version": "1.5.0",
104 | "orientation": "horizontal",
105 | "min": 0,
106 | "description_tooltip": null,
107 | "_model_module": "@jupyter-widgets/controls",
108 | "layout": "IPY_MODEL_77179de1c044473abb9da855d5ab8eca"
109 | }
110 | },
111 | "33d042e9035942359936f5d4bd6a108a": {
112 | "model_module": "@jupyter-widgets/controls",
113 | "model_name": "HTMLModel",
114 | "state": {
115 | "_view_name": "HTMLView",
116 | "style": "IPY_MODEL_eb72220a381347d1bce57c5706554460",
117 | "_dom_classes": [],
118 | "description": "",
119 | "_model_name": "HTMLModel",
120 | "placeholder": "",
121 | "_view_module": "@jupyter-widgets/controls",
122 | "_model_module_version": "1.5.0",
123 | "value": " 1250/1250 [00:06<00:00, 187.65it/s]",
124 | "_view_count": null,
125 | "_view_module_version": "1.5.0",
126 | "description_tooltip": null,
127 | "_model_module": "@jupyter-widgets/controls",
128 | "layout": "IPY_MODEL_fe2c9cb4e208436db3c15d342f8d08e7"
129 | }
130 | },
131 | "05b5252d8859432c8724c8fb9636ef33": {
132 | "model_module": "@jupyter-widgets/controls",
133 | "model_name": "ProgressStyleModel",
134 | "state": {
135 | "_view_name": "StyleView",
136 | "_model_name": "ProgressStyleModel",
137 | "description_width": "initial",
138 | "_view_module": "@jupyter-widgets/base",
139 | "_model_module_version": "1.5.0",
140 | "_view_count": null,
141 | "_view_module_version": "1.2.0",
142 | "bar_color": null,
143 | "_model_module": "@jupyter-widgets/controls"
144 | }
145 | },
146 | "77179de1c044473abb9da855d5ab8eca": {
147 | "model_module": "@jupyter-widgets/base",
148 | "model_name": "LayoutModel",
149 | "state": {
150 | "_view_name": "LayoutView",
151 | "grid_template_rows": null,
152 | "right": null,
153 | "justify_content": null,
154 | "_view_module": "@jupyter-widgets/base",
155 | "overflow": null,
156 | "_model_module_version": "1.2.0",
157 | "_view_count": null,
158 | "flex_flow": null,
159 | "width": null,
160 | "min_width": null,
161 | "border": null,
162 | "align_items": null,
163 | "bottom": null,
164 | "_model_module": "@jupyter-widgets/base",
165 | "top": null,
166 | "grid_column": null,
167 | "overflow_y": null,
168 | "overflow_x": null,
169 | "grid_auto_flow": null,
170 | "grid_area": null,
171 | "grid_template_columns": null,
172 | "flex": null,
173 | "_model_name": "LayoutModel",
174 | "justify_items": null,
175 | "grid_row": null,
176 | "max_height": null,
177 | "align_content": null,
178 | "visibility": null,
179 | "align_self": null,
180 | "height": null,
181 | "min_height": null,
182 | "padding": null,
183 | "grid_auto_rows": null,
184 | "grid_gap": null,
185 | "max_width": null,
186 | "order": null,
187 | "_view_module_version": "1.2.0",
188 | "grid_template_areas": null,
189 | "object_position": null,
190 | "object_fit": null,
191 | "grid_auto_columns": null,
192 | "margin": null,
193 | "display": null,
194 | "left": null
195 | }
196 | },
197 | "eb72220a381347d1bce57c5706554460": {
198 | "model_module": "@jupyter-widgets/controls",
199 | "model_name": "DescriptionStyleModel",
200 | "state": {
201 | "_view_name": "StyleView",
202 | "_model_name": "DescriptionStyleModel",
203 | "description_width": "",
204 | "_view_module": "@jupyter-widgets/base",
205 | "_model_module_version": "1.5.0",
206 | "_view_count": null,
207 | "_view_module_version": "1.2.0",
208 | "_model_module": "@jupyter-widgets/controls"
209 | }
210 | },
211 | "fe2c9cb4e208436db3c15d342f8d08e7": {
212 | "model_module": "@jupyter-widgets/base",
213 | "model_name": "LayoutModel",
214 | "state": {
215 | "_view_name": "LayoutView",
216 | "grid_template_rows": null,
217 | "right": null,
218 | "justify_content": null,
219 | "_view_module": "@jupyter-widgets/base",
220 | "overflow": null,
221 | "_model_module_version": "1.2.0",
222 | "_view_count": null,
223 | "flex_flow": null,
224 | "width": null,
225 | "min_width": null,
226 | "border": null,
227 | "align_items": null,
228 | "bottom": null,
229 | "_model_module": "@jupyter-widgets/base",
230 | "top": null,
231 | "grid_column": null,
232 | "overflow_y": null,
233 | "overflow_x": null,
234 | "grid_auto_flow": null,
235 | "grid_area": null,
236 | "grid_template_columns": null,
237 | "flex": null,
238 | "_model_name": "LayoutModel",
239 | "justify_items": null,
240 | "grid_row": null,
241 | "max_height": null,
242 | "align_content": null,
243 | "visibility": null,
244 | "align_self": null,
245 | "height": null,
246 | "min_height": null,
247 | "padding": null,
248 | "grid_auto_rows": null,
249 | "grid_gap": null,
250 | "max_width": null,
251 | "order": null,
252 | "_view_module_version": "1.2.0",
253 | "grid_template_areas": null,
254 | "object_position": null,
255 | "object_fit": null,
256 | "grid_auto_columns": null,
257 | "margin": null,
258 | "display": null,
259 | "left": null
260 | }
261 | },
262 | "c7e81694439c41408bb1090cb82cd2b1": {
263 | "model_module": "@jupyter-widgets/controls",
264 | "model_name": "HBoxModel",
265 | "state": {
266 | "_view_name": "HBoxView",
267 | "_dom_classes": [],
268 | "_model_name": "HBoxModel",
269 | "_view_module": "@jupyter-widgets/controls",
270 | "_model_module_version": "1.5.0",
271 | "_view_count": null,
272 | "_view_module_version": "1.5.0",
273 | "box_style": "",
274 | "layout": "IPY_MODEL_7683c42a14e14250bf6e375a42028b87",
275 | "_model_module": "@jupyter-widgets/controls",
276 | "children": [
277 | "IPY_MODEL_aac930db883d4977a9f3f08b1a0fa7a5",
278 | "IPY_MODEL_3862f5cde98f49f783d8fc54ada40a11"
279 | ]
280 | }
281 | },
282 | "7683c42a14e14250bf6e375a42028b87": {
283 | "model_module": "@jupyter-widgets/base",
284 | "model_name": "LayoutModel",
285 | "state": {
286 | "_view_name": "LayoutView",
287 | "grid_template_rows": null,
288 | "right": null,
289 | "justify_content": null,
290 | "_view_module": "@jupyter-widgets/base",
291 | "overflow": null,
292 | "_model_module_version": "1.2.0",
293 | "_view_count": null,
294 | "flex_flow": null,
295 | "width": null,
296 | "min_width": null,
297 | "border": null,
298 | "align_items": null,
299 | "bottom": null,
300 | "_model_module": "@jupyter-widgets/base",
301 | "top": null,
302 | "grid_column": null,
303 | "overflow_y": null,
304 | "overflow_x": null,
305 | "grid_auto_flow": null,
306 | "grid_area": null,
307 | "grid_template_columns": null,
308 | "flex": null,
309 | "_model_name": "LayoutModel",
310 | "justify_items": null,
311 | "grid_row": null,
312 | "max_height": null,
313 | "align_content": null,
314 | "visibility": null,
315 | "align_self": null,
316 | "height": null,
317 | "min_height": null,
318 | "padding": null,
319 | "grid_auto_rows": null,
320 | "grid_gap": null,
321 | "max_width": null,
322 | "order": null,
323 | "_view_module_version": "1.2.0",
324 | "grid_template_areas": null,
325 | "object_position": null,
326 | "object_fit": null,
327 | "grid_auto_columns": null,
328 | "margin": null,
329 | "display": null,
330 | "left": null
331 | }
332 | },
333 | "aac930db883d4977a9f3f08b1a0fa7a5": {
334 | "model_module": "@jupyter-widgets/controls",
335 | "model_name": "FloatProgressModel",
336 | "state": {
337 | "_view_name": "ProgressView",
338 | "style": "IPY_MODEL_e9b5eef0caa148a8aaaa2ed576fe68cc",
339 | "_dom_classes": [],
340 | "description": "100%",
341 | "_model_name": "FloatProgressModel",
342 | "bar_style": "success",
343 | "max": 250,
344 | "_view_module": "@jupyter-widgets/controls",
345 | "_model_module_version": "1.5.0",
346 | "value": 250,
347 | "_view_count": null,
348 | "_view_module_version": "1.5.0",
349 | "orientation": "horizontal",
350 | "min": 0,
351 | "description_tooltip": null,
352 | "_model_module": "@jupyter-widgets/controls",
353 | "layout": "IPY_MODEL_2d671cb0ab744e509c98e0d1b2fcac17"
354 | }
355 | },
356 | "3862f5cde98f49f783d8fc54ada40a11": {
357 | "model_module": "@jupyter-widgets/controls",
358 | "model_name": "HTMLModel",
359 | "state": {
360 | "_view_name": "HTMLView",
361 | "style": "IPY_MODEL_79ff0547109643e8b8002a9bf1e6ebcc",
362 | "_dom_classes": [],
363 | "description": "",
364 | "_model_name": "HTMLModel",
365 | "placeholder": "",
366 | "_view_module": "@jupyter-widgets/controls",
367 | "_model_module_version": "1.5.0",
368 | "value": " 250/250 [00:01<00:00, 221.46it/s]",
369 | "_view_count": null,
370 | "_view_module_version": "1.5.0",
371 | "description_tooltip": null,
372 | "_model_module": "@jupyter-widgets/controls",
373 | "layout": "IPY_MODEL_8c382843a4f9475fb1a084b08ea0a7ae"
374 | }
375 | },
376 | "e9b5eef0caa148a8aaaa2ed576fe68cc": {
377 | "model_module": "@jupyter-widgets/controls",
378 | "model_name": "ProgressStyleModel",
379 | "state": {
380 | "_view_name": "StyleView",
381 | "_model_name": "ProgressStyleModel",
382 | "description_width": "initial",
383 | "_view_module": "@jupyter-widgets/base",
384 | "_model_module_version": "1.5.0",
385 | "_view_count": null,
386 | "_view_module_version": "1.2.0",
387 | "bar_color": null,
388 | "_model_module": "@jupyter-widgets/controls"
389 | }
390 | },
391 | "2d671cb0ab744e509c98e0d1b2fcac17": {
392 | "model_module": "@jupyter-widgets/base",
393 | "model_name": "LayoutModel",
394 | "state": {
395 | "_view_name": "LayoutView",
396 | "grid_template_rows": null,
397 | "right": null,
398 | "justify_content": null,
399 | "_view_module": "@jupyter-widgets/base",
400 | "overflow": null,
401 | "_model_module_version": "1.2.0",
402 | "_view_count": null,
403 | "flex_flow": null,
404 | "width": null,
405 | "min_width": null,
406 | "border": null,
407 | "align_items": null,
408 | "bottom": null,
409 | "_model_module": "@jupyter-widgets/base",
410 | "top": null,
411 | "grid_column": null,
412 | "overflow_y": null,
413 | "overflow_x": null,
414 | "grid_auto_flow": null,
415 | "grid_area": null,
416 | "grid_template_columns": null,
417 | "flex": null,
418 | "_model_name": "LayoutModel",
419 | "justify_items": null,
420 | "grid_row": null,
421 | "max_height": null,
422 | "align_content": null,
423 | "visibility": null,
424 | "align_self": null,
425 | "height": null,
426 | "min_height": null,
427 | "padding": null,
428 | "grid_auto_rows": null,
429 | "grid_gap": null,
430 | "max_width": null,
431 | "order": null,
432 | "_view_module_version": "1.2.0",
433 | "grid_template_areas": null,
434 | "object_position": null,
435 | "object_fit": null,
436 | "grid_auto_columns": null,
437 | "margin": null,
438 | "display": null,
439 | "left": null
440 | }
441 | },
442 | "79ff0547109643e8b8002a9bf1e6ebcc": {
443 | "model_module": "@jupyter-widgets/controls",
444 | "model_name": "DescriptionStyleModel",
445 | "state": {
446 | "_view_name": "StyleView",
447 | "_model_name": "DescriptionStyleModel",
448 | "description_width": "",
449 | "_view_module": "@jupyter-widgets/base",
450 | "_model_module_version": "1.5.0",
451 | "_view_count": null,
452 | "_view_module_version": "1.2.0",
453 | "_model_module": "@jupyter-widgets/controls"
454 | }
455 | },
456 | "8c382843a4f9475fb1a084b08ea0a7ae": {
457 | "model_module": "@jupyter-widgets/base",
458 | "model_name": "LayoutModel",
459 | "state": {
460 | "_view_name": "LayoutView",
461 | "grid_template_rows": null,
462 | "right": null,
463 | "justify_content": null,
464 | "_view_module": "@jupyter-widgets/base",
465 | "overflow": null,
466 | "_model_module_version": "1.2.0",
467 | "_view_count": null,
468 | "flex_flow": null,
469 | "width": null,
470 | "min_width": null,
471 | "border": null,
472 | "align_items": null,
473 | "bottom": null,
474 | "_model_module": "@jupyter-widgets/base",
475 | "top": null,
476 | "grid_column": null,
477 | "overflow_y": null,
478 | "overflow_x": null,
479 | "grid_auto_flow": null,
480 | "grid_area": null,
481 | "grid_template_columns": null,
482 | "flex": null,
483 | "_model_name": "LayoutModel",
484 | "justify_items": null,
485 | "grid_row": null,
486 | "max_height": null,
487 | "align_content": null,
488 | "visibility": null,
489 | "align_self": null,
490 | "height": null,
491 | "min_height": null,
492 | "padding": null,
493 | "grid_auto_rows": null,
494 | "grid_gap": null,
495 | "max_width": null,
496 | "order": null,
497 | "_view_module_version": "1.2.0",
498 | "grid_template_areas": null,
499 | "object_position": null,
500 | "object_fit": null,
501 | "grid_auto_columns": null,
502 | "margin": null,
503 | "display": null,
504 | "left": null
505 | }
506 | }
507 | }
508 | }
509 | },
510 | "cells": [
511 | {
512 | "cell_type": "markdown",
513 | "metadata": {
514 | "id": "JuiT6O71HUAy",
515 | "colab_type": "text"
516 | },
517 | "source": [
518 | "# Initial Setup"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "metadata": {
524 | "id": "FgWG4d-K3xRt",
525 | "colab_type": "code",
526 | "outputId": "140af918-8fea-4aac-d44d-80191d1a2877",
527 | "colab": {
528 | "base_uri": "https://localhost:8080/",
529 | "height": 35
530 | }
531 | },
532 | "source": [
533 | "import tensorflow as tf\n",
534 | "print(tf.__version__)"
535 | ],
536 | "execution_count": 1,
537 | "outputs": [
538 | {
539 | "output_type": "stream",
540 | "text": [
541 | "2.2.0\n"
542 | ],
543 | "name": "stdout"
544 | }
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "metadata": {
550 | "id": "gtxvkdsm338L",
551 | "colab_type": "code",
552 | "colab": {
553 | "base_uri": "https://localhost:8080/",
554 | "height": 1000
555 | },
556 | "outputId": "56474563-95cf-45af-a1c0-af5a447d8fa7"
557 | },
558 | "source": [
559 | "!pip install wandb\n",
560 | "import wandb\n",
561 | "wandb.login()"
562 | ],
563 | "execution_count": 2,
564 | "outputs": [
565 | {
566 | "output_type": "stream",
567 | "text": [
568 | "Collecting wandb\n",
569 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d1/c7/8bf2c62c3f133f45e135a8a116e4e0f162043248e3db54de30996eaf1a8a/wandb-0.8.36-py2.py3-none-any.whl (1.4MB)\n",
570 | "\u001b[K |████████████████████████████████| 1.4MB 4.8MB/s \n",
571 | "\u001b[?25hCollecting sentry-sdk>=0.4.0\n",
572 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/1b/95/9a20eebcedab2c1c63fad59fe19a0469edfc2a25b8576497e8084629c2ff/sentry_sdk-0.14.4-py2.py3-none-any.whl (104kB)\n",
573 | "\u001b[K |████████████████████████████████| 112kB 29.8MB/s \n",
574 | "\u001b[?25hRequirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n",
575 | "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n",
576 | "Collecting configparser>=3.8.1\n",
577 | " Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n",
578 | "Requirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n",
579 | "Collecting gql==0.2.0\n",
580 | " Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n",
581 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n",
582 | "Collecting watchdog>=0.8.3\n",
583 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/73/c3/ed6d992006837e011baca89476a4bbffb0a91602432f73bd4473816c76e2/watchdog-0.10.2.tar.gz (95kB)\n",
584 | "\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n",
585 | "\u001b[?25hRequirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n",
586 | "Collecting subprocess32>=3.5.3\n",
587 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n",
588 | "\u001b[K |████████████████████████████████| 102kB 10.6MB/s \n",
589 | "\u001b[?25hCollecting shortuuid>=0.5.0\n",
590 | " Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n",
591 | "Collecting GitPython>=1.0.0\n",
592 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/44/33/917e6fde1cad13daa7053f39b7c8af3be287314f75f1b1ea8d3fe37a8571/GitPython-3.1.2-py3-none-any.whl (451kB)\n",
593 | "\u001b[K |████████████████████████████████| 460kB 23.4MB/s \n",
594 | "\u001b[?25hRequirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n",
595 | "Collecting docker-pycreds>=0.4.0\n",
596 | " Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n",
597 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n",
598 | "Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n",
599 | "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.4.5.1)\n",
600 | "Collecting graphql-core<2,>=0.5.0\n",
601 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n",
602 | "\u001b[K |████████████████████████████████| 71kB 9.4MB/s \n",
603 | "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n",
604 | "Collecting pathtools>=0.1.1\n",
605 | " Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n",
606 | "Collecting gitdb<5,>=4.0.1\n",
607 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
608 | "\u001b[K |████████████████████████████████| 71kB 11.3MB/s \n",
609 | "\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n",
610 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n",
611 | "Collecting smmap<4,>=3.0.1\n",
612 | " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
613 | "Building wheels for collected packages: gql, watchdog, subprocess32, graphql-core, pathtools\n",
614 | " Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
615 | " Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=bd431a7c7f187272be19a9836b1e7616f86e03d34c0a4de4b4f6810140d9de42\n",
616 | " Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n",
617 | " Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
618 | " Created wheel for watchdog: filename=watchdog-0.10.2-cp36-none-any.whl size=73605 sha256=c9fed8e42385522813f4c68267548bd110f3f869f0abc97b6fe982888c9aabf6\n",
619 | " Stored in directory: /root/.cache/pip/wheels/bc/ed/6c/028dea90d31b359cd2a7c8b0da4db80e41d24a59614154072e\n",
620 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
621 | " Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=0dc181757046168379a173ffa1dddec577f6fe87f03c62125ecfbdb3ab14bd4d\n",
622 | " Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n",
623 | " Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
624 | " Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=e4e49bb64de6665036a068a85324e126b1c685b4540c8caa5d1ffb89190c6b50\n",
625 | " Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n",
626 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
627 | " Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=de6e98828915023d5115bd39dd81f22676351c6a70e74bf13ce914cacd0cee48\n",
628 | " Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n",
629 | "Successfully built gql watchdog subprocess32 graphql-core pathtools\n",
630 | "Installing collected packages: sentry-sdk, configparser, graphql-core, gql, pathtools, watchdog, subprocess32, shortuuid, smmap, gitdb, GitPython, docker-pycreds, wandb\n",
631 | "Successfully installed GitPython-3.1.2 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.14.4 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.8.36 watchdog-0.10.2\n"
632 | ],
633 | "name": "stdout"
634 | },
635 | {
636 | "output_type": "display_data",
637 | "data": {
638 | "application/javascript": [
639 | "\n",
640 | " window._wandbApiKey = new Promise((resolve, reject) => {\n",
641 | " function loadScript(url) {\n",
642 | " return new Promise(function(resolve, reject) {\n",
643 | " let newScript = document.createElement(\"script\");\n",
644 | " newScript.onerror = reject;\n",
645 | " newScript.onload = resolve;\n",
646 | " document.body.appendChild(newScript);\n",
647 | " newScript.src = url;\n",
648 | " });\n",
649 | " }\n",
650 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
651 | " const iframe = document.createElement('iframe')\n",
652 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n",
653 | " document.body.appendChild(iframe)\n",
654 | " const handshake = new Postmate({\n",
655 | " container: iframe,\n",
656 | " url: 'https://app.wandb.ai/authorize'\n",
657 | " });\n",
658 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
659 | " handshake.then(function(child) {\n",
660 | " child.on('authorize', data => {\n",
661 | " clearTimeout(timeout)\n",
662 | " resolve(data)\n",
663 | " });\n",
664 | " });\n",
665 | " })\n",
666 | " });\n",
667 | " "
668 | ],
669 | "text/plain": [
670 | ""
671 | ]
672 | },
673 | "metadata": {
674 | "tags": []
675 | }
676 | },
677 | {
678 | "output_type": "stream",
679 | "text": [
680 | "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n"
681 | ],
682 | "name": "stderr"
683 | },
684 | {
685 | "output_type": "stream",
686 | "text": [
687 | "API Key: ··········\n"
688 | ],
689 | "name": "stdout"
690 | },
691 | {
692 | "output_type": "stream",
693 | "text": [
694 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
695 | ],
696 | "name": "stderr"
697 | },
698 | {
699 | "output_type": "execute_result",
700 | "data": {
701 | "text/plain": [
702 | "True"
703 | ]
704 | },
705 | "metadata": {
706 | "tags": []
707 | },
708 | "execution_count": 2
709 | }
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "metadata": {
715 | "id": "E8cv8vit3ydm",
716 | "colab_type": "code",
717 | "colab": {}
718 | },
719 | "source": [
720 | "from tensorflow.keras.layers import *\n",
721 | "from tensorflow.keras.models import *\n",
722 | "from wandb.keras import WandbCallback\n",
723 | "import tensorflow_datasets as tfds\n",
724 | "import matplotlib.pyplot as plt\n",
725 | "import numpy as np\n",
726 | "import time\n",
727 | "import cv2\n",
728 | "from tqdm.notebook import tqdm\n",
729 | "from imutils import paths\n",
730 | "tf.random.set_seed(666)\n",
731 | "np.random.seed(666)\n",
732 | "\n",
733 | "tfds.disable_progress_bar()"
734 | ],
735 | "execution_count": 0,
736 | "outputs": []
737 | },
738 | {
739 | "cell_type": "markdown",
740 | "metadata": {
741 | "id": "ebM6CaFsHcya",
742 | "colab_type": "text"
743 | },
744 | "source": [
745 | "# Imagenet Subset "
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "metadata": {
751 | "id": "4vPz9Alk31qZ",
752 | "colab_type": "code",
753 | "outputId": "14ab0cb9-a64e-4ff7-cdf4-f873752cec89",
754 | "colab": {
755 | "base_uri": "https://localhost:8080/",
756 | "height": 106
757 | }
758 | },
759 | "source": [
760 | "!git clone https://github.com/thunderInfy/imagenet-5-categories\n"
761 | ],
762 | "execution_count": 4,
763 | "outputs": [
764 | {
765 | "output_type": "stream",
766 | "text": [
767 | "Cloning into 'imagenet-5-categories'...\n",
768 | "remote: Enumerating objects: 1532, done.\u001b[K\n",
769 | "remote: Total 1532 (delta 0), reused 0 (delta 0), pack-reused 1532\u001b[K\n",
770 | "Receiving objects: 100% (1532/1532), 88.56 MiB | 51.26 MiB/s, done.\n",
771 | "Resolving deltas: 100% (1/1), done.\n"
772 | ],
773 | "name": "stdout"
774 | }
775 | ]
776 | },
777 | {
778 | "cell_type": "code",
779 | "metadata": {
780 | "id": "5vVPALgj4Ogg",
781 | "colab_type": "code",
782 | "colab": {}
783 | },
784 | "source": [
785 | "# Train and test image paths\n",
786 | "train_images = list(paths.list_images(\"imagenet-5-categories/train\"))\n",
787 | "test_images = list(paths.list_images(\"imagenet-5-categories/test\"))\n"
788 | ],
789 | "execution_count": 0,
790 | "outputs": []
791 | },
792 | {
793 | "cell_type": "code",
794 | "metadata": {
795 | "id": "YM_w3yZi4RQf",
796 | "colab_type": "code",
797 | "colab": {}
798 | },
799 | "source": [
800 | "def prepare_images(image_paths):\n",
801 | " images = []\n",
802 | " labels = []\n",
803 | "\n",
804 | " for image in tqdm(image_paths):\n",
805 | " image_pixels = plt.imread(image)\n",
806 | " image_pixels = cv2.resize(image_pixels, (128,128))\n",
807 | " image_pixels = image_pixels/255.\n",
808 | "\n",
809 | " label = image.split(\"/\")[2].split(\"_\")[0]\n",
810 | "\n",
811 | " images.append(image_pixels)\n",
812 | " labels.append(label)\n",
813 | "\n",
814 | " images = np.array(images)\n",
815 | " labels = np.array(labels)\n",
816 | "\n",
817 | " print(images.shape, labels.shape)\n",
818 | "\n",
819 | " return images, labels"
820 | ],
821 | "execution_count": 0,
822 | "outputs": []
823 | },
824 | {
825 | "cell_type": "code",
826 | "metadata": {
827 | "id": "KeNWTqpG4b0e",
828 | "colab_type": "code",
829 | "outputId": "66c2ec48-d6df-4f2b-a638-db3adb024e79",
830 | "colab": {
831 | "base_uri": "https://localhost:8080/",
832 | "height": 152,
833 | "referenced_widgets": [
834 | "2a8b0bfc58804efea484cd4cd71dd00f",
835 | "126f1d14800843aa9332f00aead8395b",
836 | "805d93bf8f4446d3871cd93392101584",
837 | "33d042e9035942359936f5d4bd6a108a",
838 | "05b5252d8859432c8724c8fb9636ef33",
839 | "77179de1c044473abb9da855d5ab8eca",
840 | "eb72220a381347d1bce57c5706554460",
841 | "fe2c9cb4e208436db3c15d342f8d08e7",
842 | "c7e81694439c41408bb1090cb82cd2b1",
843 | "7683c42a14e14250bf6e375a42028b87",
844 | "aac930db883d4977a9f3f08b1a0fa7a5",
845 | "3862f5cde98f49f783d8fc54ada40a11",
846 | "e9b5eef0caa148a8aaaa2ed576fe68cc",
847 | "2d671cb0ab744e509c98e0d1b2fcac17",
848 | "79ff0547109643e8b8002a9bf1e6ebcc",
849 | "8c382843a4f9475fb1a084b08ea0a7ae"
850 | ]
851 | }
852 | },
853 | "source": [
854 | "X_train, y_train = prepare_images(train_images)\n",
855 | "X_test, y_test = prepare_images(test_images)"
856 | ],
857 | "execution_count": 7,
858 | "outputs": [
859 | {
860 | "output_type": "display_data",
861 | "data": {
862 | "application/vnd.jupyter.widget-view+json": {
863 | "model_id": "2a8b0bfc58804efea484cd4cd71dd00f",
864 | "version_minor": 0,
865 | "version_major": 2
866 | },
867 | "text/plain": [
868 | "HBox(children=(FloatProgress(value=0.0, max=1250.0), HTML(value='')))"
869 | ]
870 | },
871 | "metadata": {
872 | "tags": []
873 | }
874 | },
875 | {
876 | "output_type": "stream",
877 | "text": [
878 | "\n",
879 | "(1250, 128, 128, 3) (1250,)\n"
880 | ],
881 | "name": "stdout"
882 | },
883 | {
884 | "output_type": "display_data",
885 | "data": {
886 | "application/vnd.jupyter.widget-view+json": {
887 | "model_id": "c7e81694439c41408bb1090cb82cd2b1",
888 | "version_minor": 0,
889 | "version_major": 2
890 | },
891 | "text/plain": [
892 | "HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))"
893 | ]
894 | },
895 | "metadata": {
896 | "tags": []
897 | }
898 | },
899 | {
900 | "output_type": "stream",
901 | "text": [
902 | "\n",
903 | "(250, 128, 128, 3) (250,)\n"
904 | ],
905 | "name": "stdout"
906 | }
907 | ]
908 | },
909 | {
910 | "cell_type": "code",
911 | "metadata": {
912 | "id": "qGdwXZJk4eDH",
913 | "colab_type": "code",
914 | "colab": {}
915 | },
916 | "source": [
917 | "from sklearn import preprocessing\n",
918 | "le = preprocessing.LabelEncoder()\n",
919 | "y_train_enc = le.fit_transform(y_train)\n",
920 | "y_test_enc = le.transform(y_test)\n"
921 | ],
922 | "execution_count": 0,
923 | "outputs": []
924 | },
925 | {
926 | "cell_type": "code",
927 | "metadata": {
928 | "id": "nmX3x8wE4zBo",
929 | "colab_type": "code",
930 | "colab": {}
931 | },
932 | "source": [
933 | "train_ds=tf.data.Dataset.from_tensor_slices((X_train,y_train_enc))\n",
934 | "validation_ds=tf.data.Dataset.from_tensor_slices((X_test,y_test_enc))"
935 | ],
936 | "execution_count": 0,
937 | "outputs": []
938 | },
939 | {
940 | "cell_type": "code",
941 | "metadata": {
942 | "id": "9yBgpLe443a-",
943 | "colab_type": "code",
944 | "colab": {}
945 | },
946 | "source": [
947 | "@tf.function\n",
948 | "def aug(image, label):\n",
949 | " x=tf.image.random_brightness(image,max_delta=0)\n",
950 | " x=tf.image.random_contrast(x,lower=0.2, upper=1.8)\n",
951 | " x = tf.image.random_saturation(x, lower=0.2, upper=1.5)\n",
952 | " x = tf.image.random_hue(x, max_delta=0.4)\n",
953 | " x = tf.clip_by_value(x, 0, 1)\n",
954 | "\n",
955 | " return x, label"
956 | ],
957 | "execution_count": 0,
958 | "outputs": []
959 | },
960 | {
961 | "cell_type": "code",
962 | "metadata": {
963 | "id": "icCj5VGk45ce",
964 | "colab_type": "code",
965 | "colab": {}
966 | },
967 | "source": [
968 | "IMG_SHAPE = 128\n",
969 | "BS = 64\n",
970 | "AUTO = tf.data.experimental.AUTOTUNE\n",
971 | "train_ds = (\n",
972 | " train_ds\n",
973 | " .shuffle(100)\n",
974 | " .batch(BS)\n",
975 | " .map(aug, num_parallel_calls=AUTO)\n",
976 | " .prefetch(AUTO)\n",
977 | ")\n",
978 | "validation_ds = (\n",
979 | " validation_ds\n",
980 | " .shuffle(100)\n",
981 | " .batch(BS)\n",
982 | " .prefetch(AUTO)\n",
983 | ")"
984 | ],
985 | "execution_count": 0,
986 | "outputs": []
987 | },
988 | {
989 | "cell_type": "markdown",
990 | "metadata": {
991 | "id": "tkxjWEeIHrCf",
992 | "colab_type": "text"
993 | },
994 | "source": [
995 | "# Model building and training wih SGD\n"
996 | ]
997 | },
998 | {
999 | "cell_type": "code",
1000 | "metadata": {
1001 | "id": "umbRNW-A4755",
1002 | "colab_type": "code",
1003 | "colab": {}
1004 | },
1005 | "source": [
1006 | "resnet50 = tf.keras.applications.ResNet50(weights=None, include_top=False)\n",
1007 | "model = tf.keras.Sequential([resnet50,GlobalAveragePooling2D(),Dropout(0.25),Dense(5,activation='softmax')])"
1008 | ],
1009 | "execution_count": 0,
1010 | "outputs": []
1011 | },
1012 | {
1013 | "cell_type": "code",
1014 | "metadata": {
1015 | "id": "WVilaFIu5Hft",
1016 | "colab_type": "code",
1017 | "colab": {}
1018 | },
1019 | "source": [
1020 | "decay_steps = 1000\n",
1021 | "lr_decayed_fn = tf.keras.experimental.CosineDecay(\n",
1022 | " initial_learning_rate=0.001, decay_steps=decay_steps)\n",
1023 | "\n",
1024 | "model.compile(optimizer=tf.keras.optimizers.SGD(lr_decayed_fn),\n",
1025 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
1026 | " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
1027 | ],
1028 | "execution_count": 0,
1029 | "outputs": []
1030 | },
1031 | {
1032 | "cell_type": "code",
1033 | "metadata": {
1034 | "id": "X3PxnSYd5x2W",
1035 | "colab_type": "code",
1036 | "colab": {}
1037 | },
1038 | "source": [
1039 | "es = tf.keras.callbacks.EarlyStopping(monitor=\"val_sparse_categorical_accuracy\", patience=2,\n",
1040 | "\trestore_best_weights=True, verbose=2)"
1041 | ],
1042 | "execution_count": 0,
1043 | "outputs": []
1044 | },
1045 | {
1046 | "cell_type": "code",
1047 | "metadata": {
1048 | "id": "S5lMKwCQ54KX",
1049 | "colab_type": "code",
1050 | "outputId": "4104b50e-107a-43c5-fda0-6f864532d24d",
1051 | "colab": {
1052 | "base_uri": "https://localhost:8080/",
1053 | "height": 210
1054 | }
1055 | },
1056 | "source": [
1057 | "import time\n",
1058 | "import wandb\n",
1059 | "\n",
1060 | "wandb.init(entity='authors',project='scl',id='cr_Entropy_SGD')\n",
1061 | "start = time.time()\n",
1062 | "model.fit(train_ds,\n",
1063 | " validation_data=validation_ds,\n",
1064 | " epochs=50,\n",
1065 | " callbacks=[wandb.keras.WandbCallback(), es])\n",
1066 | "end = time.time()\n",
1067 | "wandb.log({\"training_time\": end - start})"
1068 | ],
1069 | "execution_count": 18,
1070 | "outputs": [
1071 | {
1072 | "output_type": "display_data",
1073 | "data": {
1074 | "text/html": [
1075 | "\n",
1076 | " Logging results to Weights & Biases (Documentation).
\n",
1077 | " Project page: https://app.wandb.ai/authors/scl
\n",
1078 | " Run page: https://app.wandb.ai/authors/scl/runs/cr_Entropy_SGD
\n",
1079 | " "
1080 | ],
1081 | "text/plain": [
1082 | ""
1083 | ]
1084 | },
1085 | "metadata": {
1086 | "tags": []
1087 | }
1088 | },
1089 | {
1090 | "output_type": "stream",
1091 | "text": [
1092 | "Epoch 1/50\n",
1093 | "20/20 [==============================] - 4s 200ms/step - loss: 1.5290 - sparse_categorical_accuracy: 0.3736 - val_loss: 1.6414 - val_sparse_categorical_accuracy: 0.2640\n",
1094 | "Epoch 2/50\n",
1095 | "20/20 [==============================] - 3s 141ms/step - loss: 1.5102 - sparse_categorical_accuracy: 0.3896 - val_loss: 1.6797 - val_sparse_categorical_accuracy: 0.2240\n",
1096 | "Epoch 3/50\n",
1097 | "20/20 [==============================] - ETA: 0s - loss: 1.5144 - sparse_categorical_accuracy: 0.3888Restoring model weights from the end of the best epoch.\n",
1098 | "20/20 [==============================] - 3s 151ms/step - loss: 1.5144 - sparse_categorical_accuracy: 0.3888 - val_loss: 1.6787 - val_sparse_categorical_accuracy: 0.2280\n",
1099 | "Epoch 00003: early stopping\n"
1100 | ],
1101 | "name": "stdout"
1102 | }
1103 | ]
1104 | },
1105 | {
1106 | "cell_type": "code",
1107 | "metadata": {
1108 | "id": "edc3Fu_C6AJO",
1109 | "colab_type": "code",
1110 | "colab": {}
1111 | },
1112 | "source": [
1113 | "model.save_weights(\"full_supervised_learning.h5\")"
1114 | ],
1115 | "execution_count": 0,
1116 | "outputs": []
1117 | },
1118 | {
1119 | "cell_type": "code",
1120 | "metadata": {
1121 | "id": "wOPN7pPwBN0V",
1122 | "colab_type": "code",
1123 | "outputId": "47e31ee4-110f-44b9-8de8-d377fc2fafbd",
1124 | "colab": {
1125 | "base_uri": "https://localhost:8080/",
1126 | "height": 34
1127 | }
1128 | },
1129 | "source": [
1130 | "wandb.save(\"full_supervised_learning.h5\")"
1131 | ],
1132 | "execution_count": 0,
1133 | "outputs": [
1134 | {
1135 | "output_type": "execute_result",
1136 | "data": {
1137 | "text/plain": [
1138 | "['/content/wandb/run-20200528_111108-2h40mbhd/full_supervised_learning.h5']"
1139 | ]
1140 | },
1141 | "metadata": {
1142 | "tags": []
1143 | },
1144 | "execution_count": 94
1145 | }
1146 | ]
1147 | },
1148 | {
1149 | "cell_type": "code",
1150 | "metadata": {
1151 | "id": "FNk0NhWFBSYe",
1152 | "colab_type": "code",
1153 | "colab": {}
1154 | },
1155 | "source": [
1156 | ""
1157 | ],
1158 | "execution_count": 0,
1159 | "outputs": []
1160 | }
1161 | ]
1162 | }
--------------------------------------------------------------------------------