├── 1080Ti Notebook.ipynb
├── 2080Ti.ipynb
└── README.md
/1080Ti Notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import functools\n",
21 | "import traceback\n",
22 | "def get_ref_free_exc_info():\n",
23 | " \"Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory\"\n",
24 | " type, val, tb = sys.exc_info()\n",
25 | " traceback.clear_frames(tb)\n",
26 | " return (type, val, tb)\n",
27 | "\n",
28 | "def gpu_mem_restore(func):\n",
29 | " \"Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted\"\n",
30 | " @functools.wraps(func)\n",
31 | " def wrapper(*args, **kwargs):\n",
32 | " try:\n",
33 | " return func(*args, **kwargs)\n",
34 | " except:\n",
35 | " type, val, tb = get_ref_free_exc_info() # must!\n",
36 | " raise type(val).with_traceback(tb) from None\n",
37 | " return wrapper"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "from fastai.vision import *\n",
47 | "from fastai.metrics import error_rate"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 4,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "path = Path('/home/ekami/workspace/cifar100/')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 5,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "train = path/'train'\n",
66 | "test = path/'test'"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 6,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "bs = 108"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 7,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "data": {
85 | "text/plain": [
86 | "[PosixPath('/home/ekami/workspace/cifar100/test'),\n",
87 | " PosixPath('/home/ekami/workspace/cifar100/models'),\n",
88 | " PosixPath('/home/ekami/workspace/cifar100/train')]"
89 | ]
90 | },
91 | "execution_count": 7,
92 | "metadata": {},
93 | "output_type": "execute_result"
94 | }
95 | ],
96 | "source": [
97 | "path.ls()"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 8,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "np.random.seed(42)\n",
107 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
108 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=bs).normalize(cifar_stats)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 9,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "class gpu_mem_restore_ctx():\n",
118 | " \" context manager to reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted\"\n",
119 | " def __enter__(self): return self\n",
120 | " def __exit__(self, exc_type, exc_val, exc_tb):\n",
121 | " if not exc_val: return True\n",
122 | " traceback.clear_frames(exc_tb)\n",
123 | " raise exc_type(exc_val).with_traceback(exc_tb) from None"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 10,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "#Allow crashing\n",
133 | "learn = create_cnn(data, models.resnet18 , metrics=accuracy)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "## Resnet 18"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 11,
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "data": {
150 | "text/plain": [
151 | "405"
152 | ]
153 | },
154 | "execution_count": 11,
155 | "metadata": {},
156 | "output_type": "execute_result"
157 | }
158 | ],
159 | "source": [
160 | "int(bs*3.75)"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 12,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "np.random.seed(42)\n",
170 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
171 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=408).normalize(cifar_stats)\n",
172 | "learn = create_cnn(data, models.resnet18, metrics=accuracy)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 13,
178 | "metadata": {},
179 | "outputs": [
180 | {
181 | "data": {
182 | "text/html": [
183 | "Total time: 01:14
\n",
184 | " \n",
185 | " epoch | \n",
186 | " train_loss | \n",
187 | " valid_loss | \n",
188 | " accuracy | \n",
189 | "
\n",
190 | " \n",
191 | " 1 | \n",
192 | " 2.441905 | \n",
193 | " 1.621814 | \n",
194 | " 0.561000 | \n",
195 | "
\n",
196 | "
\n"
197 | ],
198 | "text/plain": [
199 | ""
200 | ]
201 | },
202 | "metadata": {},
203 | "output_type": "display_data"
204 | }
205 | ],
206 | "source": [
207 | "with gpu_mem_restore_ctx():\n",
208 | " learn.fit_one_cycle(1)"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 14,
214 | "metadata": {},
215 | "outputs": [
216 | {
217 | "data": {
218 | "text/html": [
219 | "Total time: 35:26 \n",
220 | " \n",
221 | " epoch | \n",
222 | " train_loss | \n",
223 | " valid_loss | \n",
224 | " accuracy | \n",
225 | "
\n",
226 | " \n",
227 | " 1 | \n",
228 | " 1.686131 | \n",
229 | " 1.220904 | \n",
230 | " 0.651333 | \n",
231 | "
\n",
232 | " \n",
233 | " 2 | \n",
234 | " 1.388124 | \n",
235 | " 0.996638 | \n",
236 | " 0.703500 | \n",
237 | "
\n",
238 | " \n",
239 | " 3 | \n",
240 | " 1.216391 | \n",
241 | " 0.915040 | \n",
242 | " 0.721250 | \n",
243 | "
\n",
244 | " \n",
245 | " 4 | \n",
246 | " 1.165618 | \n",
247 | " 0.937397 | \n",
248 | " 0.716500 | \n",
249 | "
\n",
250 | " \n",
251 | " 5 | \n",
252 | " 1.166939 | \n",
253 | " 1.005142 | \n",
254 | " 0.696083 | \n",
255 | "
\n",
256 | " \n",
257 | " 6 | \n",
258 | " 1.167172 | \n",
259 | " 1.036181 | \n",
260 | " 0.690250 | \n",
261 | "
\n",
262 | " \n",
263 | " 7 | \n",
264 | " 1.155099 | \n",
265 | " 1.098619 | \n",
266 | " 0.678833 | \n",
267 | "
\n",
268 | " \n",
269 | " 8 | \n",
270 | " 1.141750 | \n",
271 | " 1.031859 | \n",
272 | " 0.692500 | \n",
273 | "
\n",
274 | " \n",
275 | " 9 | \n",
276 | " 1.122253 | \n",
277 | " 0.973910 | \n",
278 | " 0.700917 | \n",
279 | "
\n",
280 | " \n",
281 | " 10 | \n",
282 | " 1.078584 | \n",
283 | " 0.987744 | \n",
284 | " 0.700667 | \n",
285 | "
\n",
286 | " \n",
287 | " 11 | \n",
288 | " 1.049187 | \n",
289 | " 0.925535 | \n",
290 | " 0.718917 | \n",
291 | "
\n",
292 | " \n",
293 | " 12 | \n",
294 | " 1.023786 | \n",
295 | " 0.923716 | \n",
296 | " 0.723000 | \n",
297 | "
\n",
298 | " \n",
299 | " 13 | \n",
300 | " 0.994283 | \n",
301 | " 0.950193 | \n",
302 | " 0.714333 | \n",
303 | "
\n",
304 | " \n",
305 | " 14 | \n",
306 | " 0.978886 | \n",
307 | " 0.939824 | \n",
308 | " 0.715583 | \n",
309 | "
\n",
310 | " \n",
311 | " 15 | \n",
312 | " 0.947734 | \n",
313 | " 0.900096 | \n",
314 | " 0.729167 | \n",
315 | "
\n",
316 | " \n",
317 | " 16 | \n",
318 | " 0.912281 | \n",
319 | " 0.827112 | \n",
320 | " 0.747167 | \n",
321 | "
\n",
322 | " \n",
323 | " 17 | \n",
324 | " 0.877508 | \n",
325 | " 0.794671 | \n",
326 | " 0.759000 | \n",
327 | "
\n",
328 | " \n",
329 | " 18 | \n",
330 | " 0.848763 | \n",
331 | " 0.807713 | \n",
332 | " 0.751250 | \n",
333 | "
\n",
334 | " \n",
335 | " 19 | \n",
336 | " 0.832954 | \n",
337 | " 0.764302 | \n",
338 | " 0.766583 | \n",
339 | "
\n",
340 | " \n",
341 | " 20 | \n",
342 | " 0.783972 | \n",
343 | " 0.722574 | \n",
344 | " 0.776667 | \n",
345 | "
\n",
346 | " \n",
347 | " 21 | \n",
348 | " 0.765427 | \n",
349 | " 0.726920 | \n",
350 | " 0.779167 | \n",
351 | "
\n",
352 | " \n",
353 | " 22 | \n",
354 | " 0.740291 | \n",
355 | " 0.716551 | \n",
356 | " 0.779667 | \n",
357 | "
\n",
358 | " \n",
359 | " 23 | \n",
360 | " 0.711279 | \n",
361 | " 0.689046 | \n",
362 | " 0.785083 | \n",
363 | "
\n",
364 | " \n",
365 | " 24 | \n",
366 | " 0.683742 | \n",
367 | " 0.683559 | \n",
368 | " 0.788750 | \n",
369 | "
\n",
370 | " \n",
371 | " 25 | \n",
372 | " 0.657032 | \n",
373 | " 0.672859 | \n",
374 | " 0.791583 | \n",
375 | "
\n",
376 | " \n",
377 | " 26 | \n",
378 | " 0.634756 | \n",
379 | " 0.666527 | \n",
380 | " 0.795333 | \n",
381 | "
\n",
382 | " \n",
383 | " 27 | \n",
384 | " 0.622170 | \n",
385 | " 0.663841 | \n",
386 | " 0.796250 | \n",
387 | "
\n",
388 | " \n",
389 | " 28 | \n",
390 | " 0.613146 | \n",
391 | " 0.658232 | \n",
392 | " 0.797417 | \n",
393 | "
\n",
394 | " \n",
395 | " 29 | \n",
396 | " 0.613548 | \n",
397 | " 0.658717 | \n",
398 | " 0.797667 | \n",
399 | "
\n",
400 | " \n",
401 | " 30 | \n",
402 | " 0.612183 | \n",
403 | " 0.658733 | \n",
404 | " 0.798000 | \n",
405 | "
\n",
406 | "
\n"
407 | ],
408 | "text/plain": [
409 | ""
410 | ]
411 | },
412 | "metadata": {},
413 | "output_type": "display_data"
414 | }
415 | ],
416 | "source": [
417 | "with gpu_mem_restore_ctx():\n",
418 | " learn.fit_one_cycle(30, max_lr=1e-2)"
419 | ]
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {},
424 | "source": [
425 | "## Resnet 18 (Mixed Prec)"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 15,
431 | "metadata": {},
432 | "outputs": [
433 | {
434 | "data": {
435 | "text/plain": [
436 | "108"
437 | ]
438 | },
439 | "execution_count": 15,
440 | "metadata": {},
441 | "output_type": "execute_result"
442 | }
443 | ],
444 | "source": [
445 | "bs"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 16,
451 | "metadata": {},
452 | "outputs": [],
453 | "source": [
454 | "np.random.seed(42)\n",
455 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
456 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=720).normalize(cifar_stats)"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 17,
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "learn = to_fp16(create_cnn(data, models.resnet18, metrics=accuracy))"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 18,
471 | "metadata": {},
472 | "outputs": [
473 | {
474 | "data": {
475 | "text/html": [
476 | "Total time: 01:12 \n",
477 | " \n",
478 | " epoch | \n",
479 | " train_loss | \n",
480 | " valid_loss | \n",
481 | " accuracy | \n",
482 | "
\n",
483 | " \n",
484 | " 1 | \n",
485 | " 2.870623 | \n",
486 | " 1.839730 | \n",
487 | " 0.519083 | \n",
488 | "
\n",
489 | "
\n"
490 | ],
491 | "text/plain": [
492 | ""
493 | ]
494 | },
495 | "metadata": {},
496 | "output_type": "display_data"
497 | }
498 | ],
499 | "source": [
500 | "with gpu_mem_restore_ctx():\n",
501 | " learn.fit_one_cycle(1)"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": 19,
507 | "metadata": {},
508 | "outputs": [
509 | {
510 | "data": {
511 | "text/html": [
512 | "Total time: 34:08 \n",
513 | " \n",
514 | " epoch | \n",
515 | " train_loss | \n",
516 | " valid_loss | \n",
517 | " accuracy | \n",
518 | "
\n",
519 | " \n",
520 | " 1 | \n",
521 | " 2.207675 | \n",
522 | " 1.754365 | \n",
523 | " 0.537417 | \n",
524 | "
\n",
525 | " \n",
526 | " 2 | \n",
527 | " 2.115157 | \n",
528 | " 1.612026 | \n",
529 | " 0.565667 | \n",
530 | "
\n",
531 | " \n",
532 | " 3 | \n",
533 | " 1.955724 | \n",
534 | " 1.431037 | \n",
535 | " 0.600417 | \n",
536 | "
\n",
537 | " \n",
538 | " 4 | \n",
539 | " 1.766447 | \n",
540 | " 1.264467 | \n",
541 | " 0.639083 | \n",
542 | "
\n",
543 | " \n",
544 | " 5 | \n",
545 | " 1.590725 | \n",
546 | " 1.155779 | \n",
547 | " 0.663833 | \n",
548 | "
\n",
549 | " \n",
550 | " 6 | \n",
551 | " 1.441305 | \n",
552 | " 1.066461 | \n",
553 | " 0.682167 | \n",
554 | "
\n",
555 | " \n",
556 | " 7 | \n",
557 | " 1.334857 | \n",
558 | " 1.007020 | \n",
559 | " 0.702500 | \n",
560 | "
\n",
561 | " \n",
562 | " 8 | \n",
563 | " 1.246409 | \n",
564 | " 0.961123 | \n",
565 | " 0.711833 | \n",
566 | "
\n",
567 | " \n",
568 | " 9 | \n",
569 | " 1.181972 | \n",
570 | " 0.911834 | \n",
571 | " 0.725500 | \n",
572 | "
\n",
573 | " \n",
574 | " 10 | \n",
575 | " 1.128325 | \n",
576 | " 0.881794 | \n",
577 | " 0.732333 | \n",
578 | "
\n",
579 | " \n",
580 | " 11 | \n",
581 | " 1.078243 | \n",
582 | " 0.861121 | \n",
583 | " 0.739583 | \n",
584 | "
\n",
585 | " \n",
586 | " 12 | \n",
587 | " 1.031876 | \n",
588 | " 0.839541 | \n",
589 | " 0.744167 | \n",
590 | "
\n",
591 | " \n",
592 | " 13 | \n",
593 | " 1.008928 | \n",
594 | " 0.831407 | \n",
595 | " 0.744000 | \n",
596 | "
\n",
597 | " \n",
598 | " 14 | \n",
599 | " 0.977578 | \n",
600 | " 0.809504 | \n",
601 | " 0.751833 | \n",
602 | "
\n",
603 | " \n",
604 | " 15 | \n",
605 | " 0.945696 | \n",
606 | " 0.797486 | \n",
607 | " 0.757000 | \n",
608 | "
\n",
609 | " \n",
610 | " 16 | \n",
611 | " 0.919172 | \n",
612 | " 0.782975 | \n",
613 | " 0.761917 | \n",
614 | "
\n",
615 | " \n",
616 | " 17 | \n",
617 | " 0.899146 | \n",
618 | " 0.776590 | \n",
619 | " 0.766583 | \n",
620 | "
\n",
621 | " \n",
622 | " 18 | \n",
623 | " 0.882939 | \n",
624 | " 0.767809 | \n",
625 | " 0.766000 | \n",
626 | "
\n",
627 | " \n",
628 | " 19 | \n",
629 | " 0.856347 | \n",
630 | " 0.763404 | \n",
631 | " 0.767417 | \n",
632 | "
\n",
633 | " \n",
634 | " 20 | \n",
635 | " 0.841534 | \n",
636 | " 0.754744 | \n",
637 | " 0.767250 | \n",
638 | "
\n",
639 | " \n",
640 | " 21 | \n",
641 | " 0.817957 | \n",
642 | " 0.747933 | \n",
643 | " 0.771083 | \n",
644 | "
\n",
645 | " \n",
646 | " 22 | \n",
647 | " 0.803184 | \n",
648 | " 0.746400 | \n",
649 | " 0.774833 | \n",
650 | "
\n",
651 | " \n",
652 | " 23 | \n",
653 | " 0.791340 | \n",
654 | " 0.743390 | \n",
655 | " 0.773417 | \n",
656 | "
\n",
657 | " \n",
658 | " 24 | \n",
659 | " 0.781753 | \n",
660 | " 0.739683 | \n",
661 | " 0.774750 | \n",
662 | "
\n",
663 | " \n",
664 | " 25 | \n",
665 | " 0.769866 | \n",
666 | " 0.738879 | \n",
667 | " 0.776167 | \n",
668 | "
\n",
669 | " \n",
670 | " 26 | \n",
671 | " 0.767070 | \n",
672 | " 0.737335 | \n",
673 | " 0.775583 | \n",
674 | "
\n",
675 | " \n",
676 | " 27 | \n",
677 | " 0.755852 | \n",
678 | " 0.735069 | \n",
679 | " 0.775833 | \n",
680 | "
\n",
681 | " \n",
682 | " 28 | \n",
683 | " 0.747134 | \n",
684 | " 0.734977 | \n",
685 | " 0.775833 | \n",
686 | "
\n",
687 | " \n",
688 | " 29 | \n",
689 | " 0.744960 | \n",
690 | " 0.733925 | \n",
691 | " 0.776667 | \n",
692 | "
\n",
693 | " \n",
694 | " 30 | \n",
695 | " 0.742359 | \n",
696 | " 0.735230 | \n",
697 | " 0.776667 | \n",
698 | "
\n",
699 | "
\n"
700 | ],
701 | "text/plain": [
702 | ""
703 | ]
704 | },
705 | "metadata": {},
706 | "output_type": "display_data"
707 | }
708 | ],
709 | "source": [
710 | "with gpu_mem_restore_ctx():\n",
711 | " learn.fit_one_cycle(30)"
712 | ]
713 | },
714 | {
715 | "cell_type": "markdown",
716 | "metadata": {},
717 | "source": [
718 | "## Resnet 34 "
719 | ]
720 | },
721 | {
722 | "cell_type": "code",
723 | "execution_count": 20,
724 | "metadata": {},
725 | "outputs": [],
726 | "source": [
727 | "np.random.seed(42)\n",
728 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
729 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=248).normalize(cifar_stats)"
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "execution_count": 21,
735 | "metadata": {},
736 | "outputs": [
737 | {
738 | "name": "stderr",
739 | "output_type": "stream",
740 | "text": [
741 | "Downloading: \"https://download.pytorch.org/models/resnet34-333f7ec4.pth\" to /home/ekami/.torch/models/resnet34-333f7ec4.pth\n",
742 | "100%|██████████| 87306240/87306240 [00:07<00:00, 11767995.32it/s]\n"
743 | ]
744 | }
745 | ],
746 | "source": [
747 | "learn = create_cnn(data, models.resnet34, metrics=accuracy)"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "execution_count": 22,
753 | "metadata": {},
754 | "outputs": [
755 | {
756 | "data": {
757 | "text/html": [
758 | "Total time: 51:19 \n",
759 | " \n",
760 | " epoch | \n",
761 | " train_loss | \n",
762 | " valid_loss | \n",
763 | " accuracy | \n",
764 | "
\n",
765 | " \n",
766 | " 1 | \n",
767 | " 4.100102 | \n",
768 | " 3.264899 | \n",
769 | " 0.285417 | \n",
770 | "
\n",
771 | " \n",
772 | " 2 | \n",
773 | " 2.793747 | \n",
774 | " 1.978519 | \n",
775 | " 0.517667 | \n",
776 | "
\n",
777 | " \n",
778 | " 3 | \n",
779 | " 1.910982 | \n",
780 | " 1.300082 | \n",
781 | " 0.636833 | \n",
782 | "
\n",
783 | " \n",
784 | " 4 | \n",
785 | " 1.455095 | \n",
786 | " 0.999775 | \n",
787 | " 0.702500 | \n",
788 | "
\n",
789 | " \n",
790 | " 5 | \n",
791 | " 1.241369 | \n",
792 | " 0.867122 | \n",
793 | " 0.738250 | \n",
794 | "
\n",
795 | " \n",
796 | " 6 | \n",
797 | " 1.109905 | \n",
798 | " 0.796427 | \n",
799 | " 0.754583 | \n",
800 | "
\n",
801 | " \n",
802 | " 7 | \n",
803 | " 1.048591 | \n",
804 | " 0.770618 | \n",
805 | " 0.762250 | \n",
806 | "
\n",
807 | " \n",
808 | " 8 | \n",
809 | " 0.979058 | \n",
810 | " 0.730820 | \n",
811 | " 0.773833 | \n",
812 | "
\n",
813 | " \n",
814 | " 9 | \n",
815 | " 0.928035 | \n",
816 | " 0.709470 | \n",
817 | " 0.780250 | \n",
818 | "
\n",
819 | " \n",
820 | " 10 | \n",
821 | " 0.890598 | \n",
822 | " 0.684087 | \n",
823 | " 0.786833 | \n",
824 | "
\n",
825 | " \n",
826 | " 11 | \n",
827 | " 0.864331 | \n",
828 | " 0.679980 | \n",
829 | " 0.785333 | \n",
830 | "
\n",
831 | " \n",
832 | " 12 | \n",
833 | " 0.813269 | \n",
834 | " 0.656503 | \n",
835 | " 0.796083 | \n",
836 | "
\n",
837 | " \n",
838 | " 13 | \n",
839 | " 0.793111 | \n",
840 | " 0.645735 | \n",
841 | " 0.799333 | \n",
842 | "
\n",
843 | " \n",
844 | " 14 | \n",
845 | " 0.758749 | \n",
846 | " 0.637417 | \n",
847 | " 0.799167 | \n",
848 | "
\n",
849 | " \n",
850 | " 15 | \n",
851 | " 0.733792 | \n",
852 | " 0.624778 | \n",
853 | " 0.803500 | \n",
854 | "
\n",
855 | " \n",
856 | " 16 | \n",
857 | " 0.713488 | \n",
858 | " 0.622848 | \n",
859 | " 0.804250 | \n",
860 | "
\n",
861 | " \n",
862 | " 17 | \n",
863 | " 0.695918 | \n",
864 | " 0.618520 | \n",
865 | " 0.805333 | \n",
866 | "
\n",
867 | " \n",
868 | " 18 | \n",
869 | " 0.670478 | \n",
870 | " 0.600957 | \n",
871 | " 0.812333 | \n",
872 | "
\n",
873 | " \n",
874 | " 19 | \n",
875 | " 0.655789 | \n",
876 | " 0.600830 | \n",
877 | " 0.811750 | \n",
878 | "
\n",
879 | " \n",
880 | " 20 | \n",
881 | " 0.638686 | \n",
882 | " 0.596192 | \n",
883 | " 0.814667 | \n",
884 | "
\n",
885 | " \n",
886 | " 21 | \n",
887 | " 0.617020 | \n",
888 | " 0.593554 | \n",
889 | " 0.813500 | \n",
890 | "
\n",
891 | " \n",
892 | " 22 | \n",
893 | " 0.598800 | \n",
894 | " 0.587362 | \n",
895 | " 0.814333 | \n",
896 | "
\n",
897 | " \n",
898 | " 23 | \n",
899 | " 0.588485 | \n",
900 | " 0.585536 | \n",
901 | " 0.816500 | \n",
902 | "
\n",
903 | " \n",
904 | " 24 | \n",
905 | " 0.560206 | \n",
906 | " 0.583934 | \n",
907 | " 0.818000 | \n",
908 | "
\n",
909 | " \n",
910 | " 25 | \n",
911 | " 0.552242 | \n",
912 | " 0.579802 | \n",
913 | " 0.818750 | \n",
914 | "
\n",
915 | " \n",
916 | " 26 | \n",
917 | " 0.544388 | \n",
918 | " 0.576987 | \n",
919 | " 0.820917 | \n",
920 | "
\n",
921 | " \n",
922 | " 27 | \n",
923 | " 0.544016 | \n",
924 | " 0.577123 | \n",
925 | " 0.820667 | \n",
926 | "
\n",
927 | " \n",
928 | " 28 | \n",
929 | " 0.544756 | \n",
930 | " 0.576735 | \n",
931 | " 0.819333 | \n",
932 | "
\n",
933 | " \n",
934 | " 29 | \n",
935 | " 0.535235 | \n",
936 | " 0.576290 | \n",
937 | " 0.820167 | \n",
938 | "
\n",
939 | " \n",
940 | " 30 | \n",
941 | " 0.533571 | \n",
942 | " 0.576551 | \n",
943 | " 0.819833 | \n",
944 | "
\n",
945 | "
\n"
946 | ],
947 | "text/plain": [
948 | ""
949 | ]
950 | },
951 | "metadata": {},
952 | "output_type": "display_data"
953 | }
954 | ],
955 | "source": [
956 | "with gpu_mem_restore_ctx():\n",
957 | " learn.fit_one_cycle(30)"
958 | ]
959 | },
960 | {
961 | "cell_type": "markdown",
962 | "metadata": {},
963 | "source": [
964 | "## Resnet 34 (Mixed Precision)"
965 | ]
966 | },
967 | {
968 | "cell_type": "code",
969 | "execution_count": 23,
970 | "metadata": {},
971 | "outputs": [],
972 | "source": [
973 | "np.random.seed(42)\n",
974 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
975 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=496).normalize(cifar_stats)"
976 | ]
977 | },
978 | {
979 | "cell_type": "code",
980 | "execution_count": 24,
981 | "metadata": {},
982 | "outputs": [],
983 | "source": [
984 | "learn = to_fp16(create_cnn(data, models.resnet34, metrics=accuracy))"
985 | ]
986 | },
987 | {
988 | "cell_type": "code",
989 | "execution_count": 25,
990 | "metadata": {},
991 | "outputs": [
992 | {
993 | "data": {
994 | "text/html": [
995 | "Total time: 44:14 \n",
996 | " \n",
997 | " epoch | \n",
998 | " train_loss | \n",
999 | " valid_loss | \n",
1000 | " accuracy | \n",
1001 | "
\n",
1002 | " \n",
1003 | " 1 | \n",
1004 | " 4.522862 | \n",
1005 | " 3.640730 | \n",
1006 | " 0.200583 | \n",
1007 | "
\n",
1008 | " \n",
1009 | " 2 | \n",
1010 | " 3.459070 | \n",
1011 | " 2.374401 | \n",
1012 | " 0.443583 | \n",
1013 | "
\n",
1014 | " \n",
1015 | " 3 | \n",
1016 | " 2.469919 | \n",
1017 | " 1.560894 | \n",
1018 | " 0.582833 | \n",
1019 | "
\n",
1020 | " \n",
1021 | " 4 | \n",
1022 | " 1.817846 | \n",
1023 | " 1.165402 | \n",
1024 | " 0.662000 | \n",
1025 | "
\n",
1026 | " \n",
1027 | " 5 | \n",
1028 | " 1.449373 | \n",
1029 | " 0.959060 | \n",
1030 | " 0.714667 | \n",
1031 | "
\n",
1032 | " \n",
1033 | " 6 | \n",
1034 | " 1.236690 | \n",
1035 | " 0.863116 | \n",
1036 | " 0.740500 | \n",
1037 | "
\n",
1038 | " \n",
1039 | " 7 | \n",
1040 | " 1.113157 | \n",
1041 | " 0.795895 | \n",
1042 | " 0.756417 | \n",
1043 | "
\n",
1044 | " \n",
1045 | " 8 | \n",
1046 | " 1.027290 | \n",
1047 | " 0.765271 | \n",
1048 | " 0.763417 | \n",
1049 | "
\n",
1050 | " \n",
1051 | " 9 | \n",
1052 | " 0.967804 | \n",
1053 | " 0.721662 | \n",
1054 | " 0.774667 | \n",
1055 | "
\n",
1056 | " \n",
1057 | " 10 | \n",
1058 | " 0.904266 | \n",
1059 | " 0.705837 | \n",
1060 | " 0.780333 | \n",
1061 | "
\n",
1062 | " \n",
1063 | " 11 | \n",
1064 | " 0.865761 | \n",
1065 | " 0.687858 | \n",
1066 | " 0.786917 | \n",
1067 | "
\n",
1068 | " \n",
1069 | " 12 | \n",
1070 | " 0.844098 | \n",
1071 | " 0.668408 | \n",
1072 | " 0.791083 | \n",
1073 | "
\n",
1074 | " \n",
1075 | " 13 | \n",
1076 | " 0.802036 | \n",
1077 | " 0.665113 | \n",
1078 | " 0.791667 | \n",
1079 | "
\n",
1080 | " \n",
1081 | " 14 | \n",
1082 | " 0.783970 | \n",
1083 | " 0.649133 | \n",
1084 | " 0.798250 | \n",
1085 | "
\n",
1086 | " \n",
1087 | " 15 | \n",
1088 | " 0.745470 | \n",
1089 | " 0.634941 | \n",
1090 | " 0.805250 | \n",
1091 | "
\n",
1092 | " \n",
1093 | " 16 | \n",
1094 | " 0.727000 | \n",
1095 | " 0.629780 | \n",
1096 | " 0.804333 | \n",
1097 | "
\n",
1098 | " \n",
1099 | " 17 | \n",
1100 | " 0.709905 | \n",
1101 | " 0.619154 | \n",
1102 | " 0.805417 | \n",
1103 | "
\n",
1104 | " \n",
1105 | " 18 | \n",
1106 | " 0.688805 | \n",
1107 | " 0.614638 | \n",
1108 | " 0.810250 | \n",
1109 | "
\n",
1110 | " \n",
1111 | " 19 | \n",
1112 | " 0.673002 | \n",
1113 | " 0.612835 | \n",
1114 | " 0.808083 | \n",
1115 | "
\n",
1116 | " \n",
1117 | " 20 | \n",
1118 | " 0.647279 | \n",
1119 | " 0.606833 | \n",
1120 | " 0.811917 | \n",
1121 | "
\n",
1122 | " \n",
1123 | " 21 | \n",
1124 | " 0.630066 | \n",
1125 | " 0.601341 | \n",
1126 | " 0.813167 | \n",
1127 | "
\n",
1128 | " \n",
1129 | " 22 | \n",
1130 | " 0.615739 | \n",
1131 | " 0.599353 | \n",
1132 | " 0.814083 | \n",
1133 | "
\n",
1134 | " \n",
1135 | " 23 | \n",
1136 | " 0.600889 | \n",
1137 | " 0.592636 | \n",
1138 | " 0.815833 | \n",
1139 | "
\n",
1140 | " \n",
1141 | " 24 | \n",
1142 | " 0.595916 | \n",
1143 | " 0.593515 | \n",
1144 | " 0.817667 | \n",
1145 | "
\n",
1146 | " \n",
1147 | " 25 | \n",
1148 | " 0.578139 | \n",
1149 | " 0.589500 | \n",
1150 | " 0.815500 | \n",
1151 | "
\n",
1152 | " \n",
1153 | " 26 | \n",
1154 | " 0.576346 | \n",
1155 | " 0.590978 | \n",
1156 | " 0.815750 | \n",
1157 | "
\n",
1158 | " \n",
1159 | " 27 | \n",
1160 | " 0.571478 | \n",
1161 | " 0.586962 | \n",
1162 | " 0.818833 | \n",
1163 | "
\n",
1164 | " \n",
1165 | " 28 | \n",
1166 | " 0.560052 | \n",
1167 | " 0.587171 | \n",
1168 | " 0.818000 | \n",
1169 | "
\n",
1170 | " \n",
1171 | " 29 | \n",
1172 | " 0.556342 | \n",
1173 | " 0.586621 | \n",
1174 | " 0.818083 | \n",
1175 | "
\n",
1176 | " \n",
1177 | " 30 | \n",
1178 | " 0.552511 | \n",
1179 | " 0.587052 | \n",
1180 | " 0.818333 | \n",
1181 | "
\n",
1182 | "
\n"
1183 | ],
1184 | "text/plain": [
1185 | ""
1186 | ]
1187 | },
1188 | "metadata": {},
1189 | "output_type": "display_data"
1190 | }
1191 | ],
1192 | "source": [
1193 | "with gpu_mem_restore_ctx():\n",
1194 | " learn.fit_one_cycle(30)"
1195 | ]
1196 | },
1197 | {
1198 | "cell_type": "markdown",
1199 | "metadata": {},
1200 | "source": [
1201 | "## Resnet 50"
1202 | ]
1203 | },
1204 | {
1205 | "cell_type": "code",
1206 | "execution_count": 26,
1207 | "metadata": {},
1208 | "outputs": [],
1209 | "source": [
1210 | "#bs = 512"
1211 | ]
1212 | },
1213 | {
1214 | "cell_type": "code",
1215 | "execution_count": 27,
1216 | "metadata": {},
1217 | "outputs": [],
1218 | "source": [
1219 | "np.random.seed(42)\n",
1220 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1221 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=98).normalize(cifar_stats)"
1222 | ]
1223 | },
1224 | {
1225 | "cell_type": "code",
1226 | "execution_count": 28,
1227 | "metadata": {},
1228 | "outputs": [
1229 | {
1230 | "name": "stderr",
1231 | "output_type": "stream",
1232 | "text": [
1233 | "Downloading: \"https://download.pytorch.org/models/resnet50-19c8e357.pth\" to /home/ekami/.torch/models/resnet50-19c8e357.pth\n",
1234 | "100%|██████████| 102502400/102502400 [00:19<00:00, 5150001.64it/s]\n"
1235 | ]
1236 | }
1237 | ],
1238 | "source": [
1239 | "learn = create_cnn(data, models.resnet50, metrics=accuracy)"
1240 | ]
1241 | },
1242 | {
1243 | "cell_type": "code",
1244 | "execution_count": 29,
1245 | "metadata": {},
1246 | "outputs": [
1247 | {
1248 | "data": {
1249 | "text/html": [
1250 | "Total time: 1:45:22 \n",
1251 | " \n",
1252 | " epoch | \n",
1253 | " train_loss | \n",
1254 | " valid_loss | \n",
1255 | " accuracy | \n",
1256 | "
\n",
1257 | " \n",
1258 | " 1 | \n",
1259 | " 2.515773 | \n",
1260 | " 1.886690 | \n",
1261 | " 0.541167 | \n",
1262 | "
\n",
1263 | " \n",
1264 | " 2 | \n",
1265 | " 1.696823 | \n",
1266 | " 1.202539 | \n",
1267 | " 0.663750 | \n",
1268 | "
\n",
1269 | " \n",
1270 | " 3 | \n",
1271 | " 1.346115 | \n",
1272 | " 0.940126 | \n",
1273 | " 0.719500 | \n",
1274 | "
\n",
1275 | " \n",
1276 | " 4 | \n",
1277 | " 1.105000 | \n",
1278 | " 0.807197 | \n",
1279 | " 0.751083 | \n",
1280 | "
\n",
1281 | " \n",
1282 | " 5 | \n",
1283 | " 1.012399 | \n",
1284 | " 0.757759 | \n",
1285 | " 0.765250 | \n",
1286 | "
\n",
1287 | " \n",
1288 | " 6 | \n",
1289 | " 0.941379 | \n",
1290 | " 0.734523 | \n",
1291 | " 0.770833 | \n",
1292 | "
\n",
1293 | " \n",
1294 | " 7 | \n",
1295 | " 0.891324 | \n",
1296 | " 0.690761 | \n",
1297 | " 0.783500 | \n",
1298 | "
\n",
1299 | " \n",
1300 | " 8 | \n",
1301 | " 0.819676 | \n",
1302 | " 0.679745 | \n",
1303 | " 0.790583 | \n",
1304 | "
\n",
1305 | " \n",
1306 | " 9 | \n",
1307 | " 0.785965 | \n",
1308 | " 0.657922 | \n",
1309 | " 0.794000 | \n",
1310 | "
\n",
1311 | " \n",
1312 | " 10 | \n",
1313 | " 0.737616 | \n",
1314 | " 0.631472 | \n",
1315 | " 0.806583 | \n",
1316 | "
\n",
1317 | " \n",
1318 | " 11 | \n",
1319 | " 0.716565 | \n",
1320 | " 0.626581 | \n",
1321 | " 0.805500 | \n",
1322 | "
\n",
1323 | " \n",
1324 | " 12 | \n",
1325 | " 0.703448 | \n",
1326 | " 0.611971 | \n",
1327 | " 0.810333 | \n",
1328 | "
\n",
1329 | " \n",
1330 | " 13 | \n",
1331 | " 0.643908 | \n",
1332 | " 0.610946 | \n",
1333 | " 0.809333 | \n",
1334 | "
\n",
1335 | " \n",
1336 | " 14 | \n",
1337 | " 0.619197 | \n",
1338 | " 0.596491 | \n",
1339 | " 0.818167 | \n",
1340 | "
\n",
1341 | " \n",
1342 | " 15 | \n",
1343 | " 0.588613 | \n",
1344 | " 0.582370 | \n",
1345 | " 0.821167 | \n",
1346 | "
\n",
1347 | " \n",
1348 | " 16 | \n",
1349 | " 0.572277 | \n",
1350 | " 0.584467 | \n",
1351 | " 0.823833 | \n",
1352 | "
\n",
1353 | " \n",
1354 | " 17 | \n",
1355 | " 0.537526 | \n",
1356 | " 0.581618 | \n",
1357 | " 0.819917 | \n",
1358 | "
\n",
1359 | " \n",
1360 | " 18 | \n",
1361 | " 0.520785 | \n",
1362 | " 0.573473 | \n",
1363 | " 0.824500 | \n",
1364 | "
\n",
1365 | " \n",
1366 | " 19 | \n",
1367 | " 0.479418 | \n",
1368 | " 0.570315 | \n",
1369 | " 0.828000 | \n",
1370 | "
\n",
1371 | " \n",
1372 | " 20 | \n",
1373 | " 0.462298 | \n",
1374 | " 0.565643 | \n",
1375 | " 0.826917 | \n",
1376 | "
\n",
1377 | " \n",
1378 | " 21 | \n",
1379 | " 0.447708 | \n",
1380 | " 0.562472 | \n",
1381 | " 0.830500 | \n",
1382 | "
\n",
1383 | " \n",
1384 | " 22 | \n",
1385 | " 0.413717 | \n",
1386 | " 0.560359 | \n",
1387 | " 0.833417 | \n",
1388 | "
\n",
1389 | " \n",
1390 | " 23 | \n",
1391 | " 0.376912 | \n",
1392 | " 0.557939 | \n",
1393 | " 0.834167 | \n",
1394 | "
\n",
1395 | " \n",
1396 | " 24 | \n",
1397 | " 0.372460 | \n",
1398 | " 0.559411 | \n",
1399 | " 0.835167 | \n",
1400 | "
\n",
1401 | " \n",
1402 | " 25 | \n",
1403 | " 0.367571 | \n",
1404 | " 0.556380 | \n",
1405 | " 0.837167 | \n",
1406 | "
\n",
1407 | " \n",
1408 | " 26 | \n",
1409 | " 0.347707 | \n",
1410 | " 0.554721 | \n",
1411 | " 0.837667 | \n",
1412 | "
\n",
1413 | " \n",
1414 | " 27 | \n",
1415 | " 0.334381 | \n",
1416 | " 0.551504 | \n",
1417 | " 0.838417 | \n",
1418 | "
\n",
1419 | " \n",
1420 | " 28 | \n",
1421 | " 0.351300 | \n",
1422 | " 0.549640 | \n",
1423 | " 0.838250 | \n",
1424 | "
\n",
1425 | " \n",
1426 | " 29 | \n",
1427 | " 0.335910 | \n",
1428 | " 0.549918 | \n",
1429 | " 0.837000 | \n",
1430 | "
\n",
1431 | " \n",
1432 | " 30 | \n",
1433 | " 0.340140 | \n",
1434 | " 0.553000 | \n",
1435 | " 0.836417 | \n",
1436 | "
\n",
1437 | "
\n"
1438 | ],
1439 | "text/plain": [
1440 | ""
1441 | ]
1442 | },
1443 | "metadata": {},
1444 | "output_type": "display_data"
1445 | }
1446 | ],
1447 | "source": [
1448 | "with gpu_mem_restore_ctx():\n",
1449 | " learn.fit_one_cycle(30)"
1450 | ]
1451 | },
1452 | {
1453 | "cell_type": "markdown",
1454 | "metadata": {},
1455 | "source": [
1456 | "## Resnet 50 (Mixed Precision)"
1457 | ]
1458 | },
1459 | {
1460 | "cell_type": "code",
1461 | "execution_count": 30,
1462 | "metadata": {},
1463 | "outputs": [],
1464 | "source": [
1465 | "np.random.seed(42)\n",
1466 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1467 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=164).normalize(cifar_stats)"
1468 | ]
1469 | },
1470 | {
1471 | "cell_type": "code",
1472 | "execution_count": 31,
1473 | "metadata": {},
1474 | "outputs": [],
1475 | "source": [
1476 | "learn = to_fp16(create_cnn(data, models.resnet50, metrics=accuracy))"
1477 | ]
1478 | },
1479 | {
1480 | "cell_type": "code",
1481 | "execution_count": 32,
1482 | "metadata": {},
1483 | "outputs": [
1484 | {
1485 | "data": {
1486 | "text/html": [
1487 | "Total time: 1:24:52 \n",
1488 | " \n",
1489 | " epoch | \n",
1490 | " train_loss | \n",
1491 | " valid_loss | \n",
1492 | " accuracy | \n",
1493 | "
\n",
1494 | " \n",
1495 | " 1 | \n",
1496 | " 2.834446 | \n",
1497 | " 2.072671 | \n",
1498 | " 0.511667 | \n",
1499 | "
\n",
1500 | " \n",
1501 | " 2 | \n",
1502 | " 1.883764 | \n",
1503 | " 1.311698 | \n",
1504 | " 0.639583 | \n",
1505 | "
\n",
1506 | " \n",
1507 | " 3 | \n",
1508 | " 1.395265 | \n",
1509 | " 0.999120 | \n",
1510 | " 0.704250 | \n",
1511 | "
\n",
1512 | " \n",
1513 | " 4 | \n",
1514 | " 1.161582 | \n",
1515 | " 0.848280 | \n",
1516 | " 0.743583 | \n",
1517 | "
\n",
1518 | " \n",
1519 | " 5 | \n",
1520 | " 1.024427 | \n",
1521 | " 0.772155 | \n",
1522 | " 0.759583 | \n",
1523 | "
\n",
1524 | " \n",
1525 | " 6 | \n",
1526 | " 0.940368 | \n",
1527 | " 0.736896 | \n",
1528 | " 0.770417 | \n",
1529 | "
\n",
1530 | " \n",
1531 | " 7 | \n",
1532 | " 0.892585 | \n",
1533 | " 0.707394 | \n",
1534 | " 0.776083 | \n",
1535 | "
\n",
1536 | " \n",
1537 | " 8 | \n",
1538 | " 0.829782 | \n",
1539 | " 0.685450 | \n",
1540 | " 0.790000 | \n",
1541 | "
\n",
1542 | " \n",
1543 | " 9 | \n",
1544 | " 0.780883 | \n",
1545 | " 0.662696 | \n",
1546 | " 0.798583 | \n",
1547 | "
\n",
1548 | " \n",
1549 | " 10 | \n",
1550 | " 0.755269 | \n",
1551 | " 0.638769 | \n",
1552 | " 0.802083 | \n",
1553 | "
\n",
1554 | " \n",
1555 | " 11 | \n",
1556 | " 0.695746 | \n",
1557 | " 0.623059 | \n",
1558 | " 0.809167 | \n",
1559 | "
\n",
1560 | " \n",
1561 | " 12 | \n",
1562 | " 0.674482 | \n",
1563 | " 0.617339 | \n",
1564 | " 0.806833 | \n",
1565 | "
\n",
1566 | " \n",
1567 | " 13 | \n",
1568 | " 0.614091 | \n",
1569 | " 0.608190 | \n",
1570 | " 0.814583 | \n",
1571 | "
\n",
1572 | " \n",
1573 | " 14 | \n",
1574 | " 0.609434 | \n",
1575 | " 0.598253 | \n",
1576 | " 0.812250 | \n",
1577 | "
\n",
1578 | " \n",
1579 | " 15 | \n",
1580 | " 0.561976 | \n",
1581 | " 0.589217 | \n",
1582 | " 0.818333 | \n",
1583 | "
\n",
1584 | " \n",
1585 | " 16 | \n",
1586 | " 0.544676 | \n",
1587 | " 0.583386 | \n",
1588 | " 0.821833 | \n",
1589 | "
\n",
1590 | " \n",
1591 | " 17 | \n",
1592 | " 0.503154 | \n",
1593 | " 0.581094 | \n",
1594 | " 0.821750 | \n",
1595 | "
\n",
1596 | " \n",
1597 | " 18 | \n",
1598 | " 0.489639 | \n",
1599 | " 0.572993 | \n",
1600 | " 0.823667 | \n",
1601 | "
\n",
1602 | " \n",
1603 | " 19 | \n",
1604 | " 0.469337 | \n",
1605 | " 0.574861 | \n",
1606 | " 0.826917 | \n",
1607 | "
\n",
1608 | " \n",
1609 | " 20 | \n",
1610 | " 0.445635 | \n",
1611 | " 0.570555 | \n",
1612 | " 0.829500 | \n",
1613 | "
\n",
1614 | " \n",
1615 | " 21 | \n",
1616 | " 0.421942 | \n",
1617 | " 0.570226 | \n",
1618 | " 0.830333 | \n",
1619 | "
\n",
1620 | " \n",
1621 | " 22 | \n",
1622 | " 0.395235 | \n",
1623 | " 0.558666 | \n",
1624 | " 0.831833 | \n",
1625 | "
\n",
1626 | " \n",
1627 | " 23 | \n",
1628 | " 0.376425 | \n",
1629 | " 0.563551 | \n",
1630 | " 0.835083 | \n",
1631 | "
\n",
1632 | " \n",
1633 | " 24 | \n",
1634 | " 0.366422 | \n",
1635 | " 0.559915 | \n",
1636 | " 0.835250 | \n",
1637 | "
\n",
1638 | " \n",
1639 | " 25 | \n",
1640 | " 0.344141 | \n",
1641 | " 0.557956 | \n",
1642 | " 0.835000 | \n",
1643 | "
\n",
1644 | " \n",
1645 | " 26 | \n",
1646 | " 0.328442 | \n",
1647 | " 0.559216 | \n",
1648 | " 0.835250 | \n",
1649 | "
\n",
1650 | " \n",
1651 | " 27 | \n",
1652 | " 0.322673 | \n",
1653 | " 0.556995 | \n",
1654 | " 0.835167 | \n",
1655 | "
\n",
1656 | " \n",
1657 | " 28 | \n",
1658 | " 0.324144 | \n",
1659 | " 0.557831 | \n",
1660 | " 0.835250 | \n",
1661 | "
\n",
1662 | " \n",
1663 | " 29 | \n",
1664 | " 0.320713 | \n",
1665 | " 0.555050 | \n",
1666 | " 0.836333 | \n",
1667 | "
\n",
1668 | " \n",
1669 | " 30 | \n",
1670 | " 0.327241 | \n",
1671 | " 0.557459 | \n",
1672 | " 0.836417 | \n",
1673 | "
\n",
1674 | "
\n"
1675 | ],
1676 | "text/plain": [
1677 | ""
1678 | ]
1679 | },
1680 | "metadata": {},
1681 | "output_type": "display_data"
1682 | }
1683 | ],
1684 | "source": [
1685 | "with gpu_mem_restore_ctx():\n",
1686 | " learn.fit_one_cycle(30)"
1687 | ]
1688 | },
1689 | {
1690 | "cell_type": "markdown",
1691 | "metadata": {},
1692 | "source": [
1693 | "## Resnet 101"
1694 | ]
1695 | },
1696 | {
1697 | "cell_type": "code",
1698 | "execution_count": 33,
1699 | "metadata": {},
1700 | "outputs": [],
1701 | "source": [
1702 | "np.random.seed(42)\n",
1703 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1704 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=64).normalize(cifar_stats)"
1705 | ]
1706 | },
1707 | {
1708 | "cell_type": "code",
1709 | "execution_count": 34,
1710 | "metadata": {},
1711 | "outputs": [
1712 | {
1713 | "name": "stderr",
1714 | "output_type": "stream",
1715 | "text": [
1716 | "Downloading: \"https://download.pytorch.org/models/resnet101-5d3b4d8f.pth\" to /home/ekami/.torch/models/resnet101-5d3b4d8f.pth\n",
1717 | "100%|██████████| 178728960/178728960 [00:29<00:00, 6127228.47it/s]\n"
1718 | ]
1719 | }
1720 | ],
1721 | "source": [
1722 | "learn = create_cnn(data, models.resnet101, metrics=accuracy)"
1723 | ]
1724 | },
1725 | {
1726 | "cell_type": "code",
1727 | "execution_count": 35,
1728 | "metadata": {},
1729 | "outputs": [
1730 | {
1731 | "data": {
1732 | "text/html": [
1733 | "Total time: 2:36:27 \n",
1734 | " \n",
1735 | " epoch | \n",
1736 | " train_loss | \n",
1737 | " valid_loss | \n",
1738 | " accuracy | \n",
1739 | "
\n",
1740 | " \n",
1741 | " 1 | \n",
1742 | " 2.083457 | \n",
1743 | " 1.451530 | \n",
1744 | " 0.638333 | \n",
1745 | "
\n",
1746 | " \n",
1747 | " 2 | \n",
1748 | " 1.321773 | \n",
1749 | " 0.908532 | \n",
1750 | " 0.733333 | \n",
1751 | "
\n",
1752 | " \n",
1753 | " 3 | \n",
1754 | " 1.074472 | \n",
1755 | " 0.728461 | \n",
1756 | " 0.775750 | \n",
1757 | "
\n",
1758 | " \n",
1759 | " 4 | \n",
1760 | " 0.985465 | \n",
1761 | " 0.666169 | \n",
1762 | " 0.793667 | \n",
1763 | "
\n",
1764 | " \n",
1765 | " 5 | \n",
1766 | " 0.838859 | \n",
1767 | " 0.618405 | \n",
1768 | " 0.806750 | \n",
1769 | "
\n",
1770 | " \n",
1771 | " 6 | \n",
1772 | " 0.803868 | \n",
1773 | " 0.600917 | \n",
1774 | " 0.812500 | \n",
1775 | "
\n",
1776 | " \n",
1777 | " 7 | \n",
1778 | " 0.770296 | \n",
1779 | " 0.592337 | \n",
1780 | " 0.814750 | \n",
1781 | "
\n",
1782 | " \n",
1783 | " 8 | \n",
1784 | " 0.712573 | \n",
1785 | " 0.587856 | \n",
1786 | " 0.818417 | \n",
1787 | "
\n",
1788 | " \n",
1789 | " 9 | \n",
1790 | " 0.700638 | \n",
1791 | " 0.557925 | \n",
1792 | " 0.827667 | \n",
1793 | "
\n",
1794 | " \n",
1795 | " 10 | \n",
1796 | " 0.656998 | \n",
1797 | " 0.543478 | \n",
1798 | " 0.833500 | \n",
1799 | "
\n",
1800 | " \n",
1801 | " 11 | \n",
1802 | " 0.587333 | \n",
1803 | " 0.537264 | \n",
1804 | " 0.833167 | \n",
1805 | "
\n",
1806 | " \n",
1807 | " 12 | \n",
1808 | " 0.576772 | \n",
1809 | " 0.523525 | \n",
1810 | " 0.844917 | \n",
1811 | "
\n",
1812 | " \n",
1813 | " 13 | \n",
1814 | " 0.546539 | \n",
1815 | " 0.511617 | \n",
1816 | " 0.843750 | \n",
1817 | "
\n",
1818 | " \n",
1819 | " 14 | \n",
1820 | " 0.531197 | \n",
1821 | " 0.519306 | \n",
1822 | " 0.839500 | \n",
1823 | "
\n",
1824 | " \n",
1825 | " 15 | \n",
1826 | " 0.467777 | \n",
1827 | " 0.507615 | \n",
1828 | " 0.845667 | \n",
1829 | "
\n",
1830 | " \n",
1831 | " 16 | \n",
1832 | " 0.468377 | \n",
1833 | " 0.504536 | \n",
1834 | " 0.848250 | \n",
1835 | "
\n",
1836 | " \n",
1837 | " 17 | \n",
1838 | " 0.454281 | \n",
1839 | " 0.497895 | \n",
1840 | " 0.849583 | \n",
1841 | "
\n",
1842 | " \n",
1843 | " 18 | \n",
1844 | " 0.420571 | \n",
1845 | " 0.491803 | \n",
1846 | " 0.852833 | \n",
1847 | "
\n",
1848 | " \n",
1849 | " 19 | \n",
1850 | " 0.401252 | \n",
1851 | " 0.486161 | \n",
1852 | " 0.853583 | \n",
1853 | "
\n",
1854 | " \n",
1855 | " 20 | \n",
1856 | " 0.359169 | \n",
1857 | " 0.492218 | \n",
1858 | " 0.853583 | \n",
1859 | "
\n",
1860 | " \n",
1861 | " 21 | \n",
1862 | " 0.353262 | \n",
1863 | " 0.479224 | \n",
1864 | " 0.857250 | \n",
1865 | "
\n",
1866 | " \n",
1867 | " 22 | \n",
1868 | " 0.317846 | \n",
1869 | " 0.482238 | \n",
1870 | " 0.860833 | \n",
1871 | "
\n",
1872 | " \n",
1873 | " 23 | \n",
1874 | " 0.305395 | \n",
1875 | " 0.482287 | \n",
1876 | " 0.860000 | \n",
1877 | "
\n",
1878 | " \n",
1879 | " 24 | \n",
1880 | " 0.302323 | \n",
1881 | " 0.477879 | \n",
1882 | " 0.861250 | \n",
1883 | "
\n",
1884 | " \n",
1885 | " 25 | \n",
1886 | " 0.285679 | \n",
1887 | " 0.477347 | \n",
1888 | " 0.863667 | \n",
1889 | "
\n",
1890 | " \n",
1891 | " 26 | \n",
1892 | " 0.265691 | \n",
1893 | " 0.472855 | \n",
1894 | " 0.864000 | \n",
1895 | "
\n",
1896 | " \n",
1897 | " 27 | \n",
1898 | " 0.235859 | \n",
1899 | " 0.476183 | \n",
1900 | " 0.865417 | \n",
1901 | "
\n",
1902 | " \n",
1903 | " 28 | \n",
1904 | " 0.253295 | \n",
1905 | " 0.475753 | \n",
1906 | " 0.864083 | \n",
1907 | "
\n",
1908 | " \n",
1909 | " 29 | \n",
1910 | " 0.252339 | \n",
1911 | " 0.476051 | \n",
1912 | " 0.865167 | \n",
1913 | "
\n",
1914 | " \n",
1915 | " 30 | \n",
1916 | " 0.238589 | \n",
1917 | " 0.474209 | \n",
1918 | " 0.863583 | \n",
1919 | "
\n",
1920 | "
\n"
1921 | ],
1922 | "text/plain": [
1923 | ""
1924 | ]
1925 | },
1926 | "metadata": {},
1927 | "output_type": "display_data"
1928 | }
1929 | ],
1930 | "source": [
1931 | "with gpu_mem_restore_ctx():\n",
1932 | " learn.fit_one_cycle(30)"
1933 | ]
1934 | },
1935 | {
1936 | "cell_type": "markdown",
1937 | "metadata": {},
1938 | "source": [
1939 | "## Resnet 101 (Mixed Precision)"
1940 | ]
1941 | },
1942 | {
1943 | "cell_type": "code",
1944 | "execution_count": 36,
1945 | "metadata": {},
1946 | "outputs": [],
1947 | "source": [
1948 | "np.random.seed(42)\n",
1949 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1950 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=116).normalize(cifar_stats)"
1951 | ]
1952 | },
1953 | {
1954 | "cell_type": "code",
1955 | "execution_count": 37,
1956 | "metadata": {},
1957 | "outputs": [],
1958 | "source": [
1959 | "learn = to_fp16(create_cnn(data, models.resnet101, metrics=accuracy))"
1960 | ]
1961 | },
1962 | {
1963 | "cell_type": "code",
1964 | "execution_count": 38,
1965 | "metadata": {},
1966 | "outputs": [
1967 | {
1968 | "data": {
1969 | "text/html": [
1970 | "Total time: 2:08:00 \n",
1971 | " \n",
1972 | " epoch | \n",
1973 | " train_loss | \n",
1974 | " valid_loss | \n",
1975 | " accuracy | \n",
1976 | "
\n",
1977 | " \n",
1978 | " 1 | \n",
1979 | " 2.402573 | \n",
1980 | " 1.641406 | \n",
1981 | " 0.603583 | \n",
1982 | "
\n",
1983 | " \n",
1984 | " 2 | \n",
1985 | " 1.470369 | \n",
1986 | " 0.982278 | \n",
1987 | " 0.721500 | \n",
1988 | "
\n",
1989 | " \n",
1990 | " 3 | \n",
1991 | " 1.102205 | \n",
1992 | " 0.770284 | \n",
1993 | " 0.762917 | \n",
1994 | "
\n",
1995 | " \n",
1996 | " 4 | \n",
1997 | " 0.931413 | \n",
1998 | " 0.675970 | \n",
1999 | " 0.792833 | \n",
2000 | "
\n",
2001 | " \n",
2002 | " 5 | \n",
2003 | " 0.843089 | \n",
2004 | " 0.631572 | \n",
2005 | " 0.804000 | \n",
2006 | "
\n",
2007 | " \n",
2008 | " 6 | \n",
2009 | " 0.773870 | \n",
2010 | " 0.607783 | \n",
2011 | " 0.812583 | \n",
2012 | "
\n",
2013 | " \n",
2014 | " 7 | \n",
2015 | " 0.747394 | \n",
2016 | " 0.579102 | \n",
2017 | " 0.823000 | \n",
2018 | "
\n",
2019 | " \n",
2020 | " 8 | \n",
2021 | " 0.698355 | \n",
2022 | " 0.569311 | \n",
2023 | " 0.822667 | \n",
2024 | "
\n",
2025 | " \n",
2026 | " 9 | \n",
2027 | " 0.665380 | \n",
2028 | " 0.569775 | \n",
2029 | " 0.824250 | \n",
2030 | "
\n",
2031 | " \n",
2032 | " 10 | \n",
2033 | " 0.601234 | \n",
2034 | " 0.540831 | \n",
2035 | " 0.833250 | \n",
2036 | "
\n",
2037 | " \n",
2038 | " 11 | \n",
2039 | " 0.583828 | \n",
2040 | " 0.515747 | \n",
2041 | " 0.842083 | \n",
2042 | "
\n",
2043 | " \n",
2044 | " 12 | \n",
2045 | " 0.515766 | \n",
2046 | " 0.514028 | \n",
2047 | " 0.841167 | \n",
2048 | "
\n",
2049 | " \n",
2050 | " 13 | \n",
2051 | " 0.501062 | \n",
2052 | " 0.508455 | \n",
2053 | " 0.844500 | \n",
2054 | "
\n",
2055 | " \n",
2056 | " 14 | \n",
2057 | " 0.475348 | \n",
2058 | " 0.509687 | \n",
2059 | " 0.845667 | \n",
2060 | "
\n",
2061 | " \n",
2062 | " 15 | \n",
2063 | " 0.463935 | \n",
2064 | " 0.501683 | \n",
2065 | " 0.846500 | \n",
2066 | "
\n",
2067 | " \n",
2068 | " 16 | \n",
2069 | " 0.433365 | \n",
2070 | " 0.493446 | \n",
2071 | " 0.852083 | \n",
2072 | "
\n",
2073 | " \n",
2074 | " 17 | \n",
2075 | " 0.399650 | \n",
2076 | " 0.496941 | \n",
2077 | " 0.854250 | \n",
2078 | "
\n",
2079 | " \n",
2080 | " 18 | \n",
2081 | " 0.380342 | \n",
2082 | " 0.498303 | \n",
2083 | " 0.853250 | \n",
2084 | "
\n",
2085 | " \n",
2086 | " 19 | \n",
2087 | " 0.350620 | \n",
2088 | " 0.483819 | \n",
2089 | " 0.857833 | \n",
2090 | "
\n",
2091 | " \n",
2092 | " 20 | \n",
2093 | " 0.329624 | \n",
2094 | " 0.486034 | \n",
2095 | " 0.855667 | \n",
2096 | "
\n",
2097 | " \n",
2098 | " 21 | \n",
2099 | " 0.320623 | \n",
2100 | " 0.480329 | \n",
2101 | " 0.858750 | \n",
2102 | "
\n",
2103 | " \n",
2104 | " 22 | \n",
2105 | " 0.299029 | \n",
2106 | " 0.477023 | \n",
2107 | " 0.859917 | \n",
2108 | "
\n",
2109 | " \n",
2110 | " 23 | \n",
2111 | " 0.264857 | \n",
2112 | " 0.476886 | \n",
2113 | " 0.863000 | \n",
2114 | "
\n",
2115 | " \n",
2116 | " 24 | \n",
2117 | " 0.261844 | \n",
2118 | " 0.477507 | \n",
2119 | " 0.864083 | \n",
2120 | "
\n",
2121 | " \n",
2122 | " 25 | \n",
2123 | " 0.254985 | \n",
2124 | " 0.476712 | \n",
2125 | " 0.864250 | \n",
2126 | "
\n",
2127 | " \n",
2128 | " 26 | \n",
2129 | " 0.255878 | \n",
2130 | " 0.476127 | \n",
2131 | " 0.864417 | \n",
2132 | "
\n",
2133 | " \n",
2134 | " 27 | \n",
2135 | " 0.231718 | \n",
2136 | " 0.474823 | \n",
2137 | " 0.866250 | \n",
2138 | "
\n",
2139 | " \n",
2140 | " 28 | \n",
2141 | " 0.221085 | \n",
2142 | " 0.473201 | \n",
2143 | " 0.866667 | \n",
2144 | "
\n",
2145 | " \n",
2146 | " 29 | \n",
2147 | " 0.221620 | \n",
2148 | " 0.471682 | \n",
2149 | " 0.865083 | \n",
2150 | "
\n",
2151 | " \n",
2152 | " 30 | \n",
2153 | " 0.234261 | \n",
2154 | " 0.476006 | \n",
2155 | " 0.865417 | \n",
2156 | "
\n",
2157 | "
\n"
2158 | ],
2159 | "text/plain": [
2160 | ""
2161 | ]
2162 | },
2163 | "metadata": {},
2164 | "output_type": "display_data"
2165 | }
2166 | ],
2167 | "source": [
2168 | "with gpu_mem_restore_ctx():\n",
2169 | " learn.fit_one_cycle(30)"
2170 | ]
2171 | }
2172 | ],
2173 | "metadata": {
2174 | "kernelspec": {
2175 | "display_name": "Python 3",
2176 | "language": "python",
2177 | "name": "python3"
2178 | },
2179 | "language_info": {
2180 | "codemirror_mode": {
2181 | "name": "ipython",
2182 | "version": 3
2183 | },
2184 | "file_extension": ".py",
2185 | "mimetype": "text/x-python",
2186 | "name": "python",
2187 | "nbconvert_exporter": "python",
2188 | "pygments_lexer": "ipython3",
2189 | "version": "3.6.6"
2190 | }
2191 | },
2192 | "nbformat": 4,
2193 | "nbformat_minor": 2
2194 | }
2195 |
--------------------------------------------------------------------------------
/2080Ti.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import functools\n",
21 | "import traceback\n",
22 | "def get_ref_free_exc_info():\n",
23 | " \"Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory\"\n",
24 | " type, val, tb = sys.exc_info()\n",
25 | " traceback.clear_frames(tb)\n",
26 | " return (type, val, tb)\n",
27 | "\n",
28 | "def gpu_mem_restore(func):\n",
29 | " \"Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted\"\n",
30 | " @functools.wraps(func)\n",
31 | " def wrapper(*args, **kwargs):\n",
32 | " try:\n",
33 | " return func(*args, **kwargs)\n",
34 | " except:\n",
35 | " type, val, tb = get_ref_free_exc_info() # must!\n",
36 | " raise type(val).with_traceback(tb) from None\n",
37 | " return wrapper"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "from fastai.vision import *\n",
47 | "from fastai.metrics import error_rate"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 4,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "path = Path('/home/init27/Downloads/cifar100/')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 5,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "train = path/'train'\n",
66 | "test = path/'test'"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 8,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "bs = 108"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 6,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "data": {
85 | "text/plain": [
86 | "[PosixPath('/home/init27/Downloads/cifar100/models'),\n",
87 | " PosixPath('/home/init27/Downloads/cifar100/train'),\n",
88 | " PosixPath('/home/init27/Downloads/cifar100/test')]"
89 | ]
90 | },
91 | "execution_count": 6,
92 | "metadata": {},
93 | "output_type": "execute_result"
94 | }
95 | ],
96 | "source": [
97 | "path.ls()"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 9,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "np.random.seed(42)\n",
107 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
108 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=bs).normalize(cifar_stats)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 10,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "class gpu_mem_restore_ctx():\n",
118 | " \" context manager to reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted\"\n",
119 | " def __enter__(self): return self\n",
120 | " def __exit__(self, exc_type, exc_val, exc_tb):\n",
121 | " if not exc_val: return True\n",
122 | " traceback.clear_frames(exc_tb)\n",
123 | " raise exc_type(exc_val).with_traceback(exc_tb) from None"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 11,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "#Allow crashing\n",
133 | "learn = create_cnn(data, models.resnet18 , metrics=accuracy)"
134 | ]
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "metadata": {},
139 | "source": [
140 | "## Resnet 18\n",
141 | "\n"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 17,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "np.random.seed(42)\n",
151 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
152 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=408).normalize(cifar_stats)\n",
153 | "learn = create_cnn(data, models.resnet18, metrics=accuracy)"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 18,
159 | "metadata": {},
160 | "outputs": [
161 | {
162 | "data": {
163 | "text/html": [
164 | "Total time: 00:54 \n",
165 | " \n",
166 | " epoch | \n",
167 | " train_loss | \n",
168 | " valid_loss | \n",
169 | " accuracy | \n",
170 | "
\n",
171 | " \n",
172 | " 1 | \n",
173 | " 2.425544 | \n",
174 | " 1.632482 | \n",
175 | " 0.559083 | \n",
176 | "
\n",
177 | "
\n"
178 | ],
179 | "text/plain": [
180 | ""
181 | ]
182 | },
183 | "metadata": {},
184 | "output_type": "display_data"
185 | }
186 | ],
187 | "source": [
188 | "with gpu_mem_restore_ctx():\n",
189 | " learn.fit_one_cycle(1)"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 19,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "data": {
199 | "text/html": [
200 | "Total time: 28:12 \n",
201 | " \n",
202 | " epoch | \n",
203 | " train_loss | \n",
204 | " valid_loss | \n",
205 | " accuracy | \n",
206 | "
\n",
207 | " \n",
208 | " 1 | \n",
209 | " 1.683162 | \n",
210 | " 1.234888 | \n",
211 | " 0.647083 | \n",
212 | "
\n",
213 | " \n",
214 | " 2 | \n",
215 | " 1.377793 | \n",
216 | " 1.006688 | \n",
217 | " 0.698833 | \n",
218 | "
\n",
219 | " \n",
220 | " 3 | \n",
221 | " 1.210670 | \n",
222 | " 0.953726 | \n",
223 | " 0.711000 | \n",
224 | "
\n",
225 | " \n",
226 | " 4 | \n",
227 | " 1.161485 | \n",
228 | " 0.981398 | \n",
229 | " 0.702750 | \n",
230 | "
\n",
231 | " \n",
232 | " 5 | \n",
233 | " 1.158831 | \n",
234 | " 0.985615 | \n",
235 | " 0.702833 | \n",
236 | "
\n",
237 | " \n",
238 | " 6 | \n",
239 | " 1.168251 | \n",
240 | " 1.045662 | \n",
241 | " 0.683667 | \n",
242 | "
\n",
243 | " \n",
244 | " 7 | \n",
245 | " 1.150126 | \n",
246 | " 1.071930 | \n",
247 | " 0.681250 | \n",
248 | "
\n",
249 | " \n",
250 | " 8 | \n",
251 | " 1.146979 | \n",
252 | " 1.025527 | \n",
253 | " 0.692417 | \n",
254 | "
\n",
255 | " \n",
256 | " 9 | \n",
257 | " 1.109232 | \n",
258 | " 1.041167 | \n",
259 | " 0.691250 | \n",
260 | "
\n",
261 | " \n",
262 | " 10 | \n",
263 | " 1.088920 | \n",
264 | " 1.046639 | \n",
265 | " 0.685083 | \n",
266 | "
\n",
267 | " \n",
268 | " 11 | \n",
269 | " 1.055113 | \n",
270 | " 1.009589 | \n",
271 | " 0.702417 | \n",
272 | "
\n",
273 | " \n",
274 | " 12 | \n",
275 | " 1.027755 | \n",
276 | " 0.937168 | \n",
277 | " 0.716500 | \n",
278 | "
\n",
279 | " \n",
280 | " 13 | \n",
281 | " 1.003785 | \n",
282 | " 0.953145 | \n",
283 | " 0.715917 | \n",
284 | "
\n",
285 | " \n",
286 | " 14 | \n",
287 | " 0.968761 | \n",
288 | " 0.916469 | \n",
289 | " 0.725417 | \n",
290 | "
\n",
291 | " \n",
292 | " 15 | \n",
293 | " 0.932007 | \n",
294 | " 0.907113 | \n",
295 | " 0.728667 | \n",
296 | "
\n",
297 | " \n",
298 | " 16 | \n",
299 | " 0.899529 | \n",
300 | " 0.874359 | \n",
301 | " 0.738333 | \n",
302 | "
\n",
303 | " \n",
304 | " 17 | \n",
305 | " 0.876744 | \n",
306 | " 0.819828 | \n",
307 | " 0.748167 | \n",
308 | "
\n",
309 | " \n",
310 | " 18 | \n",
311 | " 0.853781 | \n",
312 | " 0.813430 | \n",
313 | " 0.750667 | \n",
314 | "
\n",
315 | " \n",
316 | " 19 | \n",
317 | " 0.821074 | \n",
318 | " 0.790245 | \n",
319 | " 0.762583 | \n",
320 | "
\n",
321 | " \n",
322 | " 20 | \n",
323 | " 0.793510 | \n",
324 | " 0.764171 | \n",
325 | " 0.768833 | \n",
326 | "
\n",
327 | " \n",
328 | " 21 | \n",
329 | " 0.766965 | \n",
330 | " 0.755508 | \n",
331 | " 0.770500 | \n",
332 | "
\n",
333 | " \n",
334 | " 22 | \n",
335 | " 0.738198 | \n",
336 | " 0.737558 | \n",
337 | " 0.777167 | \n",
338 | "
\n",
339 | " \n",
340 | " 23 | \n",
341 | " 0.711821 | \n",
342 | " 0.725124 | \n",
343 | " 0.779583 | \n",
344 | "
\n",
345 | " \n",
346 | " 24 | \n",
347 | " 0.680861 | \n",
348 | " 0.722985 | \n",
349 | " 0.778667 | \n",
350 | "
\n",
351 | " \n",
352 | " 25 | \n",
353 | " 0.663084 | \n",
354 | " 0.706659 | \n",
355 | " 0.783667 | \n",
356 | "
\n",
357 | " \n",
358 | " 26 | \n",
359 | " 0.644237 | \n",
360 | " 0.696470 | \n",
361 | " 0.788083 | \n",
362 | "
\n",
363 | " \n",
364 | " 27 | \n",
365 | " 0.619859 | \n",
366 | " 0.691847 | \n",
367 | " 0.790667 | \n",
368 | "
\n",
369 | " \n",
370 | " 28 | \n",
371 | " 0.609522 | \n",
372 | " 0.691935 | \n",
373 | " 0.788750 | \n",
374 | "
\n",
375 | " \n",
376 | " 29 | \n",
377 | " 0.597701 | \n",
378 | " 0.690438 | \n",
379 | " 0.789417 | \n",
380 | "
\n",
381 | " \n",
382 | " 30 | \n",
383 | " 0.605155 | \n",
384 | " 0.690853 | \n",
385 | " 0.788500 | \n",
386 | "
\n",
387 | "
\n"
388 | ],
389 | "text/plain": [
390 | ""
391 | ]
392 | },
393 | "metadata": {},
394 | "output_type": "display_data"
395 | }
396 | ],
397 | "source": [
398 | "with gpu_mem_restore_ctx():\n",
399 | " learn.fit_one_cycle(30, max_lr=1e-2)"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "## Resnet 18 (Mixed Precision)"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 29,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "np.random.seed(42)\n",
416 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
417 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=720).normalize(cifar_stats)"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 30,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "learn = to_fp16(create_cnn(data, models.resnet18, metrics=accuracy))"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": 31,
432 | "metadata": {},
433 | "outputs": [
434 | {
435 | "data": {
436 | "text/html": [
437 | "Total time: 00:58 \n",
438 | " \n",
439 | " epoch | \n",
440 | " train_loss | \n",
441 | " valid_loss | \n",
442 | " accuracy | \n",
443 | "
\n",
444 | " \n",
445 | " 1 | \n",
446 | " 2.892367 | \n",
447 | " 1.857484 | \n",
448 | " 0.514833 | \n",
449 | "
\n",
450 | "
\n"
451 | ],
452 | "text/plain": [
453 | ""
454 | ]
455 | },
456 | "metadata": {},
457 | "output_type": "display_data"
458 | }
459 | ],
460 | "source": [
461 | "with gpu_mem_restore_ctx():\n",
462 | " learn.fit_one_cycle(1)"
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": 32,
468 | "metadata": {},
469 | "outputs": [
470 | {
471 | "data": {
472 | "text/html": [
473 | "Total time: 28:28 \n",
474 | " \n",
475 | " epoch | \n",
476 | " train_loss | \n",
477 | " valid_loss | \n",
478 | " accuracy | \n",
479 | "
\n",
480 | " \n",
481 | " 1 | \n",
482 | " 2.217803 | \n",
483 | " 1.776256 | \n",
484 | " 0.535000 | \n",
485 | "
\n",
486 | " \n",
487 | " 2 | \n",
488 | " 2.111644 | \n",
489 | " 1.634846 | \n",
490 | " 0.565083 | \n",
491 | "
\n",
492 | " \n",
493 | " 3 | \n",
494 | " 1.948603 | \n",
495 | " 1.451506 | \n",
496 | " 0.600167 | \n",
497 | "
\n",
498 | " \n",
499 | " 4 | \n",
500 | " 1.763333 | \n",
501 | " 1.288680 | \n",
502 | " 0.634083 | \n",
503 | "
\n",
504 | " \n",
505 | " 5 | \n",
506 | " 1.576218 | \n",
507 | " 1.170622 | \n",
508 | " 0.659167 | \n",
509 | "
\n",
510 | " \n",
511 | " 6 | \n",
512 | " 1.435181 | \n",
513 | " 1.088699 | \n",
514 | " 0.677000 | \n",
515 | "
\n",
516 | " \n",
517 | " 7 | \n",
518 | " 1.335717 | \n",
519 | " 1.024443 | \n",
520 | " 0.695833 | \n",
521 | "
\n",
522 | " \n",
523 | " 8 | \n",
524 | " 1.241808 | \n",
525 | " 0.974325 | \n",
526 | " 0.710167 | \n",
527 | "
\n",
528 | " \n",
529 | " 9 | \n",
530 | " 1.171949 | \n",
531 | " 0.934963 | \n",
532 | " 0.718083 | \n",
533 | "
\n",
534 | " \n",
535 | " 10 | \n",
536 | " 1.118036 | \n",
537 | " 0.914327 | \n",
538 | " 0.726500 | \n",
539 | "
\n",
540 | " \n",
541 | " 11 | \n",
542 | " 1.073453 | \n",
543 | " 0.887545 | \n",
544 | " 0.731083 | \n",
545 | "
\n",
546 | " \n",
547 | " 12 | \n",
548 | " 1.032388 | \n",
549 | " 0.872336 | \n",
550 | " 0.735083 | \n",
551 | "
\n",
552 | " \n",
553 | " 13 | \n",
554 | " 0.991246 | \n",
555 | " 0.859043 | \n",
556 | " 0.737417 | \n",
557 | "
\n",
558 | " \n",
559 | " 14 | \n",
560 | " 0.968865 | \n",
561 | " 0.838764 | \n",
562 | " 0.744333 | \n",
563 | "
\n",
564 | " \n",
565 | " 15 | \n",
566 | " 0.941147 | \n",
567 | " 0.827038 | \n",
568 | " 0.748500 | \n",
569 | "
\n",
570 | " \n",
571 | " 16 | \n",
572 | " 0.921168 | \n",
573 | " 0.815047 | \n",
574 | " 0.750500 | \n",
575 | "
\n",
576 | " \n",
577 | " 17 | \n",
578 | " 0.896468 | \n",
579 | " 0.810671 | \n",
580 | " 0.750750 | \n",
581 | "
\n",
582 | " \n",
583 | " 18 | \n",
584 | " 0.876352 | \n",
585 | " 0.797368 | \n",
586 | " 0.752500 | \n",
587 | "
\n",
588 | " \n",
589 | " 19 | \n",
590 | " 0.855625 | \n",
591 | " 0.796557 | \n",
592 | " 0.754833 | \n",
593 | "
\n",
594 | " \n",
595 | " 20 | \n",
596 | " 0.834282 | \n",
597 | " 0.787105 | \n",
598 | " 0.758667 | \n",
599 | "
\n",
600 | " \n",
601 | " 21 | \n",
602 | " 0.824325 | \n",
603 | " 0.783010 | \n",
604 | " 0.761167 | \n",
605 | "
\n",
606 | " \n",
607 | " 22 | \n",
608 | " 0.799274 | \n",
609 | " 0.776148 | \n",
610 | " 0.762583 | \n",
611 | "
\n",
612 | " \n",
613 | " 23 | \n",
614 | " 0.790784 | \n",
615 | " 0.770834 | \n",
616 | " 0.761250 | \n",
617 | "
\n",
618 | " \n",
619 | " 24 | \n",
620 | " 0.774112 | \n",
621 | " 0.768528 | \n",
622 | " 0.763750 | \n",
623 | "
\n",
624 | " \n",
625 | " 25 | \n",
626 | " 0.767463 | \n",
627 | " 0.766201 | \n",
628 | " 0.765917 | \n",
629 | "
\n",
630 | " \n",
631 | " 26 | \n",
632 | " 0.758423 | \n",
633 | " 0.764043 | \n",
634 | " 0.766750 | \n",
635 | "
\n",
636 | " \n",
637 | " 27 | \n",
638 | " 0.749885 | \n",
639 | " 0.762675 | \n",
640 | " 0.766500 | \n",
641 | "
\n",
642 | " \n",
643 | " 28 | \n",
644 | " 0.746697 | \n",
645 | " 0.761255 | \n",
646 | " 0.767583 | \n",
647 | "
\n",
648 | " \n",
649 | " 29 | \n",
650 | " 0.749717 | \n",
651 | " 0.761567 | \n",
652 | " 0.766833 | \n",
653 | "
\n",
654 | " \n",
655 | " 30 | \n",
656 | " 0.747547 | \n",
657 | " 0.762019 | \n",
658 | " 0.766500 | \n",
659 | "
\n",
660 | "
\n"
661 | ],
662 | "text/plain": [
663 | ""
664 | ]
665 | },
666 | "metadata": {},
667 | "output_type": "display_data"
668 | }
669 | ],
670 | "source": [
671 | "with gpu_mem_restore_ctx():\n",
672 | " learn.fit_one_cycle(30)"
673 | ]
674 | },
675 | {
676 | "cell_type": "markdown",
677 | "metadata": {},
678 | "source": [
679 | "## Resnet 34 "
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": 155,
685 | "metadata": {},
686 | "outputs": [],
687 | "source": [
688 | "np.random.seed(42)\n",
689 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
690 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=248).normalize(cifar_stats)"
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": 156,
696 | "metadata": {},
697 | "outputs": [],
698 | "source": [
699 | "learn = create_cnn(data, models.resnet34, metrics=accuracy)"
700 | ]
701 | },
702 | {
703 | "cell_type": "code",
704 | "execution_count": 157,
705 | "metadata": {},
706 | "outputs": [
707 | {
708 | "data": {
709 | "text/html": [
710 | "Total time: 39:37 \n",
711 | " \n",
712 | " epoch | \n",
713 | " train_loss | \n",
714 | " valid_loss | \n",
715 | " accuracy | \n",
716 | "
\n",
717 | " \n",
718 | " 1 | \n",
719 | " 4.091527 | \n",
720 | " 3.254414 | \n",
721 | " 0.291917 | \n",
722 | "
\n",
723 | " \n",
724 | " 2 | \n",
725 | " 2.743854 | \n",
726 | " 1.980968 | \n",
727 | " 0.516750 | \n",
728 | "
\n",
729 | " \n",
730 | " 3 | \n",
731 | " 1.920904 | \n",
732 | " 1.337292 | \n",
733 | " 0.628833 | \n",
734 | "
\n",
735 | " \n",
736 | " 4 | \n",
737 | " 1.471699 | \n",
738 | " 1.053553 | \n",
739 | " 0.692083 | \n",
740 | "
\n",
741 | " \n",
742 | " 5 | \n",
743 | " 1.233301 | \n",
744 | " 0.913338 | \n",
745 | " 0.720250 | \n",
746 | "
\n",
747 | " \n",
748 | " 6 | \n",
749 | " 1.109750 | \n",
750 | " 0.848846 | \n",
751 | " 0.743417 | \n",
752 | "
\n",
753 | " \n",
754 | " 7 | \n",
755 | " 1.037520 | \n",
756 | " 0.791737 | \n",
757 | " 0.756333 | \n",
758 | "
\n",
759 | " \n",
760 | " 8 | \n",
761 | " 0.969036 | \n",
762 | " 0.762382 | \n",
763 | " 0.766417 | \n",
764 | "
\n",
765 | " \n",
766 | " 9 | \n",
767 | " 0.919657 | \n",
768 | " 0.729532 | \n",
769 | " 0.773250 | \n",
770 | "
\n",
771 | " \n",
772 | " 10 | \n",
773 | " 0.885872 | \n",
774 | " 0.712802 | \n",
775 | " 0.778000 | \n",
776 | "
\n",
777 | " \n",
778 | " 11 | \n",
779 | " 0.839404 | \n",
780 | " 0.704638 | \n",
781 | " 0.782583 | \n",
782 | "
\n",
783 | " \n",
784 | " 12 | \n",
785 | " 0.813941 | \n",
786 | " 0.690926 | \n",
787 | " 0.786333 | \n",
788 | "
\n",
789 | " \n",
790 | " 13 | \n",
791 | " 0.789469 | \n",
792 | " 0.672594 | \n",
793 | " 0.793000 | \n",
794 | "
\n",
795 | " \n",
796 | " 14 | \n",
797 | " 0.765909 | \n",
798 | " 0.663917 | \n",
799 | " 0.796250 | \n",
800 | "
\n",
801 | " \n",
802 | " 15 | \n",
803 | " 0.728020 | \n",
804 | " 0.653305 | \n",
805 | " 0.800167 | \n",
806 | "
\n",
807 | " \n",
808 | " 16 | \n",
809 | " 0.707719 | \n",
810 | " 0.656965 | \n",
811 | " 0.797250 | \n",
812 | "
\n",
813 | " \n",
814 | " 17 | \n",
815 | " 0.682595 | \n",
816 | " 0.642476 | \n",
817 | " 0.804083 | \n",
818 | "
\n",
819 | " \n",
820 | " 18 | \n",
821 | " 0.663653 | \n",
822 | " 0.637333 | \n",
823 | " 0.803500 | \n",
824 | "
\n",
825 | " \n",
826 | " 19 | \n",
827 | " 0.651230 | \n",
828 | " 0.625249 | \n",
829 | " 0.806083 | \n",
830 | "
\n",
831 | " \n",
832 | " 20 | \n",
833 | " 0.621778 | \n",
834 | " 0.622463 | \n",
835 | " 0.809583 | \n",
836 | "
\n",
837 | " \n",
838 | " 21 | \n",
839 | " 0.613027 | \n",
840 | " 0.621823 | \n",
841 | " 0.809250 | \n",
842 | "
\n",
843 | " \n",
844 | " 22 | \n",
845 | " 0.608741 | \n",
846 | " 0.618127 | \n",
847 | " 0.813333 | \n",
848 | "
\n",
849 | " \n",
850 | " 23 | \n",
851 | " 0.577575 | \n",
852 | " 0.613892 | \n",
853 | " 0.814583 | \n",
854 | "
\n",
855 | " \n",
856 | " 24 | \n",
857 | " 0.570914 | \n",
858 | " 0.609038 | \n",
859 | " 0.814833 | \n",
860 | "
\n",
861 | " \n",
862 | " 25 | \n",
863 | " 0.544643 | \n",
864 | " 0.609252 | \n",
865 | " 0.816750 | \n",
866 | "
\n",
867 | " \n",
868 | " 26 | \n",
869 | " 0.548220 | \n",
870 | " 0.606148 | \n",
871 | " 0.817583 | \n",
872 | "
\n",
873 | " \n",
874 | " 27 | \n",
875 | " 0.535698 | \n",
876 | " 0.604082 | \n",
877 | " 0.817750 | \n",
878 | "
\n",
879 | " \n",
880 | " 28 | \n",
881 | " 0.525107 | \n",
882 | " 0.604753 | \n",
883 | " 0.817667 | \n",
884 | "
\n",
885 | " \n",
886 | " 29 | \n",
887 | " 0.531150 | \n",
888 | " 0.604243 | \n",
889 | " 0.817250 | \n",
890 | "
\n",
891 | " \n",
892 | " 30 | \n",
893 | " 0.533026 | \n",
894 | " 0.604470 | \n",
895 | " 0.817167 | \n",
896 | "
\n",
897 | "
\n"
898 | ],
899 | "text/plain": [
900 | ""
901 | ]
902 | },
903 | "metadata": {},
904 | "output_type": "display_data"
905 | }
906 | ],
907 | "source": [
908 | "with gpu_mem_restore_ctx():\n",
909 | " learn.fit_one_cycle(30)"
910 | ]
911 | },
912 | {
913 | "cell_type": "markdown",
914 | "metadata": {},
915 | "source": [
916 | "## Resnet 34 (Mixed Precision)"
917 | ]
918 | },
919 | {
920 | "cell_type": "code",
921 | "execution_count": 158,
922 | "metadata": {},
923 | "outputs": [],
924 | "source": [
925 | "np.random.seed(42)\n",
926 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
927 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=496).normalize(cifar_stats)"
928 | ]
929 | },
930 | {
931 | "cell_type": "code",
932 | "execution_count": 159,
933 | "metadata": {},
934 | "outputs": [],
935 | "source": [
936 | "learn = to_fp16(create_cnn(data, models.resnet34, metrics=accuracy))"
937 | ]
938 | },
939 | {
940 | "cell_type": "code",
941 | "execution_count": 160,
942 | "metadata": {},
943 | "outputs": [
944 | {
945 | "data": {
946 | "text/html": [
947 | "Total time: 32:13 \n",
948 | " \n",
949 | " epoch | \n",
950 | " train_loss | \n",
951 | " valid_loss | \n",
952 | " accuracy | \n",
953 | "
\n",
954 | " \n",
955 | " 1 | \n",
956 | " 4.527595 | \n",
957 | " 3.723982 | \n",
958 | " 0.192333 | \n",
959 | "
\n",
960 | " \n",
961 | " 2 | \n",
962 | " 3.496515 | \n",
963 | " 2.463351 | \n",
964 | " 0.443083 | \n",
965 | "
\n",
966 | " \n",
967 | " 3 | \n",
968 | " 2.486472 | \n",
969 | " 1.608410 | \n",
970 | " 0.577167 | \n",
971 | "
\n",
972 | " \n",
973 | " 4 | \n",
974 | " 1.810522 | \n",
975 | " 1.204057 | \n",
976 | " 0.655750 | \n",
977 | "
\n",
978 | " \n",
979 | " 5 | \n",
980 | " 1.430669 | \n",
981 | " 1.008747 | \n",
982 | " 0.701667 | \n",
983 | "
\n",
984 | " \n",
985 | " 6 | \n",
986 | " 1.225085 | \n",
987 | " 0.905269 | \n",
988 | " 0.727500 | \n",
989 | "
\n",
990 | " \n",
991 | " 7 | \n",
992 | " 1.092392 | \n",
993 | " 0.835543 | \n",
994 | " 0.743750 | \n",
995 | "
\n",
996 | " \n",
997 | " 8 | \n",
998 | " 1.013207 | \n",
999 | " 0.793349 | \n",
1000 | " 0.753833 | \n",
1001 | "
\n",
1002 | " \n",
1003 | " 9 | \n",
1004 | " 0.954900 | \n",
1005 | " 0.760904 | \n",
1006 | " 0.762250 | \n",
1007 | "
\n",
1008 | " \n",
1009 | " 10 | \n",
1010 | " 0.897951 | \n",
1011 | " 0.739107 | \n",
1012 | " 0.773417 | \n",
1013 | "
\n",
1014 | " \n",
1015 | " 11 | \n",
1016 | " 0.860628 | \n",
1017 | " 0.723568 | \n",
1018 | " 0.776000 | \n",
1019 | "
\n",
1020 | " \n",
1021 | " 12 | \n",
1022 | " 0.823930 | \n",
1023 | " 0.703144 | \n",
1024 | " 0.784000 | \n",
1025 | "
\n",
1026 | " \n",
1027 | " 13 | \n",
1028 | " 0.792410 | \n",
1029 | " 0.681215 | \n",
1030 | " 0.790583 | \n",
1031 | "
\n",
1032 | " \n",
1033 | " 14 | \n",
1034 | " 0.767022 | \n",
1035 | " 0.683587 | \n",
1036 | " 0.789167 | \n",
1037 | "
\n",
1038 | " \n",
1039 | " 15 | \n",
1040 | " 0.741968 | \n",
1041 | " 0.668285 | \n",
1042 | " 0.796083 | \n",
1043 | "
\n",
1044 | " \n",
1045 | " 16 | \n",
1046 | " 0.718596 | \n",
1047 | " 0.661959 | \n",
1048 | " 0.797583 | \n",
1049 | "
\n",
1050 | " \n",
1051 | " 17 | \n",
1052 | " 0.693880 | \n",
1053 | " 0.657565 | \n",
1054 | " 0.796500 | \n",
1055 | "
\n",
1056 | " \n",
1057 | " 18 | \n",
1058 | " 0.676531 | \n",
1059 | " 0.652226 | \n",
1060 | " 0.804333 | \n",
1061 | "
\n",
1062 | " \n",
1063 | " 19 | \n",
1064 | " 0.661541 | \n",
1065 | " 0.650092 | \n",
1066 | " 0.803000 | \n",
1067 | "
\n",
1068 | " \n",
1069 | " 20 | \n",
1070 | " 0.641904 | \n",
1071 | " 0.632877 | \n",
1072 | " 0.808167 | \n",
1073 | "
\n",
1074 | " \n",
1075 | " 21 | \n",
1076 | " 0.628384 | \n",
1077 | " 0.633768 | \n",
1078 | " 0.807500 | \n",
1079 | "
\n",
1080 | " \n",
1081 | " 22 | \n",
1082 | " 0.607901 | \n",
1083 | " 0.630894 | \n",
1084 | " 0.808167 | \n",
1085 | "
\n",
1086 | " \n",
1087 | " 23 | \n",
1088 | " 0.597407 | \n",
1089 | " 0.630145 | \n",
1090 | " 0.809083 | \n",
1091 | "
\n",
1092 | " \n",
1093 | " 24 | \n",
1094 | " 0.580347 | \n",
1095 | " 0.625942 | \n",
1096 | " 0.809583 | \n",
1097 | "
\n",
1098 | " \n",
1099 | " 25 | \n",
1100 | " 0.576550 | \n",
1101 | " 0.624290 | \n",
1102 | " 0.813000 | \n",
1103 | "
\n",
1104 | " \n",
1105 | " 26 | \n",
1106 | " 0.564124 | \n",
1107 | " 0.620901 | \n",
1108 | " 0.811500 | \n",
1109 | "
\n",
1110 | " \n",
1111 | " 27 | \n",
1112 | " 0.561751 | \n",
1113 | " 0.620830 | \n",
1114 | " 0.811000 | \n",
1115 | "
\n",
1116 | " \n",
1117 | " 28 | \n",
1118 | " 0.549595 | \n",
1119 | " 0.620113 | \n",
1120 | " 0.811500 | \n",
1121 | "
\n",
1122 | " \n",
1123 | " 29 | \n",
1124 | " 0.549034 | \n",
1125 | " 0.619881 | \n",
1126 | " 0.812917 | \n",
1127 | "
\n",
1128 | " \n",
1129 | " 30 | \n",
1130 | " 0.553179 | \n",
1131 | " 0.618634 | \n",
1132 | " 0.812083 | \n",
1133 | "
\n",
1134 | "
\n"
1135 | ],
1136 | "text/plain": [
1137 | ""
1138 | ]
1139 | },
1140 | "metadata": {},
1141 | "output_type": "display_data"
1142 | }
1143 | ],
1144 | "source": [
1145 | "with gpu_mem_restore_ctx():\n",
1146 | " learn.fit_one_cycle(30)"
1147 | ]
1148 | },
1149 | {
1150 | "cell_type": "markdown",
1151 | "metadata": {},
1152 | "source": [
1153 | "## Training Resnet 50"
1154 | ]
1155 | },
1156 | {
1157 | "cell_type": "code",
1158 | "execution_count": 44,
1159 | "metadata": {},
1160 | "outputs": [],
1161 | "source": [
1162 | "np.random.seed(42)\n",
1163 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1164 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=98).normalize(cifar_stats)"
1165 | ]
1166 | },
1167 | {
1168 | "cell_type": "code",
1169 | "execution_count": 45,
1170 | "metadata": {},
1171 | "outputs": [],
1172 | "source": [
1173 | "learn = create_cnn(data, models.resnet50, metrics=accuracy)"
1174 | ]
1175 | },
1176 | {
1177 | "cell_type": "code",
1178 | "execution_count": 46,
1179 | "metadata": {},
1180 | "outputs": [
1181 | {
1182 | "data": {
1183 | "text/html": [
1184 | "Total time: 1:16:18 \n",
1185 | " \n",
1186 | " epoch | \n",
1187 | " train_loss | \n",
1188 | " valid_loss | \n",
1189 | " accuracy | \n",
1190 | "
\n",
1191 | " \n",
1192 | " 1 | \n",
1193 | " 2.538904 | \n",
1194 | " 1.896686 | \n",
1195 | " 0.537917 | \n",
1196 | "
\n",
1197 | " \n",
1198 | " 2 | \n",
1199 | " 1.700987 | \n",
1200 | " 1.219014 | \n",
1201 | " 0.660167 | \n",
1202 | "
\n",
1203 | " \n",
1204 | " 3 | \n",
1205 | " 1.304718 | \n",
1206 | " 0.952666 | \n",
1207 | " 0.716250 | \n",
1208 | "
\n",
1209 | " \n",
1210 | " 4 | \n",
1211 | " 1.097150 | \n",
1212 | " 0.850275 | \n",
1213 | " 0.743417 | \n",
1214 | "
\n",
1215 | " \n",
1216 | " 5 | \n",
1217 | " 1.008206 | \n",
1218 | " 0.781421 | \n",
1219 | " 0.758083 | \n",
1220 | "
\n",
1221 | " \n",
1222 | " 6 | \n",
1223 | " 0.947150 | \n",
1224 | " 0.740677 | \n",
1225 | " 0.770500 | \n",
1226 | "
\n",
1227 | " \n",
1228 | " 7 | \n",
1229 | " 0.891749 | \n",
1230 | " 0.720329 | \n",
1231 | " 0.778250 | \n",
1232 | "
\n",
1233 | " \n",
1234 | " 8 | \n",
1235 | " 0.800058 | \n",
1236 | " 0.686394 | \n",
1237 | " 0.791250 | \n",
1238 | "
\n",
1239 | " \n",
1240 | " 9 | \n",
1241 | " 0.771210 | \n",
1242 | " 0.674660 | \n",
1243 | " 0.793167 | \n",
1244 | "
\n",
1245 | " \n",
1246 | " 10 | \n",
1247 | " 0.744009 | \n",
1248 | " 0.653407 | \n",
1249 | " 0.798000 | \n",
1250 | "
\n",
1251 | " \n",
1252 | " 11 | \n",
1253 | " 0.710490 | \n",
1254 | " 0.635136 | \n",
1255 | " 0.803083 | \n",
1256 | "
\n",
1257 | " \n",
1258 | " 12 | \n",
1259 | " 0.653651 | \n",
1260 | " 0.618544 | \n",
1261 | " 0.810833 | \n",
1262 | "
\n",
1263 | " \n",
1264 | " 13 | \n",
1265 | " 0.654817 | \n",
1266 | " 0.612365 | \n",
1267 | " 0.811500 | \n",
1268 | "
\n",
1269 | " \n",
1270 | " 14 | \n",
1271 | " 0.604655 | \n",
1272 | " 0.610485 | \n",
1273 | " 0.818000 | \n",
1274 | "
\n",
1275 | " \n",
1276 | " 15 | \n",
1277 | " 0.593437 | \n",
1278 | " 0.591317 | \n",
1279 | " 0.820833 | \n",
1280 | "
\n",
1281 | " \n",
1282 | " 16 | \n",
1283 | " 0.553768 | \n",
1284 | " 0.596443 | \n",
1285 | " 0.819417 | \n",
1286 | "
\n",
1287 | " \n",
1288 | " 17 | \n",
1289 | " 0.507188 | \n",
1290 | " 0.585333 | \n",
1291 | " 0.825250 | \n",
1292 | "
\n",
1293 | " \n",
1294 | " 18 | \n",
1295 | " 0.493085 | \n",
1296 | " 0.583739 | \n",
1297 | " 0.829417 | \n",
1298 | "
\n",
1299 | " \n",
1300 | " 19 | \n",
1301 | " 0.476964 | \n",
1302 | " 0.575371 | \n",
1303 | " 0.828667 | \n",
1304 | "
\n",
1305 | " \n",
1306 | " 20 | \n",
1307 | " 0.455830 | \n",
1308 | " 0.569629 | \n",
1309 | " 0.831167 | \n",
1310 | "
\n",
1311 | " \n",
1312 | " 21 | \n",
1313 | " 0.447361 | \n",
1314 | " 0.572533 | \n",
1315 | " 0.831417 | \n",
1316 | "
\n",
1317 | " \n",
1318 | " 22 | \n",
1319 | " 0.407300 | \n",
1320 | " 0.567695 | \n",
1321 | " 0.833667 | \n",
1322 | "
\n",
1323 | " \n",
1324 | " 23 | \n",
1325 | " 0.383044 | \n",
1326 | " 0.565551 | \n",
1327 | " 0.833667 | \n",
1328 | "
\n",
1329 | " \n",
1330 | " 24 | \n",
1331 | " 0.369885 | \n",
1332 | " 0.563926 | \n",
1333 | " 0.835917 | \n",
1334 | "
\n",
1335 | " \n",
1336 | " 25 | \n",
1337 | " 0.376316 | \n",
1338 | " 0.561316 | \n",
1339 | " 0.836833 | \n",
1340 | "
\n",
1341 | " \n",
1342 | " 26 | \n",
1343 | " 0.335686 | \n",
1344 | " 0.558788 | \n",
1345 | " 0.838000 | \n",
1346 | "
\n",
1347 | " \n",
1348 | " 27 | \n",
1349 | " 0.346769 | \n",
1350 | " 0.560210 | \n",
1351 | " 0.838083 | \n",
1352 | "
\n",
1353 | " \n",
1354 | " 28 | \n",
1355 | " 0.340222 | \n",
1356 | " 0.561173 | \n",
1357 | " 0.839667 | \n",
1358 | "
\n",
1359 | " \n",
1360 | " 29 | \n",
1361 | " 0.327940 | \n",
1362 | " 0.558264 | \n",
1363 | " 0.839667 | \n",
1364 | "
\n",
1365 | " \n",
1366 | " 30 | \n",
1367 | " 0.316947 | \n",
1368 | " 0.559535 | \n",
1369 | " 0.839333 | \n",
1370 | "
\n",
1371 | "
\n"
1372 | ],
1373 | "text/plain": [
1374 | ""
1375 | ]
1376 | },
1377 | "metadata": {},
1378 | "output_type": "display_data"
1379 | }
1380 | ],
1381 | "source": [
1382 | "with gpu_mem_restore_ctx():\n",
1383 | " learn.fit_one_cycle(30)"
1384 | ]
1385 | },
1386 | {
1387 | "cell_type": "markdown",
1388 | "metadata": {},
1389 | "source": [
1390 | "## Resnet 50 (Mixed Precision)"
1391 | ]
1392 | },
1393 | {
1394 | "cell_type": "code",
1395 | "execution_count": 57,
1396 | "metadata": {},
1397 | "outputs": [],
1398 | "source": [
1399 | "np.random.seed(42)\n",
1400 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1401 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=164).normalize(cifar_stats)"
1402 | ]
1403 | },
1404 | {
1405 | "cell_type": "code",
1406 | "execution_count": 58,
1407 | "metadata": {},
1408 | "outputs": [],
1409 | "source": [
1410 | "learn = to_fp16(create_cnn(data, models.resnet50, metrics=accuracy))"
1411 | ]
1412 | },
1413 | {
1414 | "cell_type": "code",
1415 | "execution_count": 59,
1416 | "metadata": {},
1417 | "outputs": [
1418 | {
1419 | "data": {
1420 | "text/html": [
1421 | "Total time: 53:38 \n",
1422 | " \n",
1423 | " epoch | \n",
1424 | " train_loss | \n",
1425 | " valid_loss | \n",
1426 | " accuracy | \n",
1427 | "
\n",
1428 | " \n",
1429 | " 1 | \n",
1430 | " 2.853479 | \n",
1431 | " 2.087002 | \n",
1432 | " 0.499583 | \n",
1433 | "
\n",
1434 | " \n",
1435 | " 2 | \n",
1436 | " 1.851662 | \n",
1437 | " 1.342598 | \n",
1438 | " 0.631250 | \n",
1439 | "
\n",
1440 | " \n",
1441 | " 3 | \n",
1442 | " 1.387896 | \n",
1443 | " 1.018817 | \n",
1444 | " 0.699417 | \n",
1445 | "
\n",
1446 | " \n",
1447 | " 4 | \n",
1448 | " 1.165870 | \n",
1449 | " 0.866866 | \n",
1450 | " 0.736500 | \n",
1451 | "
\n",
1452 | " \n",
1453 | " 5 | \n",
1454 | " 1.031915 | \n",
1455 | " 0.790143 | \n",
1456 | " 0.755833 | \n",
1457 | "
\n",
1458 | " \n",
1459 | " 6 | \n",
1460 | " 0.921515 | \n",
1461 | " 0.759719 | \n",
1462 | " 0.765500 | \n",
1463 | "
\n",
1464 | " \n",
1465 | " 7 | \n",
1466 | " 0.880313 | \n",
1467 | " 0.717190 | \n",
1468 | " 0.778917 | \n",
1469 | "
\n",
1470 | " \n",
1471 | " 8 | \n",
1472 | " 0.838310 | \n",
1473 | " 0.697754 | \n",
1474 | " 0.784250 | \n",
1475 | "
\n",
1476 | " \n",
1477 | " 9 | \n",
1478 | " 0.775729 | \n",
1479 | " 0.681422 | \n",
1480 | " 0.791750 | \n",
1481 | "
\n",
1482 | " \n",
1483 | " 10 | \n",
1484 | " 0.725294 | \n",
1485 | " 0.649651 | \n",
1486 | " 0.801583 | \n",
1487 | "
\n",
1488 | " \n",
1489 | " 11 | \n",
1490 | " 0.692380 | \n",
1491 | " 0.643033 | \n",
1492 | " 0.801167 | \n",
1493 | "
\n",
1494 | " \n",
1495 | " 12 | \n",
1496 | " 0.661506 | \n",
1497 | " 0.633204 | \n",
1498 | " 0.805750 | \n",
1499 | "
\n",
1500 | " \n",
1501 | " 13 | \n",
1502 | " 0.616202 | \n",
1503 | " 0.624044 | \n",
1504 | " 0.810333 | \n",
1505 | "
\n",
1506 | " \n",
1507 | " 14 | \n",
1508 | " 0.591845 | \n",
1509 | " 0.612394 | \n",
1510 | " 0.812417 | \n",
1511 | "
\n",
1512 | " \n",
1513 | " 15 | \n",
1514 | " 0.566731 | \n",
1515 | " 0.603592 | \n",
1516 | " 0.818500 | \n",
1517 | "
\n",
1518 | " \n",
1519 | " 16 | \n",
1520 | " 0.542707 | \n",
1521 | " 0.597049 | \n",
1522 | " 0.818333 | \n",
1523 | "
\n",
1524 | " \n",
1525 | " 17 | \n",
1526 | " 0.506298 | \n",
1527 | " 0.596841 | \n",
1528 | " 0.822417 | \n",
1529 | "
\n",
1530 | " \n",
1531 | " 18 | \n",
1532 | " 0.492856 | \n",
1533 | " 0.588816 | \n",
1534 | " 0.824667 | \n",
1535 | "
\n",
1536 | " \n",
1537 | " 19 | \n",
1538 | " 0.450351 | \n",
1539 | " 0.585315 | \n",
1540 | " 0.826667 | \n",
1541 | "
\n",
1542 | " \n",
1543 | " 20 | \n",
1544 | " 0.441241 | \n",
1545 | " 0.580710 | \n",
1546 | " 0.827167 | \n",
1547 | "
\n",
1548 | " \n",
1549 | " 21 | \n",
1550 | " 0.420925 | \n",
1551 | " 0.578927 | \n",
1552 | " 0.828500 | \n",
1553 | "
\n",
1554 | " \n",
1555 | " 22 | \n",
1556 | " 0.392566 | \n",
1557 | " 0.581997 | \n",
1558 | " 0.831417 | \n",
1559 | "
\n",
1560 | " \n",
1561 | " 23 | \n",
1562 | " 0.374037 | \n",
1563 | " 0.571375 | \n",
1564 | " 0.832500 | \n",
1565 | "
\n",
1566 | " \n",
1567 | " 24 | \n",
1568 | " 0.366736 | \n",
1569 | " 0.572550 | \n",
1570 | " 0.832583 | \n",
1571 | "
\n",
1572 | " \n",
1573 | " 25 | \n",
1574 | " 0.339871 | \n",
1575 | " 0.569914 | \n",
1576 | " 0.831500 | \n",
1577 | "
\n",
1578 | " \n",
1579 | " 26 | \n",
1580 | " 0.344734 | \n",
1581 | " 0.568628 | \n",
1582 | " 0.833083 | \n",
1583 | "
\n",
1584 | " \n",
1585 | " 27 | \n",
1586 | " 0.322793 | \n",
1587 | " 0.567510 | \n",
1588 | " 0.832667 | \n",
1589 | "
\n",
1590 | " \n",
1591 | " 28 | \n",
1592 | " 0.307898 | \n",
1593 | " 0.565839 | \n",
1594 | " 0.835750 | \n",
1595 | "
\n",
1596 | " \n",
1597 | " 29 | \n",
1598 | " 0.322832 | \n",
1599 | " 0.566465 | \n",
1600 | " 0.832583 | \n",
1601 | "
\n",
1602 | " \n",
1603 | " 30 | \n",
1604 | " 0.317137 | \n",
1605 | " 0.568857 | \n",
1606 | " 0.832000 | \n",
1607 | "
\n",
1608 | "
\n"
1609 | ],
1610 | "text/plain": [
1611 | ""
1612 | ]
1613 | },
1614 | "metadata": {},
1615 | "output_type": "display_data"
1616 | }
1617 | ],
1618 | "source": [
1619 | "with gpu_mem_restore_ctx():\n",
1620 | " learn.fit_one_cycle(30)"
1621 | ]
1622 | },
1623 | {
1624 | "cell_type": "markdown",
1625 | "metadata": {},
1626 | "source": [
1627 | "## Resnet 101"
1628 | ]
1629 | },
1630 | {
1631 | "cell_type": "code",
1632 | "execution_count": 161,
1633 | "metadata": {},
1634 | "outputs": [],
1635 | "source": [
1636 | "np.random.seed(42)\n",
1637 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1638 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=64).normalize(cifar_stats)"
1639 | ]
1640 | },
1641 | {
1642 | "cell_type": "code",
1643 | "execution_count": 162,
1644 | "metadata": {},
1645 | "outputs": [],
1646 | "source": [
1647 | "learn = create_cnn(data, models.resnet101, metrics=accuracy)"
1648 | ]
1649 | },
1650 | {
1651 | "cell_type": "code",
1652 | "execution_count": 163,
1653 | "metadata": {},
1654 | "outputs": [
1655 | {
1656 | "data": {
1657 | "text/html": [
1658 | "Total time: 2:01:58 \n",
1659 | " \n",
1660 | " epoch | \n",
1661 | " train_loss | \n",
1662 | " valid_loss | \n",
1663 | " accuracy | \n",
1664 | "
\n",
1665 | " \n",
1666 | " 1 | \n",
1667 | " 2.055326 | \n",
1668 | " 1.447209 | \n",
1669 | " 0.630000 | \n",
1670 | "
\n",
1671 | " \n",
1672 | " 2 | \n",
1673 | " 1.328482 | \n",
1674 | " 0.938083 | \n",
1675 | " 0.723333 | \n",
1676 | "
\n",
1677 | " \n",
1678 | " 3 | \n",
1679 | " 1.065412 | \n",
1680 | " 0.762804 | \n",
1681 | " 0.767833 | \n",
1682 | "
\n",
1683 | " \n",
1684 | " 4 | \n",
1685 | " 0.893859 | \n",
1686 | " 0.681791 | \n",
1687 | " 0.788333 | \n",
1688 | "
\n",
1689 | " \n",
1690 | " 5 | \n",
1691 | " 0.857969 | \n",
1692 | " 0.640077 | \n",
1693 | " 0.801417 | \n",
1694 | "
\n",
1695 | " \n",
1696 | " 6 | \n",
1697 | " 0.766642 | \n",
1698 | " 0.637175 | \n",
1699 | " 0.802417 | \n",
1700 | "
\n",
1701 | " \n",
1702 | " 7 | \n",
1703 | " 0.781614 | \n",
1704 | " 0.625195 | \n",
1705 | " 0.804750 | \n",
1706 | "
\n",
1707 | " \n",
1708 | " 8 | \n",
1709 | " 0.731761 | \n",
1710 | " 0.596706 | \n",
1711 | " 0.817333 | \n",
1712 | "
\n",
1713 | " \n",
1714 | " 9 | \n",
1715 | " 0.682865 | \n",
1716 | " 0.582054 | \n",
1717 | " 0.822250 | \n",
1718 | "
\n",
1719 | " \n",
1720 | " 10 | \n",
1721 | " 0.648590 | \n",
1722 | " 0.560301 | \n",
1723 | " 0.828667 | \n",
1724 | "
\n",
1725 | " \n",
1726 | " 11 | \n",
1727 | " 0.618241 | \n",
1728 | " 0.556258 | \n",
1729 | " 0.829667 | \n",
1730 | "
\n",
1731 | " \n",
1732 | " 12 | \n",
1733 | " 0.574654 | \n",
1734 | " 0.547130 | \n",
1735 | " 0.835917 | \n",
1736 | "
\n",
1737 | " \n",
1738 | " 13 | \n",
1739 | " 0.530491 | \n",
1740 | " 0.540169 | \n",
1741 | " 0.837667 | \n",
1742 | "
\n",
1743 | " \n",
1744 | " 14 | \n",
1745 | " 0.528848 | \n",
1746 | " 0.534272 | \n",
1747 | " 0.839333 | \n",
1748 | "
\n",
1749 | " \n",
1750 | " 15 | \n",
1751 | " 0.491567 | \n",
1752 | " 0.526254 | \n",
1753 | " 0.840750 | \n",
1754 | "
\n",
1755 | " \n",
1756 | " 16 | \n",
1757 | " 0.451276 | \n",
1758 | " 0.531170 | \n",
1759 | " 0.842167 | \n",
1760 | "
\n",
1761 | " \n",
1762 | " 17 | \n",
1763 | " 0.438837 | \n",
1764 | " 0.518594 | \n",
1765 | " 0.847167 | \n",
1766 | "
\n",
1767 | " \n",
1768 | " 18 | \n",
1769 | " 0.433088 | \n",
1770 | " 0.524159 | \n",
1771 | " 0.845750 | \n",
1772 | "
\n",
1773 | " \n",
1774 | " 19 | \n",
1775 | " 0.381678 | \n",
1776 | " 0.522843 | \n",
1777 | " 0.847667 | \n",
1778 | "
\n",
1779 | " \n",
1780 | " 20 | \n",
1781 | " 0.365216 | \n",
1782 | " 0.521657 | \n",
1783 | " 0.848917 | \n",
1784 | "
\n",
1785 | " \n",
1786 | " 21 | \n",
1787 | " 0.343347 | \n",
1788 | " 0.517553 | \n",
1789 | " 0.852500 | \n",
1790 | "
\n",
1791 | " \n",
1792 | " 22 | \n",
1793 | " 0.339596 | \n",
1794 | " 0.507800 | \n",
1795 | " 0.854000 | \n",
1796 | "
\n",
1797 | " \n",
1798 | " 23 | \n",
1799 | " 0.309928 | \n",
1800 | " 0.514493 | \n",
1801 | " 0.854917 | \n",
1802 | "
\n",
1803 | " \n",
1804 | " 24 | \n",
1805 | " 0.297399 | \n",
1806 | " 0.506818 | \n",
1807 | " 0.856000 | \n",
1808 | "
\n",
1809 | " \n",
1810 | " 25 | \n",
1811 | " 0.268865 | \n",
1812 | " 0.504412 | \n",
1813 | " 0.857667 | \n",
1814 | "
\n",
1815 | " \n",
1816 | " 26 | \n",
1817 | " 0.262993 | \n",
1818 | " 0.505606 | \n",
1819 | " 0.860333 | \n",
1820 | "
\n",
1821 | " \n",
1822 | " 27 | \n",
1823 | " 0.244982 | \n",
1824 | " 0.503122 | \n",
1825 | " 0.858667 | \n",
1826 | "
\n",
1827 | " \n",
1828 | " 28 | \n",
1829 | " 0.249251 | \n",
1830 | " 0.501970 | \n",
1831 | " 0.860833 | \n",
1832 | "
\n",
1833 | " \n",
1834 | " 29 | \n",
1835 | " 0.254412 | \n",
1836 | " 0.504455 | \n",
1837 | " 0.861500 | \n",
1838 | "
\n",
1839 | " \n",
1840 | " 30 | \n",
1841 | " 0.244284 | \n",
1842 | " 0.502431 | \n",
1843 | " 0.859750 | \n",
1844 | "
\n",
1845 | "
\n"
1846 | ],
1847 | "text/plain": [
1848 | ""
1849 | ]
1850 | },
1851 | "metadata": {},
1852 | "output_type": "display_data"
1853 | }
1854 | ],
1855 | "source": [
1856 | "with gpu_mem_restore_ctx():\n",
1857 | " learn.fit_one_cycle(30)"
1858 | ]
1859 | },
1860 | {
1861 | "cell_type": "markdown",
1862 | "metadata": {},
1863 | "source": [
1864 | "## Resnet 101 (Mixed Precision)"
1865 | ]
1866 | },
1867 | {
1868 | "cell_type": "code",
1869 | "execution_count": 164,
1870 | "metadata": {},
1871 | "outputs": [],
1872 | "source": [
1873 | "np.random.seed(42)\n",
1874 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
1875 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=116).normalize(cifar_stats)"
1876 | ]
1877 | },
1878 | {
1879 | "cell_type": "code",
1880 | "execution_count": 165,
1881 | "metadata": {},
1882 | "outputs": [],
1883 | "source": [
1884 | "learn = to_fp16(create_cnn(data, models.resnet101, metrics=accuracy))"
1885 | ]
1886 | },
1887 | {
1888 | "cell_type": "code",
1889 | "execution_count": 166,
1890 | "metadata": {},
1891 | "outputs": [
1892 | {
1893 | "data": {
1894 | "text/html": [
1895 | "Total time: 1:22:32 \n",
1896 | " \n",
1897 | " epoch | \n",
1898 | " train_loss | \n",
1899 | " valid_loss | \n",
1900 | " accuracy | \n",
1901 | "
\n",
1902 | " \n",
1903 | " 1 | \n",
1904 | " 2.356006 | \n",
1905 | " 1.653988 | \n",
1906 | " 0.594750 | \n",
1907 | "
\n",
1908 | " \n",
1909 | " 2 | \n",
1910 | " 1.458173 | \n",
1911 | " 1.015118 | \n",
1912 | " 0.709333 | \n",
1913 | "
\n",
1914 | " \n",
1915 | " 3 | \n",
1916 | " 1.113660 | \n",
1917 | " 0.803704 | \n",
1918 | " 0.758167 | \n",
1919 | "
\n",
1920 | " \n",
1921 | " 4 | \n",
1922 | " 0.931843 | \n",
1923 | " 0.710339 | \n",
1924 | " 0.782333 | \n",
1925 | "
\n",
1926 | " \n",
1927 | " 5 | \n",
1928 | " 0.844236 | \n",
1929 | " 0.666915 | \n",
1930 | " 0.791833 | \n",
1931 | "
\n",
1932 | " \n",
1933 | " 6 | \n",
1934 | " 0.785052 | \n",
1935 | " 0.625276 | \n",
1936 | " 0.807167 | \n",
1937 | "
\n",
1938 | " \n",
1939 | " 7 | \n",
1940 | " 0.725867 | \n",
1941 | " 0.601363 | \n",
1942 | " 0.815750 | \n",
1943 | "
\n",
1944 | " \n",
1945 | " 8 | \n",
1946 | " 0.704728 | \n",
1947 | " 0.592559 | \n",
1948 | " 0.817833 | \n",
1949 | "
\n",
1950 | " \n",
1951 | " 9 | \n",
1952 | " 0.656651 | \n",
1953 | " 0.576617 | \n",
1954 | " 0.821583 | \n",
1955 | "
\n",
1956 | " \n",
1957 | " 10 | \n",
1958 | " 0.599608 | \n",
1959 | " 0.572227 | \n",
1960 | " 0.822667 | \n",
1961 | "
\n",
1962 | " \n",
1963 | " 11 | \n",
1964 | " 0.569434 | \n",
1965 | " 0.566142 | \n",
1966 | " 0.828750 | \n",
1967 | "
\n",
1968 | " \n",
1969 | " 12 | \n",
1970 | " 0.538815 | \n",
1971 | " 0.539515 | \n",
1972 | " 0.835917 | \n",
1973 | "
\n",
1974 | " \n",
1975 | " 13 | \n",
1976 | " 0.508781 | \n",
1977 | " 0.536017 | \n",
1978 | " 0.835250 | \n",
1979 | "
\n",
1980 | " \n",
1981 | " 14 | \n",
1982 | " 0.471497 | \n",
1983 | " 0.535257 | \n",
1984 | " 0.839833 | \n",
1985 | "
\n",
1986 | " \n",
1987 | " 15 | \n",
1988 | " 0.455695 | \n",
1989 | " 0.540152 | \n",
1990 | " 0.836750 | \n",
1991 | "
\n",
1992 | " \n",
1993 | " 16 | \n",
1994 | " 0.437731 | \n",
1995 | " 0.531233 | \n",
1996 | " 0.841583 | \n",
1997 | "
\n",
1998 | " \n",
1999 | " 17 | \n",
2000 | " 0.398460 | \n",
2001 | " 0.528875 | \n",
2002 | " 0.847250 | \n",
2003 | "
\n",
2004 | " \n",
2005 | " 18 | \n",
2006 | " 0.370818 | \n",
2007 | " 0.521328 | \n",
2008 | " 0.851000 | \n",
2009 | "
\n",
2010 | " \n",
2011 | " 19 | \n",
2012 | " 0.361273 | \n",
2013 | " 0.524604 | \n",
2014 | " 0.851583 | \n",
2015 | "
\n",
2016 | " \n",
2017 | " 20 | \n",
2018 | " 0.328099 | \n",
2019 | " 0.516841 | \n",
2020 | " 0.851167 | \n",
2021 | "
\n",
2022 | " \n",
2023 | " 21 | \n",
2024 | " 0.309971 | \n",
2025 | " 0.510228 | \n",
2026 | " 0.854833 | \n",
2027 | "
\n",
2028 | " \n",
2029 | " 22 | \n",
2030 | " 0.300888 | \n",
2031 | " 0.518774 | \n",
2032 | " 0.852333 | \n",
2033 | "
\n",
2034 | " \n",
2035 | " 23 | \n",
2036 | " 0.289590 | \n",
2037 | " 0.506502 | \n",
2038 | " 0.856250 | \n",
2039 | "
\n",
2040 | " \n",
2041 | " 24 | \n",
2042 | " 0.258556 | \n",
2043 | " 0.513164 | \n",
2044 | " 0.855417 | \n",
2045 | "
\n",
2046 | " \n",
2047 | " 25 | \n",
2048 | " 0.254689 | \n",
2049 | " 0.512996 | \n",
2050 | " 0.857583 | \n",
2051 | "
\n",
2052 | " \n",
2053 | " 26 | \n",
2054 | " 0.232998 | \n",
2055 | " 0.516241 | \n",
2056 | " 0.857750 | \n",
2057 | "
\n",
2058 | " \n",
2059 | " 27 | \n",
2060 | " 0.238141 | \n",
2061 | " 0.514106 | \n",
2062 | " 0.857417 | \n",
2063 | "
\n",
2064 | " \n",
2065 | " 28 | \n",
2066 | " 0.217220 | \n",
2067 | " 0.514888 | \n",
2068 | " 0.856250 | \n",
2069 | "
\n",
2070 | " \n",
2071 | " 29 | \n",
2072 | " 0.224265 | \n",
2073 | " 0.513912 | \n",
2074 | " 0.856167 | \n",
2075 | "
\n",
2076 | " \n",
2077 | " 30 | \n",
2078 | " 0.220724 | \n",
2079 | " 0.514018 | \n",
2080 | " 0.856750 | \n",
2081 | "
\n",
2082 | "
\n"
2083 | ],
2084 | "text/plain": [
2085 | ""
2086 | ]
2087 | },
2088 | "metadata": {},
2089 | "output_type": "display_data"
2090 | }
2091 | ],
2092 | "source": [
2093 | "with gpu_mem_restore_ctx():\n",
2094 | " learn.fit_one_cycle(30)"
2095 | ]
2096 | },
2097 | {
2098 | "cell_type": "markdown",
2099 | "metadata": {},
2100 | "source": [
2101 | "## Resnet 152"
2102 | ]
2103 | },
2104 | {
2105 | "cell_type": "code",
2106 | "execution_count": 65,
2107 | "metadata": {},
2108 | "outputs": [],
2109 | "source": [
2110 | "np.random.seed(42)\n",
2111 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
2112 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=48).normalize(cifar_stats)"
2113 | ]
2114 | },
2115 | {
2116 | "cell_type": "code",
2117 | "execution_count": 66,
2118 | "metadata": {},
2119 | "outputs": [],
2120 | "source": [
2121 | "learn = create_cnn(data, models.resnet152, metrics=accuracy)"
2122 | ]
2123 | },
2124 | {
2125 | "cell_type": "code",
2126 | "execution_count": 67,
2127 | "metadata": {},
2128 | "outputs": [
2129 | {
2130 | "data": {
2131 | "text/html": [
2132 | "Total time: 2:58:27 \n",
2133 | " \n",
2134 | " epoch | \n",
2135 | " train_loss | \n",
2136 | " valid_loss | \n",
2137 | " accuracy | \n",
2138 | "
\n",
2139 | " \n",
2140 | " 1 | \n",
2141 | " 1.898828 | \n",
2142 | " 1.305966 | \n",
2143 | " 0.653833 | \n",
2144 | "
\n",
2145 | " \n",
2146 | " 2 | \n",
2147 | " 1.175181 | \n",
2148 | " 0.842290 | \n",
2149 | " 0.747250 | \n",
2150 | "
\n",
2151 | " \n",
2152 | " 3 | \n",
2153 | " 0.966224 | \n",
2154 | " 0.699221 | \n",
2155 | " 0.786167 | \n",
2156 | "
\n",
2157 | " \n",
2158 | " 4 | \n",
2159 | " 0.884034 | \n",
2160 | " 0.651086 | \n",
2161 | " 0.795667 | \n",
2162 | "
\n",
2163 | " \n",
2164 | " 5 | \n",
2165 | " 0.792795 | \n",
2166 | " 0.621682 | \n",
2167 | " 0.804750 | \n",
2168 | "
\n",
2169 | " \n",
2170 | " 6 | \n",
2171 | " 0.795135 | \n",
2172 | " 0.593315 | \n",
2173 | " 0.819083 | \n",
2174 | "
\n",
2175 | " \n",
2176 | " 7 | \n",
2177 | " 0.723622 | \n",
2178 | " 0.596839 | \n",
2179 | " 0.816833 | \n",
2180 | "
\n",
2181 | " \n",
2182 | " 8 | \n",
2183 | " 0.658714 | \n",
2184 | " 0.581674 | \n",
2185 | " 0.821417 | \n",
2186 | "
\n",
2187 | " \n",
2188 | " 9 | \n",
2189 | " 0.670455 | \n",
2190 | " 0.553189 | \n",
2191 | " 0.833000 | \n",
2192 | "
\n",
2193 | " \n",
2194 | " 10 | \n",
2195 | " 0.622607 | \n",
2196 | " 0.560935 | \n",
2197 | " 0.830250 | \n",
2198 | "
\n",
2199 | " \n",
2200 | " 11 | \n",
2201 | " 0.587018 | \n",
2202 | " 0.533542 | \n",
2203 | " 0.836250 | \n",
2204 | "
\n",
2205 | " \n",
2206 | " 12 | \n",
2207 | " 0.580308 | \n",
2208 | " 0.520549 | \n",
2209 | " 0.841167 | \n",
2210 | "
\n",
2211 | " \n",
2212 | " 13 | \n",
2213 | " 0.521300 | \n",
2214 | " 0.532823 | \n",
2215 | " 0.840583 | \n",
2216 | "
\n",
2217 | " \n",
2218 | " 14 | \n",
2219 | " 0.487249 | \n",
2220 | " 0.526109 | \n",
2221 | " 0.842750 | \n",
2222 | "
\n",
2223 | " \n",
2224 | " 15 | \n",
2225 | " 0.448812 | \n",
2226 | " 0.513343 | \n",
2227 | " 0.850167 | \n",
2228 | "
\n",
2229 | " \n",
2230 | " 16 | \n",
2231 | " 0.431717 | \n",
2232 | " 0.506566 | \n",
2233 | " 0.850667 | \n",
2234 | "
\n",
2235 | " \n",
2236 | " 17 | \n",
2237 | " 0.401086 | \n",
2238 | " 0.506412 | \n",
2239 | " 0.851917 | \n",
2240 | "
\n",
2241 | " \n",
2242 | " 18 | \n",
2243 | " 0.398720 | \n",
2244 | " 0.505510 | \n",
2245 | " 0.853083 | \n",
2246 | "
\n",
2247 | " \n",
2248 | " 19 | \n",
2249 | " 0.361149 | \n",
2250 | " 0.500569 | \n",
2251 | " 0.855833 | \n",
2252 | "
\n",
2253 | " \n",
2254 | " 20 | \n",
2255 | " 0.357157 | \n",
2256 | " 0.493713 | \n",
2257 | " 0.856417 | \n",
2258 | "
\n",
2259 | " \n",
2260 | " 21 | \n",
2261 | " 0.319523 | \n",
2262 | " 0.496027 | \n",
2263 | " 0.860000 | \n",
2264 | "
\n",
2265 | " \n",
2266 | " 22 | \n",
2267 | " 0.284045 | \n",
2268 | " 0.494504 | \n",
2269 | " 0.861000 | \n",
2270 | "
\n",
2271 | " \n",
2272 | " 23 | \n",
2273 | " 0.253644 | \n",
2274 | " 0.493373 | \n",
2275 | " 0.861333 | \n",
2276 | "
\n",
2277 | " \n",
2278 | " 24 | \n",
2279 | " 0.275702 | \n",
2280 | " 0.500611 | \n",
2281 | " 0.861167 | \n",
2282 | "
\n",
2283 | " \n",
2284 | " 25 | \n",
2285 | " 0.248939 | \n",
2286 | " 0.488999 | \n",
2287 | " 0.863083 | \n",
2288 | "
\n",
2289 | " \n",
2290 | " 26 | \n",
2291 | " 0.233902 | \n",
2292 | " 0.496497 | \n",
2293 | " 0.862333 | \n",
2294 | "
\n",
2295 | " \n",
2296 | " 27 | \n",
2297 | " 0.218012 | \n",
2298 | " 0.491988 | \n",
2299 | " 0.864417 | \n",
2300 | "
\n",
2301 | " \n",
2302 | " 28 | \n",
2303 | " 0.225783 | \n",
2304 | " 0.493372 | \n",
2305 | " 0.864667 | \n",
2306 | "
\n",
2307 | " \n",
2308 | " 29 | \n",
2309 | " 0.217687 | \n",
2310 | " 0.493124 | \n",
2311 | " 0.866333 | \n",
2312 | "
\n",
2313 | " \n",
2314 | " 30 | \n",
2315 | " 0.202206 | \n",
2316 | " 0.491947 | \n",
2317 | " 0.863083 | \n",
2318 | "
\n",
2319 | "
\n"
2320 | ],
2321 | "text/plain": [
2322 | ""
2323 | ]
2324 | },
2325 | "metadata": {},
2326 | "output_type": "display_data"
2327 | }
2328 | ],
2329 | "source": [
2330 | "with gpu_mem_restore_ctx():\n",
2331 | " learn.fit_one_cycle(30)"
2332 | ]
2333 | },
2334 | {
2335 | "cell_type": "markdown",
2336 | "metadata": {},
2337 | "source": [
2338 | "## Resnet 152 (Mixed Precision)"
2339 | ]
2340 | },
2341 | {
2342 | "cell_type": "code",
2343 | "execution_count": 167,
2344 | "metadata": {},
2345 | "outputs": [],
2346 | "source": [
2347 | "np.random.seed(42)\n",
2348 | "data = ImageDataBunch.from_folder(path, train=train, valid_pct=0.2,\n",
2349 | " ds_tfms=get_transforms(), size=224, num_workers=4, bs=88).normalize(cifar_stats)"
2350 | ]
2351 | },
2352 | {
2353 | "cell_type": "code",
2354 | "execution_count": 168,
2355 | "metadata": {},
2356 | "outputs": [],
2357 | "source": [
2358 | "learn = to_fp16(create_cnn(data, models.resnet152, metrics=accuracy))"
2359 | ]
2360 | },
2361 | {
2362 | "cell_type": "code",
2363 | "execution_count": 169,
2364 | "metadata": {},
2365 | "outputs": [
2366 | {
2367 | "data": {
2368 | "text/html": [
2369 | "Total time: 1:51:21 \n",
2370 | " \n",
2371 | " epoch | \n",
2372 | " train_loss | \n",
2373 | " valid_loss | \n",
2374 | " accuracy | \n",
2375 | "
\n",
2376 | " \n",
2377 | " 1 | \n",
2378 | " 2.160135 | \n",
2379 | " 1.495938 | \n",
2380 | " 0.625750 | \n",
2381 | "
\n",
2382 | " \n",
2383 | " 2 | \n",
2384 | " 1.294585 | \n",
2385 | " 0.920020 | \n",
2386 | " 0.733083 | \n",
2387 | "
\n",
2388 | " \n",
2389 | " 3 | \n",
2390 | " 0.984379 | \n",
2391 | " 0.727062 | \n",
2392 | " 0.776750 | \n",
2393 | "
\n",
2394 | " \n",
2395 | " 4 | \n",
2396 | " 0.865696 | \n",
2397 | " 0.641753 | \n",
2398 | " 0.800250 | \n",
2399 | "
\n",
2400 | " \n",
2401 | " 5 | \n",
2402 | " 0.777403 | \n",
2403 | " 0.599795 | \n",
2404 | " 0.815000 | \n",
2405 | "
\n",
2406 | " \n",
2407 | " 6 | \n",
2408 | " 0.726824 | \n",
2409 | " 0.607037 | \n",
2410 | " 0.811833 | \n",
2411 | "
\n",
2412 | " \n",
2413 | " 7 | \n",
2414 | " 0.695808 | \n",
2415 | " 0.577306 | \n",
2416 | " 0.824917 | \n",
2417 | "
\n",
2418 | " \n",
2419 | " 8 | \n",
2420 | " 0.663318 | \n",
2421 | " 0.561540 | \n",
2422 | " 0.828833 | \n",
2423 | "
\n",
2424 | " \n",
2425 | " 9 | \n",
2426 | " 0.601016 | \n",
2427 | " 0.554505 | \n",
2428 | " 0.833417 | \n",
2429 | "
\n",
2430 | " \n",
2431 | " 10 | \n",
2432 | " 0.576079 | \n",
2433 | " 0.546127 | \n",
2434 | " 0.834667 | \n",
2435 | "
\n",
2436 | " \n",
2437 | " 11 | \n",
2438 | " 0.549800 | \n",
2439 | " 0.524947 | \n",
2440 | " 0.842250 | \n",
2441 | "
\n",
2442 | " \n",
2443 | " 12 | \n",
2444 | " 0.494807 | \n",
2445 | " 0.513583 | \n",
2446 | " 0.844250 | \n",
2447 | "
\n",
2448 | " \n",
2449 | " 13 | \n",
2450 | " 0.475365 | \n",
2451 | " 0.516716 | \n",
2452 | " 0.846250 | \n",
2453 | "
\n",
2454 | " \n",
2455 | " 14 | \n",
2456 | " 0.443925 | \n",
2457 | " 0.506661 | \n",
2458 | " 0.849750 | \n",
2459 | "
\n",
2460 | " \n",
2461 | " 15 | \n",
2462 | " 0.410093 | \n",
2463 | " 0.509516 | \n",
2464 | " 0.849250 | \n",
2465 | "
\n",
2466 | " \n",
2467 | " 16 | \n",
2468 | " 0.378922 | \n",
2469 | " 0.499594 | \n",
2470 | " 0.855167 | \n",
2471 | "
\n",
2472 | " \n",
2473 | " 17 | \n",
2474 | " 0.344542 | \n",
2475 | " 0.489011 | \n",
2476 | " 0.856333 | \n",
2477 | "
\n",
2478 | " \n",
2479 | " 18 | \n",
2480 | " 0.340006 | \n",
2481 | " 0.507246 | \n",
2482 | " 0.852333 | \n",
2483 | "
\n",
2484 | " \n",
2485 | " 19 | \n",
2486 | " 0.325086 | \n",
2487 | " 0.486156 | \n",
2488 | " 0.861083 | \n",
2489 | "
\n",
2490 | " \n",
2491 | " 20 | \n",
2492 | " 0.302048 | \n",
2493 | " 0.490937 | \n",
2494 | " 0.857667 | \n",
2495 | "
\n",
2496 | " \n",
2497 | " 21 | \n",
2498 | " 0.288165 | \n",
2499 | " 0.486860 | \n",
2500 | " 0.864250 | \n",
2501 | "
\n",
2502 | " \n",
2503 | " 22 | \n",
2504 | " 0.256561 | \n",
2505 | " 0.492568 | \n",
2506 | " 0.860917 | \n",
2507 | "
\n",
2508 | " \n",
2509 | " 23 | \n",
2510 | " 0.237763 | \n",
2511 | " 0.485718 | \n",
2512 | " 0.862583 | \n",
2513 | "
\n",
2514 | " \n",
2515 | " 24 | \n",
2516 | " 0.224336 | \n",
2517 | " 0.486326 | \n",
2518 | " 0.863333 | \n",
2519 | "
\n",
2520 | " \n",
2521 | " 25 | \n",
2522 | " 0.209552 | \n",
2523 | " 0.485967 | \n",
2524 | " 0.863500 | \n",
2525 | "
\n",
2526 | " \n",
2527 | " 26 | \n",
2528 | " 0.210889 | \n",
2529 | " 0.485575 | \n",
2530 | " 0.865250 | \n",
2531 | "
\n",
2532 | " \n",
2533 | " 27 | \n",
2534 | " 0.202123 | \n",
2535 | " 0.481421 | \n",
2536 | " 0.865833 | \n",
2537 | "
\n",
2538 | " \n",
2539 | " 28 | \n",
2540 | " 0.205988 | \n",
2541 | " 0.486022 | \n",
2542 | " 0.867167 | \n",
2543 | "
\n",
2544 | " \n",
2545 | " 29 | \n",
2546 | " 0.173248 | \n",
2547 | " 0.485957 | \n",
2548 | " 0.865583 | \n",
2549 | "
\n",
2550 | " \n",
2551 | " 30 | \n",
2552 | " 0.182509 | \n",
2553 | " 0.484172 | \n",
2554 | " 0.866750 | \n",
2555 | "
\n",
2556 | "
\n"
2557 | ],
2558 | "text/plain": [
2559 | ""
2560 | ]
2561 | },
2562 | "metadata": {},
2563 | "output_type": "display_data"
2564 | }
2565 | ],
2566 | "source": [
2567 | "with gpu_mem_restore_ctx():\n",
2568 | " learn.fit_one_cycle(30)"
2569 | ]
2570 | }
2571 | ],
2572 | "metadata": {
2573 | "kernelspec": {
2574 | "display_name": "Python 3",
2575 | "language": "python",
2576 | "name": "python3"
2577 | },
2578 | "language_info": {
2579 | "codemirror_mode": {
2580 | "name": "ipython",
2581 | "version": 3
2582 | },
2583 | "file_extension": ".py",
2584 | "mimetype": "text/x-python",
2585 | "name": "python",
2586 | "nbconvert_exporter": "python",
2587 | "pygments_lexer": "ipython3",
2588 | "version": "3.6.6"
2589 | }
2590 | },
2591 | "nbformat": 4,
2592 | "nbformat_minor": 2
2593 | }
2594 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RTX-2080Ti-Vs-GTX-1080Ti-CIFAR-100-Benchmarks
2 |
3 | [Scripts for the accompanying blogpost](https://hackernoon.com/rtx-2080ti-vs-gtx-1080ti-fastai-mixed-precision-training-comparisons-on-cifar-100-761d8f615d7f?source=user_profile---------4------------------)
4 |
5 | For more details, please checkout the blogpost above.
6 |
7 | For details on fast.ai and the latest MOOC (2019), please checkout course.fast.ai
8 |
9 | ### Special Thanks
10 |
11 | Special Thanks to Tuatini Godard for helping me run the tests on his 1080Ti
12 |
--------------------------------------------------------------------------------