├── DatasetConverter
├── build.xml
├── build
│ └── classes
│ │ ├── .netbeans_automatic_build
│ │ ├── .netbeans_update_resources
│ │ └── converters
│ │ ├── Converter$1.class
│ │ ├── Converter.class
│ │ └── ConverterMain.class
├── manifest.mf
├── nbproject
│ ├── build-impl.xml
│ ├── genfiles.properties
│ ├── private
│ │ ├── private.properties
│ │ └── private.xml
│ ├── project.properties
│ └── project.xml
└── src
│ └── converters
│ ├── Converter.java
│ └── ConverterMain.java
├── DeleteExtrasFromDataSet.sh
├── Logs
├── imdbLogs
└── initial_train_logs
├── README.md
├── classify_img.py
├── classify_img_arg.py
├── create_imdb.py
├── freeze_and_convert_to_tflite.sh
├── imgs
├── accuracy_per_epoch.png
├── loss_per_epoch.png
├── network_visualization.png
└── tensorboard_plots.png
├── test.py
├── train_alexnet.py
└── tune_cnn.py
/DatasetConverter/build.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Builds, tests, and runs the project Converters.
12 |
13 |
73 |
74 |
--------------------------------------------------------------------------------
/DatasetConverter/build/classes/.netbeans_automatic_build:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/.netbeans_automatic_build
--------------------------------------------------------------------------------
/DatasetConverter/build/classes/.netbeans_update_resources:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/.netbeans_update_resources
--------------------------------------------------------------------------------
/DatasetConverter/build/classes/converters/Converter$1.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/Converter$1.class
--------------------------------------------------------------------------------
/DatasetConverter/build/classes/converters/Converter.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/Converter.class
--------------------------------------------------------------------------------
/DatasetConverter/build/classes/converters/ConverterMain.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/ConverterMain.class
--------------------------------------------------------------------------------
/DatasetConverter/manifest.mf:
--------------------------------------------------------------------------------
1 | Manifest-Version: 1.0
2 | X-COMMENT: Main-Class will be added automatically by build
3 |
4 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/build-impl.xml:
--------------------------------------------------------------------------------
1 |
2 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 | Must set src.dir
234 | Must set test.src.dir
235 | Must set build.dir
236 | Must set dist.dir
237 | Must set build.classes.dir
238 | Must set dist.javadoc.dir
239 | Must set build.test.classes.dir
240 | Must set build.test.results.dir
241 | Must set build.classes.excludes
242 | Must set dist.jar
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 | Must set javac.includes
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 | No tests executed.
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 | Must set JVM to use for profiling in profiler.info.jvm
723 | Must set profiler agent JVM arguments in profiler.info.jvmargs.agent
724 |
725 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
864 |
865 |
866 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 |
916 |
917 |
918 |
919 |
920 |
921 |
922 |
923 |
924 |
925 |
926 |
927 |
928 |
929 |
930 |
931 |
932 |
933 |
934 |
935 |
936 |
937 |
938 |
939 |
940 |
941 |
942 |
943 |
944 |
945 |
946 |
947 |
948 |
949 |
950 |
951 | Must select some files in the IDE or set javac.includes
952 |
953 |
954 |
955 |
956 |
957 |
958 |
959 |
960 |
965 |
966 |
967 |
968 |
969 |
970 |
971 |
972 |
973 |
974 |
975 |
976 |
977 |
978 |
979 |
980 |
981 |
982 |
983 |
984 |
985 |
986 |
987 |
988 |
989 |
990 |
991 |
992 |
993 |
994 |
995 |
996 |
997 |
998 |
999 |
1000 |
1001 | To run this application from the command line without Ant, try:
1002 |
1003 | java -jar "${dist.jar.resolved}"
1004 |
1005 |
1006 |
1007 |
1008 |
1009 |
1010 |
1011 |
1012 |
1013 |
1014 |
1015 |
1016 |
1017 |
1018 |
1019 |
1020 |
1021 |
1022 |
1023 |
1024 |
1025 |
1026 |
1027 |
1028 |
1029 |
1030 |
1031 |
1032 |
1033 |
1034 |
1039 |
1040 |
1041 |
1042 |
1043 |
1044 |
1045 |
1046 |
1047 |
1048 |
1049 |
1050 | Must select one file in the IDE or set run.class
1051 |
1052 |
1053 |
1054 | Must select one file in the IDE or set run.class
1055 |
1056 |
1057 |
1062 |
1063 |
1064 |
1065 |
1066 |
1067 |
1068 |
1069 |
1070 |
1071 |
1072 |
1073 |
1074 |
1075 |
1076 |
1077 |
1078 |
1079 |
1080 |
1081 | Must select one file in the IDE or set debug.class
1082 |
1083 |
1084 |
1085 |
1086 | Must select one file in the IDE or set debug.class
1087 |
1088 |
1089 |
1090 |
1091 | Must set fix.includes
1092 |
1093 |
1094 |
1095 |
1096 |
1097 |
1098 |
1103 |
1106 |
1107 | This target only works when run from inside the NetBeans IDE.
1108 |
1109 |
1110 |
1111 |
1112 |
1113 |
1114 |
1115 |
1116 | Must select one file in the IDE or set profile.class
1117 | This target only works when run from inside the NetBeans IDE.
1118 |
1119 |
1120 |
1121 |
1122 |
1123 |
1124 |
1125 |
1126 | This target only works when run from inside the NetBeans IDE.
1127 |
1128 |
1129 |
1130 |
1131 |
1132 |
1133 |
1134 |
1135 |
1136 |
1137 |
1138 |
1139 | This target only works when run from inside the NetBeans IDE.
1140 |
1141 |
1142 |
1143 |
1144 |
1145 |
1146 |
1147 |
1148 |
1149 |
1150 |
1151 |
1152 |
1153 |
1154 |
1155 |
1156 |
1157 |
1158 |
1159 |
1160 |
1161 |
1164 |
1165 |
1166 |
1167 |
1168 |
1169 |
1170 |
1171 |
1172 |
1173 |
1174 |
1175 |
1176 |
1177 | Must select one file in the IDE or set run.class
1178 |
1179 |
1180 |
1181 |
1182 |
1183 | Must select some files in the IDE or set test.includes
1184 |
1185 |
1186 |
1187 |
1188 | Must select one file in the IDE or set run.class
1189 |
1190 |
1191 |
1192 |
1193 | Must select one file in the IDE or set applet.url
1194 |
1195 |
1196 |
1197 |
1202 |
1203 |
1204 |
1205 |
1206 |
1207 |
1208 |
1209 |
1210 |
1211 |
1212 |
1213 |
1214 |
1215 |
1216 |
1217 |
1218 |
1219 |
1220 |
1221 |
1222 |
1223 |
1224 |
1225 |
1226 |
1227 |
1228 |
1229 |
1230 |
1231 |
1232 |
1233 |
1234 |
1235 |
1236 |
1237 |
1238 |
1239 |
1240 |
1241 |
1246 |
1247 |
1248 |
1249 |
1250 |
1251 |
1252 |
1253 |
1254 |
1255 |
1256 |
1257 |
1258 |
1259 |
1260 |
1261 |
1262 |
1263 |
1264 |
1265 |
1266 |
1267 |
1268 |
1269 |
1270 |
1271 |
1272 | Must select some files in the IDE or set javac.includes
1273 |
1274 |
1275 |
1276 |
1277 |
1278 |
1279 |
1280 |
1281 |
1282 |
1283 |
1284 |
1289 |
1290 |
1291 |
1292 |
1293 |
1294 |
1295 |
1296 | Some tests failed; see details above.
1297 |
1298 |
1299 |
1300 |
1301 |
1302 |
1303 |
1304 |
1305 | Must select some files in the IDE or set test.includes
1306 |
1307 |
1308 |
1309 | Some tests failed; see details above.
1310 |
1311 |
1312 |
1313 | Must select some files in the IDE or set test.class
1314 | Must select some method in the IDE or set test.method
1315 |
1316 |
1317 |
1318 | Some tests failed; see details above.
1319 |
1320 |
1321 |
1326 |
1327 | Must select one file in the IDE or set test.class
1328 |
1329 |
1330 |
1331 | Must select one file in the IDE or set test.class
1332 | Must select some method in the IDE or set test.method
1333 |
1334 |
1335 |
1336 |
1337 |
1338 |
1339 |
1340 |
1341 |
1342 |
1343 |
1344 |
1349 |
1350 | Must select one file in the IDE or set applet.url
1351 |
1352 |
1353 |
1354 |
1355 |
1356 |
1357 |
1362 |
1363 | Must select one file in the IDE or set applet.url
1364 |
1365 |
1366 |
1367 |
1368 |
1369 |
1370 |
1371 |
1376 |
1377 |
1378 |
1379 |
1380 |
1381 |
1382 |
1383 |
1384 |
1385 |
1386 |
1387 |
1388 |
1389 |
1390 |
1391 |
1392 |
1393 |
1394 |
1395 |
1396 |
1397 |
1398 |
1399 |
1400 |
1401 |
1402 |
1403 |
1404 |
1405 |
1406 |
1407 |
1408 |
1409 |
1410 |
1411 |
1412 |
1413 |
1414 |
1415 |
1416 |
1417 |
1418 |
1419 |
1420 |
1421 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/genfiles.properties:
--------------------------------------------------------------------------------
1 | build.xml.data.CRC32=7f403d13
2 | build.xml.script.CRC32=e7a17702
3 | build.xml.stylesheet.CRC32=8064a381@1.80.1.48
4 | # This file is used by a NetBeans-based IDE to track changes in generated files such as build-impl.xml.
5 | # Do not edit this file. You may delete it but then the IDE will never regenerate such files for you.
6 | nbproject/build-impl.xml.data.CRC32=7f403d13
7 | nbproject/build-impl.xml.script.CRC32=75480b50
8 | nbproject/build-impl.xml.stylesheet.CRC32=830a3534@1.80.1.48
9 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/private/private.properties:
--------------------------------------------------------------------------------
1 | compile.on.save=true
2 | user.properties.file=/home/olu/.netbeans/8.2/build.properties
3 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/private/private.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | file:/home/olu/NetBeansProjects/FolderNameConverter/src/converters/Converter.java
7 | file:/home/olu/NetBeansProjects/FolderNameConverter/src/converters/ConverterMain.java
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/project.properties:
--------------------------------------------------------------------------------
1 | annotation.processing.enabled=true
2 | annotation.processing.enabled.in.editor=false
3 | annotation.processing.processor.options=
4 | annotation.processing.processors.list=
5 | annotation.processing.run.all.processors=true
6 | annotation.processing.source.output=${build.generated.sources.dir}/ap-source-output
7 | build.classes.dir=${build.dir}/classes
8 | build.classes.excludes=**/*.java,**/*.form
9 | # This directory is removed when the project is cleaned:
10 | build.dir=build
11 | build.generated.dir=${build.dir}/generated
12 | build.generated.sources.dir=${build.dir}/generated-sources
13 | # Only compile against the classpath explicitly listed here:
14 | build.sysclasspath=ignore
15 | build.test.classes.dir=${build.dir}/test/classes
16 | build.test.results.dir=${build.dir}/test/results
17 | # Uncomment to specify the preferred debugger connection transport:
18 | #debug.transport=dt_socket
19 | debug.classpath=\
20 | ${run.classpath}
21 | debug.test.classpath=\
22 | ${run.test.classpath}
23 | # Files in build.classes.dir which should be excluded from distribution jar
24 | dist.archive.excludes=
25 | # This directory is removed when the project is cleaned:
26 | dist.dir=dist
27 | dist.jar=${dist.dir}/Converters.jar
28 | dist.javadoc.dir=${dist.dir}/javadoc
29 | excludes=
30 | includes=**
31 | jar.compress=false
32 | javac.classpath=
33 | # Space-separated list of extra javac options
34 | javac.compilerargs=
35 | javac.deprecation=false
36 | javac.external.vm=true
37 | javac.processorpath=\
38 | ${javac.classpath}
39 | javac.source=1.8
40 | javac.target=1.8
41 | javac.test.classpath=\
42 | ${javac.classpath}:\
43 | ${build.classes.dir}
44 | javac.test.processorpath=\
45 | ${javac.test.classpath}
46 | javadoc.additionalparam=
47 | javadoc.author=false
48 | javadoc.encoding=${source.encoding}
49 | javadoc.noindex=false
50 | javadoc.nonavbar=false
51 | javadoc.notree=false
52 | javadoc.private=false
53 | javadoc.splitindex=true
54 | javadoc.use=true
55 | javadoc.version=false
56 | javadoc.windowtitle=
57 | main.class=converters.ConverterMain
58 | manifest.file=manifest.mf
59 | meta.inf.dir=${src.dir}/META-INF
60 | mkdist.disabled=false
61 | platform.active=default_platform
62 | run.classpath=\
63 | ${javac.classpath}:\
64 | ${build.classes.dir}
65 | # Space-separated list of JVM arguments used when running the project.
66 | # You may also define separate properties like run-sys-prop.name=value instead of -Dname=value.
67 | # To set system properties for unit tests define test-sys-prop.name=value:
68 | run.jvmargs=
69 | run.test.classpath=\
70 | ${javac.test.classpath}:\
71 | ${build.test.classes.dir}
72 | source.encoding=UTF-8
73 | src.dir=src
74 | test.src.dir=test
75 |
--------------------------------------------------------------------------------
/DatasetConverter/nbproject/project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | org.netbeans.modules.java.j2seproject
4 |
5 |
6 | Converters
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/DatasetConverter/src/converters/Converter.java:
--------------------------------------------------------------------------------
1 | /*
2 | * To change this license header, choose License Headers in Project Properties.
3 | * To change this template file, choose Tools | Templates
4 | * and open the template in the editor.
5 | */
6 | package converters;
7 |
8 | import java.awt.image.BufferedImage;
9 | import java.io.BufferedReader;
10 | import java.io.File;
11 | import java.io.FileInputStream;
12 | import java.io.FileOutputStream;
13 | import java.io.FileReader;
14 | import java.io.IOException;
15 | import java.io.InputStreamReader;
16 | import java.text.SimpleDateFormat;
17 | import java.util.ArrayList;
18 | import java.util.Arrays;
19 | import java.util.Comparator;
20 | import java.util.Date;
21 | import java.util.Scanner;
22 | import java.util.concurrent.*;
23 | import java.util.logging.Level;
24 | import java.util.logging.Logger;
25 | import javax.imageio.ImageIO;
26 |
27 | /**
28 | *
29 | * @author olu
30 | * email: oluwoleoyetoke@gmail.com
31 | * Class contains methods used to
32 | * 1. Convert GTSB dataset folder names to labels for the class
33 | * 2. Convert the datasets .ppm images to .jpeg
34 | * (GTSB - German Traffic Sign Benchmark)
35 | */
36 | public class Converter implements Runnable {
37 | String pathToFolderUnderView=null;
38 | String format;
39 | static int totalConversion;
40 |
41 | //Default constructor
42 | Converter(){
43 |
44 | }
45 |
46 | //Constructor used when multithread operation is needed for the picture conversion
47 | Converter(String pathToFolderUnderView, String format){
48 | this.pathToFolderUnderView= pathToFolderUnderView;
49 | this.format = format;
50 | }
51 |
52 | @Override
53 | public void run(){
54 | System.out.println("(Multithreaded) Now Converting Contents of Folder "+this.pathToFolderUnderView);
55 | //Connect to sub folder and begin to convert all its subfiles
56 | long timeStart = System.currentTimeMillis();
57 |
58 | File file = new File(this.pathToFolderUnderView);
59 | String extension="";
60 | String[] fileNames = file.list();
61 | for(int i=0; iSubfolders-->Each subfolder containing "
110 | + "specific classes of image\n"
111 | + "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg....");
112 |
113 | //Get user confirmation
114 | Scanner scanner = new Scanner(System.in);
115 | System.out.println("Would you like to proceed? [Y/N]: ");
116 | String answer = scanner.next();
117 | if(!answer.equals("Y")){
118 | System.out.println("Exiting....");
119 | return false;
120 | }
121 |
122 | //Connect to dataset base deirectory
123 | File base =null;
124 | try{ //For saftey sake
125 | base = new File(baseFolderDir);
126 | if(!base.isDirectory()){ //Check to make sure directory specified isnt just a file
127 | System.out.println("Not a directory");
128 | return false;
129 | }
130 | }catch (Exception ex){
131 | System.out.println("Error occured while opening directory: "+ex.getMessage());
132 | return false;
133 | }
134 |
135 | System.out.println("Base Directory: "+baseFolderDir);
136 |
137 | //Get sub directories
138 | String[] subFiles = base.list();
139 | File[] subFilesHandle = base.listFiles();
140 |
141 |
142 | //Confirm that base directory has sub directories or at least, sub files
143 | int noOfContents = 0;
144 | noOfContents = subFiles.length;
145 | System.out.println("Number of sub directories or posibly files: "+noOfContents);
146 | if(noOfContents==0){
147 | System.out.println("There are no sub files/directories in the base directory");
148 | return false;
149 | }
150 |
151 | System.out.println("About to begin multithreaded conversion. Please note, this might take some time");
152 | String pathToFolderUnderOperation ="";
153 | //Open each subdirectory and convert image present to desired format (Multi Threaded)
154 | //Use executors to manage concurrency
155 | ExecutorService executor = Executors.newCachedThreadPool();
156 | for(int i=0; iSubfolders-->Each subfolder containing "
186 | + "specific classes of image\n"
187 | + "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg....");
188 |
189 | //Get user confirmation
190 | Scanner scanner = new Scanner(System.in);
191 | System.out.println("Would you like to proceed? [Y/N]: ");
192 | String answer = scanner.next();
193 | if(!answer.equals("Y")){
194 | System.out.println("Exiting....");
195 | return false;
196 | }
197 |
198 | //Validation
199 | if(baseFolderDir.isEmpty()){
200 | System.out.println("baseFolderDir not set");
201 | return false;
202 | }else if(labelingFileDir.isEmpty()){
203 | System.out.println("blabelingFileDir not set");
204 | return false;
205 | }
206 |
207 | //Try to open directory
208 | try{ //For saftey sake
209 | dir = new File(baseFolderDir);
210 | if(!dir.isDirectory()){ //Check to make sure directory specified isnt just a file
211 | System.out.println("Not a directory");
212 | return false;
213 | }
214 | }catch (Exception ex){
215 | System.out.println("Error occured while opening directory: "+ex.getMessage());
216 | return false;
217 | }
218 |
219 | //Get sub directories
220 | String[] subFiles = dir.list();
221 | File[] subFilesHandle = dir.listFiles();
222 |
223 | // Sort files/folders handle by name
224 | Arrays.sort(subFilesHandle, new Comparator(){
225 | @Override
226 | public int compare(Object f1, Object f2){
227 | return ((File) f1).getName().compareTo(((File) f2).getName());
228 | }
229 | });
230 | Arrays.sort(subFiles); //sort files/folders string by name
231 | ArrayList subDirs = new ArrayList();
232 | File test = null;
233 | int noOfContents = subFiles.length;
234 | for (int i=0; i getFormats() {
365 | return new ArrayList<>(Arrays.asList("jpg", "jpeg", "png", "ppm", "gif"));
366 | }
367 |
368 | private static String getFileExtensionFromPath(String path) {
369 | int i = path.lastIndexOf('.');
370 | if (i > 0) {
371 | return path.substring(i + 1);
372 | }
373 | return "";
374 | }
375 |
376 | private static String getOutputPathFromInputPath(String path, String format) {
377 | return path.substring(0, path.lastIndexOf('.')) + "." + format;
378 | }
379 |
380 | private static String executeCommand(String command) {
381 |
382 | StringBuilder output = new StringBuilder();
383 |
384 | Process p;
385 | try {
386 | p = Runtime.getRuntime().exec(command);
387 | p.waitFor();
388 | BufferedReader reader
389 | = new BufferedReader(new InputStreamReader(p.getInputStream()));
390 |
391 | String line;
392 | while ((line = reader.readLine()) != null) {
393 | output.append(line).append("\n");
394 | }
395 |
396 | } catch (IOException | InterruptedException ex) {
397 | Logger.getLogger(Converter.class.getName()).log(Level.SEVERE, null, ex);
398 | }
399 |
400 | return output.toString();
401 | }
402 |
403 |
404 | }
405 |
--------------------------------------------------------------------------------
/DatasetConverter/src/converters/ConverterMain.java:
--------------------------------------------------------------------------------
1 | /*
2 | * To change this license header, choose License Headers in Project Properties.
3 | * To change this template file, choose Tools | Templates
4 | * and open the template in the editor.
5 | */
6 | package converters;
7 |
8 | import java.util.Date;
9 |
10 | /**
11 | *
12 | * @author olu
13 | */
14 | public class ConverterMain {
15 |
16 | /**
17 | * @param args the command line arguments
18 | */
19 | public static void main(String[] args) {
20 | String baseFolderDir="/home/olu/Dev/data_base/sign_base/training"; //Change as appropriate to you
21 | String labelingFileDir="/home/olu/Dev/data_base/sign_base/labels.txt"; //Change as appropriate to you
22 | String formatName = "jpeg";
23 |
24 | //Convert Folder Names
25 | Converter convert = new Converter();
26 | boolean converted = convert.convertFolderName(baseFolderDir, labelingFileDir);
27 |
28 | //Convert all of datasets .ppm to .jpeg
29 | long timeStart = System.currentTimeMillis();
30 | boolean converted2 = convert.convertAllDatasetImages(baseFolderDir, formatName);
31 | if(converted2==true){
32 | long timeEnd = System.currentTimeMillis(); //in milliseconds
33 | long diff = timeEnd - timeStart;
34 | long diffSeconds = diff / 1000 % 60;
35 | long diffMinutes = diff / (60 * 1000) % 60;
36 | long diffHours = diff / (60 * 60 * 1000) % 24;
37 | long diffDays = diff / (24 * 60 * 60 * 1000);
38 | System.out.println("ALL "+formatName+" CONVERSIONS NOW COMPLETED. Took "+diffDays+" Day(s), "+diffHours+" Hour(s) "+diffMinutes+" Minute(s) and "+diffSeconds+" Second(s)");
39 | }
40 |
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/DeleteExtrasFromDataSet.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | #Pass in Base Folder Location
4 | cd $1
--------------------------------------------------------------------------------
/Logs/imdbLogs:
--------------------------------------------------------------------------------
1 | ========== RESTART: /home/olu/Dev/scratch_train_sign/create_imdb.py ==========
2 | Result will be saved to /home/olu/Dev/data_base/sign_base/output
3 | Determining list of input files and labels from /home/olu/Dev/data_base/sign_base/labels.txt
4 | File path /home/olu/Dev/data_base/sign_base/training/speed_20/*
5 |
6 | File path /home/olu/Dev/data_base/sign_base/training/speed_30/*
7 |
8 | File path /home/olu/Dev/data_base/sign_base/training/speed_50/*
9 |
10 | File path /home/olu/Dev/data_base/sign_base/training/speed_60/*
11 |
12 | File path /home/olu/Dev/data_base/sign_base/training/speed_70/*
13 |
14 | File path /home/olu/Dev/data_base/sign_base/training/speed_80/*
15 |
16 | File path /home/olu/Dev/data_base/sign_base/training/speed_less_80/*
17 |
18 | File path /home/olu/Dev/data_base/sign_base/training/speed_100/*
19 |
20 | File path /home/olu/Dev/data_base/sign_base/training/speed_120/*
21 |
22 | File path /home/olu/Dev/data_base/sign_base/training/no_car_overtaking/*
23 |
24 | File path /home/olu/Dev/data_base/sign_base/training/no_truck_overtaking/*
25 |
26 | File path /home/olu/Dev/data_base/sign_base/training/priority_road/*
27 |
28 | File path /home/olu/Dev/data_base/sign_base/training/priority_road_2/*
29 |
30 | File path /home/olu/Dev/data_base/sign_base/training/yield_right_of_way/*
31 |
32 | File path /home/olu/Dev/data_base/sign_base/training/stop/*
33 |
34 | File path /home/olu/Dev/data_base/sign_base/training/road_closed/*
35 |
36 | File path /home/olu/Dev/data_base/sign_base/training/maximum_weight_allowed/*
37 |
38 | File path /home/olu/Dev/data_base/sign_base/training/entry_prohibited/*
39 |
40 | File path /home/olu/Dev/data_base/sign_base/training/danger/*
41 |
42 | File path /home/olu/Dev/data_base/sign_base/training/curve_left/*
43 |
44 | File path /home/olu/Dev/data_base/sign_base/training/curve_right/*
45 |
46 | File path /home/olu/Dev/data_base/sign_base/training/double_curve_right/*
47 |
48 | File path /home/olu/Dev/data_base/sign_base/training/rough_road/*
49 |
50 | File path /home/olu/Dev/data_base/sign_base/training/slippery_road/*
51 |
52 | File path /home/olu/Dev/data_base/sign_base/training/road_narrows_right/*
53 |
54 | File path /home/olu/Dev/data_base/sign_base/training/work_in_progress/*
55 |
56 | File path /home/olu/Dev/data_base/sign_base/training/traffic_light_ahead/*
57 |
58 | File path /home/olu/Dev/data_base/sign_base/training/pedestrian_crosswalk/*
59 |
60 | File path /home/olu/Dev/data_base/sign_base/training/children_area/*
61 |
62 | File path /home/olu/Dev/data_base/sign_base/training/bicycle_crossing/*
63 |
64 | File path /home/olu/Dev/data_base/sign_base/training/beware_of_ice/*
65 |
66 | File path /home/olu/Dev/data_base/sign_base/training/wild_animal_crossing/*
67 |
68 | File path /home/olu/Dev/data_base/sign_base/training/end_of_restriction/*
69 |
70 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_right/*
71 |
72 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_left/*
73 |
74 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight/*
75 |
76 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_right/*
77 |
78 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_left/*
79 |
80 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle/*
81 |
82 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle2/*
83 |
84 | File path /home/olu/Dev/data_base/sign_base/training/traffic_circle/*
85 |
86 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_car_overtaking/*
87 |
88 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_truck_overtaking/*
89 |
90 | Found 39209 JPEG files across 43 labels inside /home/olu/Dev/data_base/sign_base/training
91 | Launching 2 threads for spacings: [[0, 19604], [19604, 39209]]
92 | 2017-12-01 15:56:39.543697 [thread 1]: Processed 1000 of 19605 images in thread batch.2017-12-01 15:56:39.545929 [thread 0]: Processed 1000 of 19604 images in thread batch.
93 |
94 | 2017-12-01 15:56:44.254576 [thread 0]: Processed 2000 of 19604 images in thread batch.
95 | 2017-12-01 15:56:44.433656 [thread 1]: Processed 2000 of 19605 images in thread batch.
96 | 2017-12-01 15:56:48.754633 [thread 0]: Processed 3000 of 19604 images in thread batch.
97 | 2017-12-01 15:56:49.798092 [thread 1]: Processed 3000 of 19605 images in thread batch.
98 | 2017-12-01 15:56:53.908201 [thread 0]: Processed 4000 of 19604 images in thread batch.
99 | 2017-12-01 15:56:54.351327 [thread 1]: Processed 4000 of 19605 images in thread batch.
100 | 2017-12-01 15:56:58.855722 [thread 0]: Processed 5000 of 19604 images in thread batch.
101 | 2017-12-01 15:56:59.301768 [thread 1]: Processed 5000 of 19605 images in thread batch.
102 | 2017-12-01 15:57:03.704445 [thread 0]: Processed 6000 of 19604 images in thread batch.
103 | 2017-12-01 15:57:04.144965 [thread 1]: Processed 6000 of 19605 images in thread batch.
104 | 2017-12-01 15:57:08.543383 [thread 0]: Processed 7000 of 19604 images in thread batch.
105 | 2017-12-01 15:57:09.007709 [thread 1]: Processed 7000 of 19605 images in thread batch.
106 | 2017-12-01 15:57:13.622921 [thread 0]: Processed 8000 of 19604 images in thread batch.
107 | 2017-12-01 15:57:14.059550 [thread 1]: Processed 8000 of 19605 images in thread batch.
108 | 2017-12-01 15:57:18.477144 [thread 0]: Processed 9000 of 19604 images in thread batch.
109 | 2017-12-01 15:57:18.811312 [thread 1]: Processed 9000 of 19605 images in thread batch.
110 | 2017-12-01 15:57:23.763258 [thread 0]: Processed 10000 of 19604 images in thread batch.
111 | 2017-12-01 15:57:24.086001 [thread 1]: Processed 10000 of 19605 images in thread batch.
112 | 2017-12-01 15:57:28.965086 [thread 0]: Processed 11000 of 19604 images in thread batch.
113 | 2017-12-01 15:57:29.261494 [thread 1]: Processed 11000 of 19605 images in thread batch.
114 | 2017-12-01 15:57:34.061319 [thread 0]: Processed 12000 of 19604 images in thread batch.
115 | 2017-12-01 15:57:34.177079 [thread 1]: Processed 12000 of 19605 images in thread batch.
116 | 2017-12-01 15:57:39.415580 [thread 1]: Processed 13000 of 19605 images in thread batch.2017-12-01 15:57:39.418505 [thread 0]: Processed 13000 of 19604 images in thread batch.
117 |
118 | 2017-12-01 15:57:44.575577 [thread 1]: Processed 14000 of 19605 images in thread batch.
119 | 2017-12-01 15:57:44.973289 [thread 0]: Processed 14000 of 19604 images in thread batch.
120 | 2017-12-01 15:57:49.592148 [thread 1]: Processed 15000 of 19605 images in thread batch.
121 | 2017-12-01 15:57:50.393305 [thread 0]: Processed 15000 of 19604 images in thread batch.
122 | 2017-12-01 15:57:54.668816 [thread 1]: Processed 16000 of 19605 images in thread batch.
123 | 2017-12-01 15:57:55.473357 [thread 0]: Processed 16000 of 19604 images in thread batch.
124 | 2017-12-01 15:57:59.705012 [thread 1]: Processed 17000 of 19605 images in thread batch.
125 | 2017-12-01 15:58:00.571395 [thread 0]: Processed 17000 of 19604 images in thread batch.
126 | 2017-12-01 15:58:04.590437 [thread 1]: Processed 18000 of 19605 images in thread batch.
127 | 2017-12-01 15:58:05.412019 [thread 0]: Processed 18000 of 19604 images in thread batch.
128 | 2017-12-01 15:58:09.393710 [thread 1]: Processed 19000 of 19605 images in thread batch.
129 | 2017-12-01 15:58:10.293427 [thread 0]: Processed 19000 of 19604 images in thread batch.
130 | 2017-12-01 15:58:12.288467 [thread 1]: Wrote 19605 images to /home/olu/Dev/data_base/sign_base/output/validation-00001-of-00002
131 | 2017-12-01 15:58:12.326821 [thread 1]: Wrote 19605 images to 19605 shards.
132 | 2017-12-01 15:58:12.793973 [thread 0]: Wrote 19604 images to /home/olu/Dev/data_base/sign_base/output/validation-00000-of-00002
133 | 2017-12-01 15:58:12.840278 [thread 0]: Wrote 19604 images to 19604 shards.
134 | 2017-12-01 15:58:13.796779: Finished writing all 39209 images in data set.
135 | Determining list of input files and labels from /home/olu/Dev/data_base/sign_base/labels.txt
136 | File path /home/olu/Dev/data_base/sign_base/training/speed_20/*
137 |
138 | File path /home/olu/Dev/data_base/sign_base/training/speed_30/*
139 |
140 | File path /home/olu/Dev/data_base/sign_base/training/speed_50/*
141 |
142 | File path /home/olu/Dev/data_base/sign_base/training/speed_60/*
143 |
144 | File path /home/olu/Dev/data_base/sign_base/training/speed_70/*
145 |
146 | File path /home/olu/Dev/data_base/sign_base/training/speed_80/*
147 |
148 | File path /home/olu/Dev/data_base/sign_base/training/speed_less_80/*
149 |
150 | File path /home/olu/Dev/data_base/sign_base/training/speed_100/*
151 |
152 | File path /home/olu/Dev/data_base/sign_base/training/speed_120/*
153 |
154 | File path /home/olu/Dev/data_base/sign_base/training/no_car_overtaking/*
155 |
156 | File path /home/olu/Dev/data_base/sign_base/training/no_truck_overtaking/*
157 |
158 | File path /home/olu/Dev/data_base/sign_base/training/priority_road/*
159 |
160 | File path /home/olu/Dev/data_base/sign_base/training/priority_road_2/*
161 |
162 | File path /home/olu/Dev/data_base/sign_base/training/yield_right_of_way/*
163 |
164 | File path /home/olu/Dev/data_base/sign_base/training/stop/*
165 |
166 | File path /home/olu/Dev/data_base/sign_base/training/road_closed/*
167 |
168 | File path /home/olu/Dev/data_base/sign_base/training/maximum_weight_allowed/*
169 |
170 | File path /home/olu/Dev/data_base/sign_base/training/entry_prohibited/*
171 |
172 | File path /home/olu/Dev/data_base/sign_base/training/danger/*
173 |
174 | File path /home/olu/Dev/data_base/sign_base/training/curve_left/*
175 |
176 | File path /home/olu/Dev/data_base/sign_base/training/curve_right/*
177 |
178 | File path /home/olu/Dev/data_base/sign_base/training/double_curve_right/*
179 |
180 | File path /home/olu/Dev/data_base/sign_base/training/rough_road/*
181 |
182 | File path /home/olu/Dev/data_base/sign_base/training/slippery_road/*
183 |
184 | File path /home/olu/Dev/data_base/sign_base/training/road_narrows_right/*
185 |
186 | File path /home/olu/Dev/data_base/sign_base/training/work_in_progress/*
187 |
188 | File path /home/olu/Dev/data_base/sign_base/training/traffic_light_ahead/*
189 |
190 | File path /home/olu/Dev/data_base/sign_base/training/pedestrian_crosswalk/*
191 |
192 | File path /home/olu/Dev/data_base/sign_base/training/children_area/*
193 |
194 | File path /home/olu/Dev/data_base/sign_base/training/bicycle_crossing/*
195 |
196 | File path /home/olu/Dev/data_base/sign_base/training/beware_of_ice/*
197 |
198 | File path /home/olu/Dev/data_base/sign_base/training/wild_animal_crossing/*
199 |
200 | File path /home/olu/Dev/data_base/sign_base/training/end_of_restriction/*
201 |
202 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_right/*
203 |
204 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_left/*
205 |
206 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight/*
207 |
208 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_right/*
209 |
210 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_left/*
211 |
212 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle/*
213 |
214 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle2/*
215 |
216 | File path /home/olu/Dev/data_base/sign_base/training/traffic_circle/*
217 |
218 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_car_overtaking/*
219 |
220 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_truck_overtaking/*
221 |
222 | Found 39209 JPEG files across 43 labels inside /home/olu/Dev/data_base/sign_base/training
223 | Launching 2 threads for spacings: [[0, 19604], [19604, 39209]]
224 | 2017-12-01 15:58:19.321528 [thread 1]: Processed 1000 of 19605 images in thread batch.2017-12-01 15:58:19.328218 [thread 0]: Processed 1000 of 19604 images in thread batch.
225 |
226 | 2017-12-01 15:58:24.103518 [thread 1]: Processed 2000 of 19605 images in thread batch.
227 | 2017-12-01 15:58:24.379995 [thread 0]: Processed 2000 of 19604 images in thread batch.
228 | 2017-12-01 15:58:28.848063 [thread 1]: Processed 3000 of 19605 images in thread batch.
229 | 2017-12-01 15:58:29.290989 [thread 0]: Processed 3000 of 19604 images in thread batch.
230 | 2017-12-01 15:58:33.692279 [thread 1]: Processed 4000 of 19605 images in thread batch.
231 | 2017-12-01 15:58:34.348252 [thread 0]: Processed 4000 of 19604 images in thread batch.
232 | 2017-12-01 15:58:38.885082 [thread 1]: Processed 5000 of 19605 images in thread batch.
233 | 2017-12-01 15:58:39.212437 [thread 0]: Processed 5000 of 19604 images in thread batch.
234 | 2017-12-01 15:58:44.014678 [thread 1]: Processed 6000 of 19605 images in thread batch.
235 | 2017-12-01 15:58:44.319710 [thread 0]: Processed 6000 of 19604 images in thread batch.
236 | 2017-12-01 15:58:48.727675 [thread 1]: Processed 7000 of 19605 images in thread batch.
237 | 2017-12-01 15:58:49.454886 [thread 0]: Processed 7000 of 19604 images in thread batch.
238 | 2017-12-01 15:58:53.757572 [thread 1]: Processed 8000 of 19605 images in thread batch.
239 | 2017-12-01 15:58:54.732761 [thread 0]: Processed 8000 of 19604 images in thread batch.
240 | 2017-12-01 15:58:59.082209 [thread 1]: Processed 9000 of 19605 images in thread batch.
241 | 2017-12-01 15:58:59.685727 [thread 0]: Processed 9000 of 19604 images in thread batch.
242 | 2017-12-01 15:59:03.855798 [thread 1]: Processed 10000 of 19605 images in thread batch.
243 | 2017-12-01 15:59:04.623629 [thread 0]: Processed 10000 of 19604 images in thread batch.
244 | 2017-12-01 15:59:08.897193 [thread 1]: Processed 11000 of 19605 images in thread batch.
245 | 2017-12-01 15:59:09.754831 [thread 0]: Processed 11000 of 19604 images in thread batch.
246 | 2017-12-01 15:59:13.925373 [thread 1]: Processed 12000 of 19605 images in thread batch.
247 | 2017-12-01 15:59:14.625715 [thread 0]: Processed 12000 of 19604 images in thread batch.
248 | 2017-12-01 15:59:18.758661 [thread 1]: Processed 13000 of 19605 images in thread batch.
249 | 2017-12-01 15:59:19.334345 [thread 0]: Processed 13000 of 19604 images in thread batch.
250 | 2017-12-01 15:59:23.557197 [thread 1]: Processed 14000 of 19605 images in thread batch.
251 | 2017-12-01 15:59:24.091283 [thread 0]: Processed 14000 of 19604 images in thread batch.
252 | 2017-12-01 15:59:28.380125 [thread 1]: Processed 15000 of 19605 images in thread batch.
253 | 2017-12-01 15:59:28.939316 [thread 0]: Processed 15000 of 19604 images in thread batch.
254 | 2017-12-01 15:59:33.511671 [thread 1]: Processed 16000 of 19605 images in thread batch.
255 | 2017-12-01 15:59:33.791349 [thread 0]: Processed 16000 of 19604 images in thread batch.
256 | 2017-12-01 15:59:38.410925 [thread 1]: Processed 17000 of 19605 images in thread batch.
257 | 2017-12-01 15:59:38.555400 [thread 0]: Processed 17000 of 19604 images in thread batch.
258 | 2017-12-01 15:59:43.261218 [thread 1]: Processed 18000 of 19605 images in thread batch.
259 | 2017-12-01 15:59:43.384039 [thread 0]: Processed 18000 of 19604 images in thread batch.
260 | 2017-12-01 15:59:48.125398 [thread 1]: Processed 19000 of 19605 images in thread batch.
261 | 2017-12-01 15:59:48.189964 [thread 0]: Processed 19000 of 19604 images in thread batch.
262 | 2017-12-01 15:59:50.915470 [thread 1]: Wrote 19605 images to /home/olu/Dev/data_base/sign_base/output/train-00001-of-00002
263 | 2017-12-01 15:59:50.925460 [thread 1]: Wrote 19605 images to 19605 shards.
264 | 2017-12-01 15:59:51.064479 [thread 0]: Wrote 19604 images to /home/olu/Dev/data_base/sign_base/output/train-00000-of-00002
265 | 2017-12-01 15:59:51.069508 [thread 0]: Wrote 19604 images to 19604 shards.
266 | 2017-12-01 15:59:51.536970: Finished writing all 39209 images in data set.
267 | >>>
268 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Computer_Vision_Using_TensorFlowLite
2 |
3 | On this project the AlexNet Convolutional Neural Network is trained using traffic sign images from the German Road Traffic Sign Benchmark. The initially trained network is then quantized/optimized for deployment on mobile devices using TensorFlowLite
4 |
5 | ## Project Steps
6 | 1. Download German Traffic Sign Benchmark Training Images [from here](https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip)
7 | 2. Convert From .ppm to .jpg Format
8 | 3. Label db Folder Appropriately
9 | 4. Convert Dataset to TFRecord
10 | 5. Create CNN (alexnet)
11 | 6. Train CNN Using TFRecord Data
12 | 7. Test & Evaluate Trained Model
13 | 8. Quantize/Tune/Optimize trained model for mobile deployment
14 | 9. Test & Evaluate Tuned Model
15 | 10. Deploy Tuned Model to Mobile Platform [Here](https://github.com/OluwoleOyetoke/Accelerated-Android-Vision)
16 |
17 | ### Steps 1, 2 & 3: Get Dataset, Convert (from .ppm to .jpeg) & Label Appropriately
18 | The **DatasetConverter** folder contains the java program written to go through the GTSBR dataset, rename all its folders using the class name of the sets of images the folders contain. It also converts the image file types from .ppm to .jpeg
19 |
20 | ### Step 4: Convert Dataset to TFRecord
21 | With TensorFlow, we can store our whole dataset and all of its meta-data as a serialized record called **TFRecord**. Loading up our data into this format helps for portability and simplicity. Also note that this record can be broken into multiple shards and used when performing distributed training. A 'TFRecord' in TensorFlow is basically TensorFlow's default data format. A record file containing serialized tf.train.examples. To avoid confusion, a TensorFlow 'example' is a normalized data format for storing data for training and inference purposes. It contains a key-value store where each key string maps to a feature message. These feature messages can be things like a packed byte list, float list, int64 list. Note that many 'examples' come together to form a TFRecord. The script **create_imdb.py** is used to make the TFRecords out of the dataset. To learn more about creating a TFRecord file or streaming out data from it, see these posts [TensorFlow TFRecord Blog Post 1](http://eagle-beacon.com/blog/posts/Loading_And_Poping_TFRecords.html) here & [TensorFlow TFRecord Blog Post 2](http://eagle-beacon.com/blog/posts/Loading_And_Poping_TFRecords_P2.html)
22 |
23 | ### Step 5: Create AlexNet Structure
24 | The script **train_alexnet.py** is used to create and train the CNN. See TensorBoard visualization of AlexNet structure below:
25 |
26 |
27 | 
28 |
29 | **Figure Showing TesnorBoard Visualization of the Network**
30 |
31 |
32 | ### Step 6: Train CNN Using TFRecord Data
33 | The figures below show the drop in loss of the network as training progressed. The Adam Optimizer was used. The Figure below shows the loss reducing per epoch
34 | 
35 |
36 | **Figure Showing Improvement In Network Performance Per Epoch (Epoch 0-20)**
37 |
38 | After the full training procedure, the trained model performed at **over 98% accuracy**
39 |
40 | ### Step 7: Test & Evaluate Trained Model
41 | To test the trained network model, the script **classify_img_arg.py** can be used.
42 | Usage format:
43 |
44 | ```
45 |
46 | $python one_time_classify.py [saved_model_directory] [path_to_image]
47 |
48 | ```
49 |
50 | ### Step 8: Quantize/Tune/Optimize Trained Model for Mobile Deployment
51 | Here, the TensorFlow Graph Transform tool comes in handy in helping us shrink the size of the model and making it deployable on mobile. The transform tools are designed to work on models that are saved as GraphDef files, usually in a binary protobuf format, the low-level definition of a TensorFlow computational graph which includes a list of nodes and the input and output connections between them. But firstly before we start any form of transformation/optimization of the model, it is wise to freez the graph, i.e. making sure that trained weights are fused with the graph and converting these weights into embedded constants within the graph file itself. To do this, we will need to run the [reeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) script.
52 |
53 | The various transforms we can perform on our graph include striping unused nodes, removing unused nodes, folding batch norm layers etc.
54 |
55 | In a nutshell, to achieve our conversion of the Tensorflow .pb model to a TensorFlowLite .lite model, we will:
56 | 1. Freeze the grpah i.e merge checkpoint values with graph stucture. In other words, load variables into the graph and convert them to constants
57 | 2. Perform one or two simple transformations/optimizations on the froze model
58 | 3. Convert the frozen graph definition into the the [flat buffer format](https://google.github.io/flatbuffers/) (.lite)
59 |
60 | The various transforms we can perform on our graph include stiping unused nodes, remving unused nodes, folding batch norm layers etc.
61 |
62 | #### Step 8.1: Sample Transform Definition and A Little Bit of Explanation
63 | Note that these transform options are key to shrinking the model file size
64 |
65 | ```
66 | transforms = 'strip_unused_nodes(type=float, shape="1,299,299,3")
67 | remove_nodes(op=Identity, op=CheckNumerics)
68 | fold_constants(ignore_errors=true)
69 | fold_batch_norms fold_old_batch_norms
70 | round_weights(num_steps=256)
71 | quantize_weights obfuscate_names
72 | quantize_nodes sort_by_execution_order'
73 | ```
74 |
75 | **round_weights(num_steps=256)** - Essentially, this rounds the weights so that nearby numbers are stored as exactly the same values, the resulting bit stream has a lot more repetition and so compresses down a lot more effectively. Compressed tflite model can be as small as 70-% less the size of the original model. The nice thing about this transform option is that it doesn't change the structure of the graph at all, so it's running exactly the same operations and should have the same latency and memory usage as before. We can adjust the num_steps parameter to control how many values each weight buffer is rounded to, so lower numbers will increase the compression at the cost of accuracy
76 |
77 | **quantize_weights** - Storing the weights of the model as 8-bit will drastically reduce its size, however, we will be trading off inference accuracy here.
78 |
79 | **obfuscate_names** - Supposing our graph has a lot of small nodes in it, the names can also end up being a cause of space utilization. Using this option will help cut that down.
80 |
81 | For some platforms it is very helpful to be able to do as many calculations as possible in eight-bit, rather than floating-point. a few transform option are available to help achieve this, although in most cases some modification will need to be made to the actual training code to make this transofrm happen effectively.
82 |
83 | #### 8.2: Transformation & Conversion Execution
84 | As we know, TensorFlowLite is still in its early development stage and many new features are being added daily. Conversely, there are some TensorFlow operation which are not currently fully supported at the moment [e.g ArgMax](https://github.com/tensorflow/tensorflow/issues/15948). As a matter of fact, at the moment, [models that were quantized using transform_graph are not supported by TF Lite ](https://github.com/tensorflow/tensorflow/issues/15871#issuecomment-356419505). However, on the brighter note, we can still convert our TensorFlow custom alexnet model to a TFLite model, but we will need to turn of some transform options such as 'quantize_weights', 'quantize_nodes' and keep our model inference, input types as FLOATs. In this case, model size would not really change.
85 |
86 | Code excerpt below from the [freeze_and_convert_to_tflite.sh](https://github.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/blob/master/freeze_and_convert_to_tflite.sh) script shows how this .tflite conversion is achieved.
87 |
88 | ```
89 | #FREEZE GRAPH
90 | bazel build tensorflow/python/tools:freeze_graph
91 | bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$1 --input_checkpoint=$2 --input_binary=$3 --output_graph=$4 --output_node_names=$6
92 |
93 |
94 | #VIEW SUMARY OF GRAPH
95 | bazel build /tensorflow/tools/graph_transforms:summarize_graph
96 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=$4
97 |
98 | #TRANSFORM/OPTIMIZE GRAPH
99 | bazel build tensorflow/tools/graph_transforms:transform_graph
100 | bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$4 --out_graph=${10} --inputs=$5 --outputs=$6 --transforms='strip_unused_nodes(type=float, shape="1,227,227,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) round_weights(num_steps=256) obfuscate_names sort_by_execution_order fold_old_batch_norms fold_batch_norms'
101 |
102 |
103 | #CONVERT TO TFLITE MODEL
104 | bazel build /tensorflow/contrib/lite/toco:toco
105 | bazel-bin/tensorflow/contrib/lite/toco/toco --input_format=TENSORFLOW_GRAPHDEF --input_file=${10} --output_format=TFLITE --output_file=$7 --inference_type=$9 --#input_type=$8 --input_arrays=$5 --output_arrays=$6 --inference_input_type=$8 --input_shapes=1,227,227,3
106 |
107 | ```
108 |
109 | Note, this will require you to build TensorFlow from source. You can get instructions to do this from [here](https://www.tensorflow.org/install/install_sources)
110 |
--------------------------------------------------------------------------------
/classify_img.py:
--------------------------------------------------------------------------------
1 | """
2 | Script passes raw image to an already trained model to get prediction
3 |
4 | @date: 20th December, 2017
5 | @author: Oluwole Oyetoke
6 | @Language: Python
7 | @email: oluwoleoyetoke@gmail.com
8 |
9 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING
10 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
11 | from __future__ import absolute_import
12 | from __future__ import division
13 | from __future__ import print_function
14 | from sklearn.model_selection import train_test_split
15 |
16 | import os
17 | import math
18 | import time
19 | import numpy as np
20 | import tensorflow as tf #import tensorflow
21 | import matplotlib.pyplot as plt
22 | from PIL import Image
23 |
24 | flags = tf.app.flags
25 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width")
26 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height")
27 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels")
28 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases")
29 | FLAGS = flags.FLAGS
30 |
31 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN)
32 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
33 |
34 |
35 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER
36 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
37 | def main(unused_argv):
38 |
39 | #Specify checkpoint & image directory
40 | checkpoint_directory="/home/olu/Dev/data_base/sign_base/backup/Checkpoints_N_Model_Epoch_34__copy/trained_alexnet_model"
41 | filename="/home/olu/Dev/data_base/sign_base/training_227x227/road_closed/00002_00005.jpeg"
42 |
43 | #Process image to be sent to Neural Net
44 | #im = np.array(Image.open(filename))
45 | #img_batch = im.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels)
46 |
47 | img = Image.open(filename)
48 | img_resized = img.resize((227, 227), Image.ANTIALIAS)
49 | img_batch_np = np.array(img_resized)
50 | img_batch = img_batch_np.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels)
51 | plt.imshow(img_batch_np)
52 |
53 | #Declare categories/classes as string
54 | categories = ["speed_20", "speed_30","speed_50","speed_60","speed_70",
55 | "speed_80","speed_less_80","speed_100","speed_120",
56 | "no_car_overtaking","no_truck_overtaking","priority_road",
57 | "priority_road_2","yield_right_of_way","stop","road_closed",
58 | "maximum_weight_allowed","entry_prohibited","danger","curve_left",
59 | "curve_right","double_curve_right","rough_road","slippery_road",
60 | "road_narrows_right","work_in_progress","traffic_light_ahead",
61 | "pedestrian_crosswalk","children_area","bicycle_crossing",
62 | "beware_of_ice","wild_animal_crossing","end_of_restriction",
63 | "must_turn_right","must_turn_left","must_go_straight",
64 | "must_go_straight_or_right","must_go_straight_or_left",
65 | "mandatroy_direction_bypass_obstacle",
66 | "mandatroy_direction_bypass_obstacle2",
67 | "traffic_circle","end_of_no_car_overtaking",
68 | "end_of_no_truck_overtaking"];
69 |
70 |
71 | #Recreate network graph.
72 | sess = tf.Session()
73 | latest_checkpoint_name = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_directory)
74 | saver = tf.train.import_meta_graph(latest_checkpoint_name+'.meta') #At this step only graph is created.
75 |
76 | #Accessing the default graph which we have restored
77 | graph = tf.get_default_graph()
78 |
79 | #Get model's graph
80 | checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory)
81 | saver.restore(sess, checkpoint_file) #Load the weights saved using the restore method.
82 |
83 | probabilities = graph.get_tensor_by_name("softmax_tensor:0")
84 | classes = graph.get_tensor_by_name("classes_tensor:0") #'ArgMax:0' is the name of the argmax tensor in the train_alexnet.py file.
85 | feed_dict = {"input_layer:0": img_batch} #'Reshape:0' is the name of the 'input_layer' tensor in the train_alexnet.py. Given to it as default.
86 | predicted_class = sess.run(classes, feed_dict)
87 | predicted_probabilities = sess.run(probabilities, feed_dict)
88 | assurance = predicted_probabilities[0,int(predicted_class)]*100;
89 |
90 | print("Predicted Sign: ", categories[int(predicted_class)], " (With ", assurance," Percent Assurance)")
91 | print("finished")
92 | plt.show()
93 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
94 |
95 |
96 | if __name__=="__main__":
97 | tf.app.run()
98 |
99 |
--------------------------------------------------------------------------------
/classify_img_arg.py:
--------------------------------------------------------------------------------
1 | """
2 | Script passes raw image to an already trained model to get prediction
3 |
4 | @date: 20th December, 2017
5 | @Language: Python
6 |
7 | usage = $python classify_img_arg.py [saved_model_directory] [path_to_image]
8 |
9 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING
10 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
11 |
12 |
13 | from __future__ import absolute_import
14 | from __future__ import print_function
15 |
16 | import sys
17 | import os
18 | import math
19 | import time
20 | import numpy as np
21 | import tensorflow as tf #import tensorflow
22 | import matplotlib.pyplot as plt
23 | from PIL import Image
24 |
25 | flags = tf.app.flags
26 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width")
27 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height")
28 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels")
29 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases")
30 | FLAGS = flags.FLAGS
31 |
32 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN)
33 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
34 |
35 |
36 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER
37 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
38 | #Specify checkpoint & image directory
39 | checkpoint_directory= sys.argv[1] # "/home/olu/Dev/data_base/sign_base/backup/Checkpoints_N_Model2-After_Epoch_20_copy/trained_alexnet_model"
40 | filename= sys.argv[2] # "/home/olu/Dev/data_base/sign_base/training_227x227/road_closed/00002_00005.jpeg"
41 |
42 | #Declare categories/classes as string
43 | categories = ["speed_20", "speed_30","speed_50","speed_60","speed_70",
44 | "speed_80","speed_less_80","speed_100","speed_120",
45 | "no_car_overtaking","no_truck_overtaking","priority_road",
46 | "priority_road_2","yield_right_of_way","stop","road_closed",
47 | "maximum_weight_allowed","entry_prohibited","danger","curve_left",
48 | "curve_right","double_curve_right","rough_road","slippery_road",
49 | "road_narrows_right","work_in_progress","traffic_light_ahead",
50 | "pedestrian_crosswalk","children_area","bicycle_crossing",
51 | "beware_of_ice","wild_animal_crossing","end_of_restriction",
52 | "must_turn_right","must_turn_left","must_go_straight",
53 | "must_go_straight_or_right","must_go_straight_or_left",
54 | "mandatroy_direction_bypass_obstacle",
55 | "mandatroy_direction_bypass_obstacle2",
56 | "traffic_circle","end_of_no_car_overtaking",
57 | "end_of_no_truck_overtaking"];
58 |
59 | print("resizing image.....")
60 | #Process image to be sent to Neural Net
61 | img = Image.open(filename)
62 | img_resized = img.resize((FLAGS.image_width, FLAGS.image_height), Image.ANTIALIAS)
63 | img_batch_np = np.array(img_resized)
64 | plt.imshow(img_batch_np)
65 | img_batch = img_batch_np.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels)
66 |
67 | print("loading network graph.....")
68 | #Recreate network graph.
69 | sess = tf.Session()
70 | latest_checkpoint_name = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_directory)
71 | saver = tf.train.import_meta_graph(latest_checkpoint_name+'.meta') #At this step only graph is created.
72 |
73 | #Accessing the default graph which we have restored
74 | graph = tf.get_default_graph()
75 |
76 | print("loading network weights.....")
77 | #Get model's graph
78 | checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory)
79 | saver.restore(sess, checkpoint_file) #Load the weights saved using the restore method.
80 |
81 | print("classification process started.....")
82 | start = time.time()
83 | probabilities = graph.get_tensor_by_name("softmax_tensor:0")
84 | classes = graph.get_tensor_by_name("ArgMax:0") #'ArgMax:0' is the name of the argmax tensor in the train_alexnet.py file.
85 | feed_dict = {"Reshape:0": img_batch} #'Reshape:0' is the name of the 'input_layer' tensor in the train_alexnet.py. Given to it as default.
86 | predicted_class = sess.run(classes, feed_dict)
87 | predicted_probabilities = sess.run(probabilities, feed_dict)
88 | assurance = predicted_probabilities[0,int(predicted_class)]*100;
89 | end = time.time()
90 | difference = end-start
91 | difference_milli = difference*1000
92 |
93 | print("Predicted Sign: '", categories[int(predicted_class)], "' With ", assurance," Percent Assurance")
94 | print("Time Taken For Classification: %f millisecond(s)" % difference_milli)
95 | plt.show()
96 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
97 |
--------------------------------------------------------------------------------
/create_imdb.py:
--------------------------------------------------------------------------------
1 | """--------------------------------------------------------------------------------------------------------------------------------------------------
2 | REFERENCE:
3 | ----------
4 | Code adapted from Google Tensor FLow Git Hub Repositiory:
5 | https://github.com/tensorflow/models/blob/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/inception/inception/data/build_image_data.py
6 |
7 | @author: Oluwole Oyetoke
8 | @date: 5th December, 2017
9 | @langauge: Python/TF
10 | @email: oluwoleoyetoke@gmail.com
11 |
12 | INTRODUCTION:
13 | -------------
14 | Converts image dataset to a sharded dataset. The sharded dataset consists of Tensor Flow Records Format (TFRecords) and with Example protos.
15 | train_directory/train-00000-of-01024
16 | train_directory/train-00001-of-01024
17 | ...
18 | train_directory/train-01023-of-01024
19 | and
20 | validation_directory/validation-00000-of-00128
21 | validation_directory/validation-00001-of-00128
22 | ...
23 | validation_directory/validation-00127-of-00128
24 |
25 | EXPECTATIONS:
26 | ------------
27 | 1. Image data set should be in .jpeg format
28 | 2. It is adviced that you have only folders in the base directory containing your training images."
29 | "Base folder-->Subfolders-->Each subfolder containing specific classes of image."
30 | "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg....";
31 | (data_dir/label_0/image0.jpeg
32 | (data_dir/label_0/image1.jpg)
33 | 3.The sub-directory should be the unique label associated with the images in the folder.
34 |
35 |
36 | SHARDS CONTENT:
37 | --------------
38 | Where we have selected [x] number of image files per training dataset shard and [y] number of image files per evaluation dataset shard,
39 | for each of the shards, each record within the TFRecord file (shard) is a serialized example proto consisting of the following fields:
40 | image/encoded: string containing JPEG encoded image in RGB colorspace
41 | image/height: integer, image height in pixels
42 | image/width: integer, image width in pixels
43 | image/colorspace: string, specifying the colorspace, always 'RGB'
44 | image/channels: integer, specifying the number of channels, always 3
45 | image/format: string, specifying the format, always 'JPEG'
46 | image/filename: string containing the basename of the image file e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
47 | image/class/label:integer specifying the index in a classification layer. The label ranges from [0, num_labels] where 0 is unused and left as the background class.
48 | image/class/text: string specifying the human-readable version of the label e.g. 'dog'
49 | If your data set involves bounding boxes, please look at build_imagenet_data.py.
50 |
51 |
52 |
53 |
54 | #IMPORTS
55 | -------------------------------------------------------------------------------------------------------------------------------------------------------"""
56 | from __future__ import absolute_import
57 | from __future__ import division
58 | from __future__ import print_function
59 |
60 | from datetime import datetime
61 | from PIL import Image
62 | from time import sleep
63 | import os
64 | import random
65 | import sys
66 | import threading
67 |
68 | import numpy as np
69 | import tensorflow as tf
70 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
71 |
72 |
73 |
74 |
75 | # SETTING SOME GLOBAL DATA
76 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
77 | tf.app.flags.DEFINE_string('train_directory', '/home/olu/Dev/data_base/sign_base/training_227x227', 'Training data directory')
78 | tf.app.flags.DEFINE_string('validation_directory', '/home/olu/Dev/data_base/sign_base/training_227x227', 'Validation data directory')
79 | tf.app.flags.DEFINE_string('output_directory', '/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227', 'Output data directory')
80 | tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.')
81 | tf.app.flags.DEFINE_integer('validation_shards', 2, 'Number of shards in validation TFRecord files.')
82 | tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.')
83 | tf.app.flags.DEFINE_string('labels_file', '/home/olu/Dev/data_base/sign_base/labels.txt', 'Labels_file.txt')
84 | tf.app.flags.DEFINE_integer("image_height", 227, "Height of the output image after crop and resize.") #Alexnet takes 227 x 227 image input
85 | tf.app.flags.DEFINE_integer("image_width", 227, "Width of the output image after crop and resize.")
86 | FLAGS = tf.app.flags.FLAGS
87 | """ The labels file contains a list of valid labels are held in this file. The file contains entries such as:
88 | speed_100
89 | speed_120
90 | no_car_overtaking
91 | no_truck_overtaking
92 | Each line corresponds to a label, and each label (per line) is mapped to an integer corresponding to the line number starting from 0.
93 |
94 |
95 | -------------------------------------------------------------------------------------------------------------------------------------------------------"""
96 |
97 |
98 |
99 |
100 | # WRAPPER FOR INSERTING int64 FEATURES int64 FEATURES & BYTES FEATURES INTO EXAMPLES PROTO
101 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
102 | def _int64_feature(value):
103 | if not isinstance(value, list):
104 | value = [value]
105 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
106 |
107 | def _bytes_feature(value):
108 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
109 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
110 |
111 |
112 |
113 |
114 |
115 | #FUNCTION FOR BUILDING A PROTO
116 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
117 | def _convert_to_example(filename, image_buffer, label, text, shape_buffer):
118 | """Build an Example proto for an example.
119 | Args:
120 | filename: string, path to an image file, e.g., '/path/to/example.JPG'
121 | image_buffer: string, JPEG encoding of RGB image
122 | label: integer, identifier for the ground truth for the network
123 | text: string, unique human-readable, e.g. 'dog'
124 | height: integer, image height in pixels
125 | width: integer, image width in pixels
126 | Returns:
127 | Example proto
128 | """
129 |
130 | colorspace = 'RGB'
131 | channels = 3
132 | image_format = 'JPEG'
133 | #Save TFrecord containing image_bytes, shape [337,337,3], label, text, filename
134 | example = tf.train.Example(features=tf.train.Features(feature={
135 | 'image/shape': _bytes_feature(shape_buffer),
136 | 'image/class/label': _int64_feature(label),
137 | 'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
138 | 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
139 | 'image/encoded': _bytes_feature(image_buffer)}))
140 | return example
141 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
142 |
143 |
144 |
145 |
146 | #CLASS WIH FUNCTIONS TO HELP ENCODE & DECODE IMAGES
147 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
148 | class ImageCoder(object):
149 |
150 | def __init__(self):
151 | # Create a single Session to run all image coding calls.
152 | self._sess = tf.Session()
153 |
154 | # Initializes function that converts PNG to JPEG data.
155 | self._png_data = tf.placeholder(dtype=tf.string)
156 | image = tf.image.decode_png(self._png_data, channels=3)
157 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
158 |
159 | # Initializes function that decodes RGB JPEG data.
160 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
161 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
162 |
163 | def png_to_jpeg(self, image_data):
164 | return self._sess.run(self._png_to_jpeg,
165 | feed_dict={self._png_data: image_data})
166 |
167 | def decode_jpeg(self, image_data):
168 | image = self._sess.run(self._decode_jpeg,
169 | feed_dict={self._decode_jpeg_data: image_data})
170 | assert len(image.shape) == 3
171 | assert image.shape[2] == 3
172 | return image
173 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
174 |
175 |
176 |
177 |
178 | # PRE-PROCESS SINGLE IMAGE(Check if PNG, convert to JPEG, confirm conversion)
179 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
180 | def _is_png(filename):
181 | """Determine if a file contains a PNG format image.
182 | Args:
183 | filename: string, path of the image file.
184 | Returns:
185 | boolean indicating if the image is a PNG.
186 | """
187 | return '.png' in filename
188 |
189 |
190 | def _process_image(filename, coder):
191 | """Process a single image file.
192 | Args:
193 | filename: string, path to an image file e.g., '/path/to/example.JPG'.
194 | coder: instance of ImageCoder to provide TensorFlow image coding utils.
195 | Returns:
196 | image_buffer: string, JPEG encoding of RGB image.
197 | height: integer, image height in pixels.
198 | width: integer, image width in pixels.
199 | """
200 | #Resize image to networks input size
201 | size=(FLAGS.image_height, FLAGS.image_width)
202 | original_image = Image.open(filename)
203 | width, height = original_image.size
204 | #print('The original image size is {wide} wide x {height} high'.format(wide=width, height=height))
205 |
206 | resized_image = original_image.resize(size)
207 | width, height = resized_image.size
208 | #print('The resized image size is {wide} wide x {height} high'.format(wide=width, height=height))
209 | resized_image.save(filename)
210 |
211 |
212 | #Sleep a bit before file is re-read 5 milliseconds
213 | sleep(0.005)
214 |
215 | #ensure that all dataset images have been conveted to .jpeg
216 | image = np.asarray(original_image, np.uint8) #get image data
217 | shape = np.array(image.shape, np.int32) #get image shape
218 | shape_data = shape.tobytes() #convert image shape to bytes
219 | image_data = image.tobytes() # convert image to raw data bytes in the array.
220 |
221 | """ ANOTHER METHOD
222 | # Read the image file.
223 | with tf.gfile.FastGFile(filename, 'rb') as f:
224 | image_data = f.read()
225 |
226 |
227 | # Convert any PNG to JPEG's for consistency.
228 | if _is_png(filename):
229 | print('Converting PNG to JPEG for %s' % filename)
230 | image_data = coder.png_to_jpeg(image_data)
231 |
232 | # Decode the RGB JPEG.
233 | image = coder.decode_jpeg(image_data)
234 |
235 | # Check that image converted to RGB
236 | assert len(image.shape) == 3
237 | height = image.shape[0]
238 | width = image.shape[1]
239 | assert image.shape[2] == 3"""
240 | return image_data, shape_data
241 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
242 |
243 |
244 |
245 |
246 | # PROCESS BATCHES OF IMAGES AS AS EXAMPLE PROTO SAVED TO TFRecord PER SHARD
247 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
248 | def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
249 | texts, labels, num_shards):
250 | """Processes and saves list of images as TFRecord in 1 thread.
251 | Args:
252 | coder: instance of ImageCoder to provide TensorFlow image coding utils.
253 | thread_index: integer, unique batch to run index is within [0, len(ranges)).
254 | ranges: list of pairs of integers specifying ranges of each batches to
255 | analyze in parallel.
256 | name: string, unique identifier specifying the data set
257 | filenames: list of strings; each string is a path to an image file
258 | texts: list of strings; each string is human readable, e.g. 'dog'
259 | labels: list of integer; each integer identifies the ground truth
260 | num_shards: integer number of shards for this data set.
261 | """
262 | # Each thread produces N shards where N = int(num_shards / num_threads).
263 | # For instance, if num_shards = 128, and the num_threads = 2, then the first
264 | # thread would produce shards [0, 64).
265 | num_threads = len(ranges)
266 | assert not num_shards % num_threads
267 | num_shards_per_batch = int(num_shards / num_threads) #Same as number of shards per thread
268 |
269 | shard_ranges = np.linspace(ranges[thread_index][0],
270 | ranges[thread_index][1],
271 | num_shards_per_batch + 1).astype(int)
272 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
273 |
274 | counter = 0
275 | for s in range(num_shards_per_batch):
276 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
277 | shard = thread_index * num_shards_per_batch + s
278 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
279 | output_file = os.path.join(FLAGS.output_directory, output_filename)
280 | writer = tf.python_io.TFRecordWriter(output_file)
281 |
282 | shard_counter = 0
283 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
284 | for i in files_in_shard:
285 | filename = filenames[i]
286 | label = labels[i]
287 | text = texts[i]
288 |
289 | try:
290 | # image_buffer, height, width = _process_image(filename, coder)
291 | image_buffer, shape_buffer = _process_image(filename, coder)
292 | except Exception as e:
293 | print(e)
294 | print('SKIPPED: Unexpected eror while decoding %s.' % filename)
295 | continue
296 |
297 | example = _convert_to_example(filename, image_buffer, label,
298 | text, shape_buffer)
299 | writer.write(example.SerializeToString())
300 | shard_counter += 1
301 | counter += 1
302 |
303 | if not counter % 1000:
304 | print('%s [thread %d]: Processed %d of %d images in thread batch.' %
305 | (datetime.now(), thread_index, counter, num_files_in_thread))
306 | sys.stdout.flush()
307 |
308 | writer.close()
309 | print('%s [thread %d]: Wrote %d images to %s' %
310 | (datetime.now(), thread_index, shard_counter, output_file))
311 | sys.stdout.flush()
312 | shard_counter = 0
313 | print('%s [thread %d]: Wrote %d images to %d shards.' %
314 | (datetime.now(), thread_index, counter, num_files_in_thread))
315 | sys.stdout.flush()
316 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
317 |
318 |
319 |
320 |
321 | # PROCESS AND SAVES LIST OF IMAGES AS TFRecord OF EXAMPLE PROTOS (Entire Dataset details sent here)
322 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
323 | def _process_image_files(name, filenames, texts, labels, num_shards):
324 | """Process and save list of images as TFRecord of Example protos.
325 | Args:
326 | name: string, unique identifier specifying the data set
327 | filenames: list of strings; each string is a path to an image file
328 | texts: list of strings; each string is human readable, e.g. 'dog'
329 | labels: list of integer; each integer identifies the ground truth
330 | num_shards: integer number of shards for this data set.
331 | """
332 | assert len(filenames) == len(texts)
333 | assert len(filenames) == len(labels)
334 |
335 | # Break all images into batches with a [ranges[i][0], ranges[i][1]].
336 | spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
337 | ranges = []
338 | for i in range(len(spacing) - 1):
339 | ranges.append([spacing[i], spacing[i + 1]])
340 |
341 | # Launch a thread for each batch.
342 | print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
343 | sys.stdout.flush()
344 |
345 | # Create a mechanism for monitoring when all threads are finished.
346 | coord = tf.train.Coordinator()
347 |
348 | # Create a generic TensorFlow-based utility for converting all image codings.
349 | coder = ImageCoder()
350 |
351 | threads = []
352 | for thread_index in range(len(ranges)):
353 | args = (coder, thread_index, ranges, name, filenames,
354 | texts, labels, num_shards)
355 | #From the entire data set details sent to _process_image_files, convert then to proto examples in batches (per no of threads set) and save as TFRecord
356 | t = threading.Thread(target=_process_image_files_batch, args=args)
357 | t.start()
358 | threads.append(t)
359 |
360 | # Wait for all the threads to terminate.
361 | coord.join(threads)
362 | print('%s: Finished writing all %d images in data set.' %
363 | (datetime.now(), len(filenames)))
364 | sys.stdout.flush()
365 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
366 |
367 |
368 |
369 |
370 |
371 | #BUILD LIST OF ALL IMAGES FILES AND LABELS IN THE DATA SET
372 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
373 | def _find_image_files(data_dir, labels_file):
374 | """
375 | Args:
376 | data_dir: string, path to the root directory of images.
377 | Assumes that the image data set resides in JPEG files located in
378 | the following directory structure.
379 | data_dir/dog/another-image.JPEG
380 | data_dir/dog/my-image.jpg
381 | where 'dog' is the label associated with these images.
382 | labels_file: string, path to the labels file.
383 | The list of valid labels are held in this file. Assumes that the file
384 | contains entries as such:
385 | dog
386 | cat
387 | flower
388 | where each line corresponds to a label. We map each label contained in
389 | the file to an integer starting with the integer 0 corresponding to the
390 | label contained in the first line.
391 | Returns:
392 | filenames: list of strings; each string is a path to an image file.
393 | texts: list of strings; each string is the class, e.g. 'dog'
394 | labels: list of integer; each integer identifies the ground truth.
395 | """
396 | print('Determining list of input files and labels from %s ' % labels_file)
397 | unique_labels = [l.strip() for l in tf.gfile.FastGFile(
398 | labels_file, 'r').readlines()]
399 |
400 | labels = []
401 | filenames = []
402 | texts = []
403 |
404 | # Leave label index 0 empty as a background class.
405 | label_index = 1
406 |
407 | # Construct the list of JPEG files and labels.
408 | for text in unique_labels:
409 | jpeg_file_path = '%s/%s/*' % (data_dir, text)
410 | print("File path %s \n" % jpeg_file_path);
411 | matching_files = tf.gfile.Glob(jpeg_file_path)
412 |
413 | labels.extend([label_index] * len(matching_files))
414 | texts.extend([text] * len(matching_files))
415 | filenames.extend(matching_files)
416 |
417 | if not label_index % 100:
418 | print('Finished finding files in %d of %d classes.' % (
419 | label_index, len(labels)))
420 | label_index += 1
421 |
422 | # Shuffle the ordering of all image files in order to guarantee
423 | # random ordering of the images with respect to label in the
424 | # saved TFRecord files. Make the randomization repeatable.
425 | shuffled_index = list(range(len(filenames)))
426 | random.seed(12345)
427 | random.shuffle(shuffled_index)
428 |
429 | filenames = [filenames[i] for i in shuffled_index]
430 | texts = [texts[i] for i in shuffled_index]
431 | labels = [labels[i] for i in shuffled_index]
432 |
433 | print('Found %d JPEG files across %d labels inside %s' %
434 | (len(filenames), len(unique_labels), data_dir))
435 | return filenames, texts, labels
436 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
437 |
438 |
439 |
440 | #CALL TO PROCESS DATASET IS MADE HERE
441 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
442 | def _process_dataset(name, directory, num_shards, labels_file):
443 | """Process a complete data set and save it as a TFRecord.
444 | Args:
445 | name: string, unique identifier specifying the data set.
446 | directory: string, root path to the data set.
447 | num_shards: integer number of shards for this data set.
448 | labels_file: string, path to the labels file.
449 | """
450 | filenames, texts, labels = _find_image_files(directory, labels_file) #Build list of dataset image file (path to them) and their labels as string and integer
451 | _process_image_files(name, filenames, texts, labels, num_shards) #Process the entire list of images in the dataset into a TF record
452 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
453 |
454 |
455 |
456 | #MAIN
457 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
458 | def main(unused_argv):
459 | assert not FLAGS.train_shards % FLAGS.num_threads, ('Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
460 | assert not FLAGS.validation_shards % FLAGS.num_threads, ('Please make the FLAGS.num_threads commensurate with ''FLAGS.validation_shards')
461 | print('Result will be saved to %s' % FLAGS.output_directory)
462 |
463 | # Run it!
464 | _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file)
465 | _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file)
466 |
467 |
468 | if __name__ == '__main__':
469 | tf.app.run()
470 |
471 | """-------------------------------------------------------------------------------------------------------------------------------------------------------"""
472 |
--------------------------------------------------------------------------------
/freeze_and_convert_to_tflite.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #################################################################################################################
4 | # Bash script to help freez tensorflow GraphDef + ckpt into FrozenGraphDef, Optimize and convert to tflite model#
5 | # $1 = path/to/.pbtx file #
6 | # $2 = path/to/.ckpt file #
7 | # $3 = true/false. 'true' for .pb and 'false' for .pbtx input #
8 | # $4 = path/to/save/frozengraph #
9 | # $5 = graph input node name #
10 | # $6 = graph output_node_name #
11 | # $7 = path/to/tflite/file e.g path/to/tflite_model.lite #
12 | # $8 = input type e.g FLOAT #
13 | # $9 = inference (output) type (FLOAT or QUANTIZED) #
14 | # $10 = path/to/optimized_graph.pb #
15 | # #
16 | # Sample usage: #
17 | # freez_graph_tf.sh /tmp/graph.pbtx /tmp/model.ckpt-0 false /tmp/frozen.pb ArgMax #
18 | # #
19 | # Make sure you are runing this from the tensorflow home folder after cloning the TF repository #
20 | # #
21 | #################################################################################################################
22 |
23 |
24 | #FREEZE GRAPH
25 | #bazel build tensorflow/python/tools:freeze_graph
26 | bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$1 --input_checkpoint=$2 --input_binary=$3 --output_graph=$4 --output_node_names=$6
27 |
28 |
29 | #VIEW SUMARY OF GRAPH
30 | #bazel build /tensorflow/tools/graph_transforms:summarize_graph
31 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=$4
32 |
33 | #TRANSFORM/OPTIMIZE GRAPH
34 | #bazel build tensorflow/tools/graph_transforms:transform_graph
35 | bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$4 --out_graph=${10} --inputs=$5 --outputs=$6 --transforms='strip_unused_nodes(type=float, shape="1,227,227,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) round_weights(num_steps=256) obfuscate_names sort_by_execution_order fold_old_batch_norms fold_batch_norms'
36 |
37 |
38 | #CONVERT TO TFLITE MODEL
39 | #bazel build /tensorflow/contrib/lite/toco:toco
40 | bazel-bin/tensorflow/contrib/lite/toco/toco --input_format=TENSORFLOW_GRAPHDEF --input_file=${10} --output_format=TFLITE --output_file=$7 --inference_type=$9 --input_type=$8 --input_arrays=$5 --output_arrays=$6 --inference_input_type=$8 --input_shapes=1,227,227,3
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/imgs/accuracy_per_epoch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/accuracy_per_epoch.png
--------------------------------------------------------------------------------
/imgs/loss_per_epoch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/loss_per_epoch.png
--------------------------------------------------------------------------------
/imgs/network_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/network_visualization.png
--------------------------------------------------------------------------------
/imgs/tensorboard_plots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/tensorboard_plots.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #IMPORTS
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 | from sklearn.model_selection import train_test_split
6 |
7 | #import Image
8 | import math
9 | import numpy as np
10 | import tensorflow as tf
11 | import matplotlib.pyplot as plt
12 | import os
13 | import time
14 |
15 | def _int64_feature(value):
16 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
17 |
18 | def _bytes_feature(value):
19 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
20 |
21 |
22 | #FUNCTION TO GET ALL DATASET DATA
23 | def _process_dataset(serialized):
24 |
25 |
26 | #Specify the fatures you want to extract
27 | features = {'image/shape': tf.FixedLenFeature([], tf.string),
28 | 'image/class/label': tf.FixedLenFeature([], tf.int64),
29 | 'image/class/text': tf.FixedLenFeature([], tf.string),
30 | 'image/filename': tf.FixedLenFeature([], tf.string),
31 | 'image/encoded': tf.FixedLenFeature([], tf.string)}
32 | parsed_example = tf.parse_single_example(serialized, features=features)
33 |
34 | #Finese extracted data
35 | image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8)
36 | shape = tf.decode_raw(parsed_example['image/shape'], tf.int32)
37 | label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32)
38 | reshaped_img = tf.reshape(image_raw, shape)
39 | casted_img = tf.cast(reshaped_img, tf.float32)
40 | label_tensor= [label]
41 | image_tensor = [casted_img]
42 | return label_tensor, image_tensor
43 |
44 |
45 | #MAIN FUNCTION
46 | def main(unused_argv):
47 |
48 |
49 | print("STARTED\n\n")
50 |
51 | #Declare needed variables
52 | perform_shuffle=False
53 | repeat_count=1
54 | batch_size=1000
55 | num_of_epochs=1
56 |
57 | #Directory path to the '.tfrecord' files
58 | filenames = ["/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00000-of-00002", "/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00001-of-00002"]
59 |
60 |
61 | print("GETTING RECORD COUNT")
62 | #Determine total number of records in the '.tfrecord' files
63 | record_count = 0
64 | for fn in filenames:
65 | for record in tf.python_io.tf_record_iterator(fn):
66 | record_count += 1
67 | print("Total Number of Records in the .tfrecord file(s): %i" % record_count)
68 |
69 |
70 | dataset = tf.data.TFRecordDataset(filenames=filenames)
71 | dataset = dataset.map(_process_dataset) #Get all content of dataset
72 | dataset = dataset.shuffle(buffer_size=1000) #Shuffle selection from the dataset
73 | dataset = dataset.repeat(repeat_count) #Repeats dataset this # times
74 | dataset = dataset.batch(batch_size) #Batch size to use
75 | iterator = dataset.make_initializable_iterator() #Create iterator which helps to get all iamges in the dataset
76 | labels_tensor, images_tensor = iterator.get_next() #Get batch data
77 | no_of_rounds = int(math.ceil(record_count/batch_size));
78 |
79 | #Create tf session, get nest set of batches, and evelauate them in batches
80 | sess = tf.Session()
81 | print("Total number of strides needed to stream through dataset: ~%i" %no_of_rounds)
82 |
83 |
84 | for _ in range(2):
85 | sess.run(iterator.initializer)
86 | count=0
87 | complete_evaluation_image_set = np.array([])
88 | complete_evaluation_label_set = np.array([])
89 | while True:
90 | try:
91 | print("Now evaluating tensors for stride %i out of %i" % (count, no_of_rounds))
92 | evaluated_label, evaluated_image = sess.run([labels_tensor, images_tensor])
93 | #convert evaluated tensors to np array
94 | label_np_array = np.asarray(evaluated_label, dtype=np.uint8)
95 | image_np_array = np.asarray(evaluated_image, dtype=np.uint8)
96 | #squeeze np array to make dimesnsions appropriate
97 | squeezed_label_np_array = label_np_array.squeeze()
98 | squeezed_image_np_array = image_np_array.squeeze()
99 | #Split data into training and testing data
100 | image_train, image_test, label_train, label_test = train_test_split(squeezed_image_np_array, squeezed_label_np_array, test_size=0.010, random_state=42, shuffle=True)
101 | #Store evaluation data in its place
102 | complete_evaluation_image_set = np.append(complete_evaluation_image_set, image_test)
103 | complete_evaluation_label_set = np.append(complete_evaluation_label_set, label_test)
104 | #Feed current batch to TF Estimator for training
105 | except tf.errors.OutOfRangeError:
106 | print("End of Dataset Reached")
107 | break
108 | count=count+1
109 | print(complete_evaluation_label_set.shape)
110 | print(complete_evaluation_image_set.shape)
111 | sess.close()
112 |
113 | print("End of Training")
114 |
115 |
116 |
117 | if __name__ == "__main__":
118 | tf.app.run()
119 |
120 |
--------------------------------------------------------------------------------
/train_alexnet.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | """
4 | Used to Create AlexNet and Train it using image data stored in a TF Record over several Epochs
5 |
6 | @date: 4th December, 2017
7 | @author: Oluwole Oyetoke
8 | @Language: Python
9 | @email: oluwoleoyetoke@gmail.com
10 |
11 | AlexNet NETWORK OVERVIEW
12 | AlexNet Structure: 60 million Parameters
13 | 8 layers in total: 5 Convolutional and 3 Fully Connected Layers
14 | [227x227x3] INPUT
15 | [55x55x96] CONV1: 96 11x11 filters at stride 4, pad 0
16 | [27x27x96] MAX POOL1: 3x3 filters at stride 2
17 | [27x27x96] NORM1: Normalization layer
18 | [27x27x256] CONV2: 256 5x5 filters at stride 1, pad 2
19 | [13x13x256] MAX POOL2: 3x3 filters at stride 2
20 | [13x13x256] NORM2: Normalization layer
21 | [13x13x384] CONV3: 384 3x3 filters at stride 1, pad 1
22 | [13x13x384] CONV4: 384 3x3 filters at stride 1, pad 1
23 | [13x13x256] CONV5: 256 3x3 filters at stride 1, pad 1
24 | [6x6x256] MAX POOL3: 3x3 filters at stride 2
25 | [4096] FC6: 4096 neurons
26 | [4096] FC7: 4096 neurons
27 | [1000] FC8: 43 neurons (class scores)
28 |
29 |
30 |
31 |
32 |
33 |
34 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING
35 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 | from sklearn.model_selection import train_test_split
40 | from dateutil.relativedelta import relativedelta
41 |
42 | import os
43 | import math
44 | import time
45 | import numpy as np
46 | import tensorflow as tf #import tensorflow
47 | import matplotlib.pyplot as plt
48 | import datetime
49 |
50 | flags = tf.app.flags
51 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width")
52 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height")
53 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels")
54 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases")
55 | FLAGS = flags.FLAGS
56 |
57 | losses_bank = np.array([]) #global
58 | accuracy_bank = np.array([]) #global
59 | steps_bank = np.array([]) #global
60 | epoch_bank = np.array([]) #global
61 |
62 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN)
63 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
64 |
65 |
66 |
67 |
68 |
69 |
70 | #WRAPPER FOR INSERTING int64 FEATURES int64 FEATURES & BYTES FEATURES INTO EXAMPLES PROTO
71 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
72 | def _int64_feature(value):
73 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
74 |
75 | def _bytes_feature(value):
76 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
77 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
78 |
79 |
80 | #CREATE CNN STRUCTURE
81 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
82 | def cnn_model_fn(features, labels, mode):
83 |
84 | """INPUT LAYER"""
85 | input_layer = tf.reshape(features["x"], [-1, FLAGS.image_width, FLAGS.image_height, FLAGS.image_channels], name="input_layer") #Alexnet uses 227x227x3 input layer. '-1' means pick batch size randomly
86 | #print(input_layer)
87 |
88 | """%FIRST CONVOLUTION BLOCK
89 | The first convolutional layer filters the 227×227×3 input image with
90 | 96 kernels of size 11×11 with a stride of 4 pixels. Bias of 1."""
91 | conv1 = tf.layers.conv2d(inputs=input_layer, filters=96, kernel_size=[11, 11], strides=4, padding="valid", activation=tf.nn.relu)
92 | lrn1 = tf.nn.lrn(input=conv1, depth_radius=5, bias=1.0, alpha=0.0001/5.0, beta=0.75); #Normalization layer
93 | pool1_conv1 = tf.layers.max_pooling2d(inputs=lrn1, pool_size=[3, 3], strides=2) #Max Pool Layer
94 | #print(pool1_conv1)
95 |
96 |
97 | """SECOND CONVOLUTION BLOCK
98 | Divide the 96 channel blob input from block one into 48 and process independently"""
99 | conv2 = tf.layers.conv2d(inputs=pool1_conv1, filters=256, kernel_size=[5, 5], strides=1, padding="same", activation=tf.nn.relu)
100 | lrn2 = tf.nn.lrn(input=conv2, depth_radius=5, bias=1.0, alpha=0.0001/5.0, beta=0.75); #Normalization layer
101 | pool2_conv2 = tf.layers.max_pooling2d(inputs=lrn2, pool_size=[3, 3], strides=2) #Max Pool Layer
102 | #print(pool2_conv2)
103 |
104 | """THIRD CONVOLUTION BLOCK
105 | Note that the third, fourth, and fifth convolution layers are connected to one
106 | another without any intervening pooling or normalization layers.
107 | The third convolutional layer has 384 kernels of size 3 × 3
108 | connected to the (normalized, pooled) outputs of the second convolutional layer"""
109 | conv3 = tf.layers.conv2d(inputs=pool2_conv2, filters=384, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu)
110 | #print(conv3)
111 |
112 | #FOURTH CONVOLUTION BLOCK
113 | """%The fourth convolutional layer has 384 kernels of size 3 × 3"""
114 | conv4 = tf.layers.conv2d(inputs=conv3, filters=384, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu)
115 | #print(conv4)
116 |
117 | #FIFTH CONVOLUTION BLOCK
118 | """%the fifth convolutional layer has 256 kernels of size 3 × 3"""
119 | conv5 = tf.layers.conv2d(inputs=conv4, filters=256, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu)
120 | pool3_conv5 = tf.layers.max_pooling2d(inputs=conv5, pool_size=[3, 3], strides=2, padding="valid") #Max Pool Layer
121 | #print(pool3_conv5)
122 |
123 |
124 | #FULLY CONNECTED LAYER 1
125 | """The fully-connected layers have 4096 neurons each"""
126 | pool3_conv5_flat = tf.reshape(pool3_conv5, [-1, 6* 6 * 256]) #output of conv block is 6x6x256 therefore, to connect it to a fully connected layer, we can flaten it out
127 | fc1 = tf.layers.dense(inputs=pool3_conv5_flat, units=4096, activation=tf.nn.relu)
128 | #fc1 = tf.layers.conv2d(inputs=pool3_conv5, filters=4096, kernel_size=[6, 6], strides=1, padding="valid", activation=tf.nn.relu) #representing the FCL using a convolution block (no need to do 'pool3_conv5_flat' above)
129 | #print(fc1)
130 |
131 | #FULLY CONNECTED LAYER 2
132 | """since the output from above is [1x1x4096]"""
133 | fc2 = tf.layers.dense(inputs=fc1, units=4096, activation=tf.nn.relu)
134 | #fc2 = tf.layers.conv2d(inputs=fc1, filters=4096, kernel_size=[1, 1], strides=1, padding="valid", activation=tf.nn.relu)
135 | #print(fc2)
136 |
137 | #FULLY CONNECTED LAYER 3
138 | """since the output from above is [1x1x4096]"""
139 | logits = tf.layers.dense(inputs=fc2, units=FLAGS.num_of_classes, name="logits_layer")
140 | #fc3 = tf.layers.conv2d(inputs=fc2, filters=43, kernel_size=[1, 1], strides=1, padding="valid")
141 | #logits = tf.layers.dense(inputs=fc3, units=FLAGS.num_of_classes) #converting the convolutional block (tf.layers.conv2d) to a dense layer (tf.layers.dense). Only needed if we had used tf.layers.conv2d to represent the FCLs
142 | #print(logits)
143 |
144 | #PASS OUTPUT OF LAST FC LAYER TO A SOFTMAX LAYER
145 | """convert these raw values into two different formats that our model function can return:
146 | The predicted class for each example: a digit from 1–43.
147 | The probabilities for each possible target class for each example
148 | tf.argmax(input=fc3, axis=1: Generate predictions from the 43 last filters returned from the fc3. Axis 1 will apply argmax to the rows
149 | tf.nn.softmax(logits, name="softmax_tensor"): Generate the probability distribution
150 | """
151 | predictions = {
152 | "classes": tf.argmax(input=logits, axis=1, name="classes_tensor"),
153 | "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
154 | }
155 |
156 | #Return result if we were in prediction mode and not training
157 | if mode == tf.estimator.ModeKeys.PREDICT:
158 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
159 |
160 | #CALCULATE OUR LOSS
161 | """For both training and evaluation, we need to define a loss function that measures how closely the
162 | model's predictions match the target classes. For multiclass classification, cross entropy is typically used as the loss metric."""
163 | onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=FLAGS.num_of_classes)
164 | loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)
165 | tf.summary.scalar('Loss Per Stride', loss) #Just to see loss values per batch on tensorboard
166 |
167 | #CONFIGURE TRAINING
168 | """Since the loss of the CNN is the softmax cross-entropy of the fc3 layer
169 | and our labels. Let's configure our model to optimize this loss value during
170 | training. We'll use a learning rate of 0.001 and stochastic gradient descent
171 | as the optimization algorithm:"""
172 | if mode == tf.estimator.ModeKeys.TRAIN:
173 | optimizer = tf.train.AdamOptimizer(learning_rate=0.00001)
174 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) #global_Step needed for proper graph on tensor board
175 | #optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.00005) #Very small learning rate used. Training will be slower at converging by better
176 | #train_op = optimizer.minimize(loss=loss,global_step=tf.train.get_global_step())
177 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
178 |
179 | #ADD EVALUATION METRICS
180 | eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])}
181 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
182 | """-----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
183 |
184 |
185 |
186 |
187 | #FUNCTION TO PROCESS ALL DATASET DATA
188 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
189 | def _process_dataset(serialized):
190 | #Specify the fatures you want to extract
191 | features = {'image/shape': tf.FixedLenFeature([], tf.string),
192 | 'image/class/label': tf.FixedLenFeature([], tf.int64),
193 | 'image/class/text': tf.FixedLenFeature([], tf.string),
194 | 'image/filename': tf.FixedLenFeature([], tf.string),
195 | 'image/encoded': tf.FixedLenFeature([], tf.string)}
196 | parsed_example = tf.parse_single_example(serialized, features=features)
197 |
198 | #Finese extracted data
199 | image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8)
200 | shape = tf.decode_raw(parsed_example['image/shape'], tf.int32)
201 | label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32)
202 | reshaped_img = tf.reshape(image_raw, [FLAGS.image_width, FLAGS.image_height, FLAGS.image_channels])
203 | casted_img = tf.cast(reshaped_img, tf.float32)
204 | label_tensor= [label]
205 | image_tensor = [casted_img]
206 | return label_tensor, image_tensor
207 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
208 |
209 | #PLOT TRAINING PROGRESS
210 | def _plot_training_progress():
211 | global losses_bank #to make sure losses_bank is not declared again in this method (as a local variable)
212 | global accuracy_bank
213 | global steps_bank
214 | global epoch_bank
215 |
216 | #PLOT LOSS PER EPOCH
217 | loss_per_epoch_fig = plt.figure("LOSS PER EPOCH PLOT")
218 | plt.plot(epoch_bank, losses_bank, 'ro-')
219 | loss_per_epoch_fig.suptitle('LOSS LEVEL PER EPOCH')
220 | plt.xlabel('Epoch Count')
221 | plt.ylabel('Loss Value')
222 | manager = plt.get_current_fig_manager()
223 | manager.resize(*manager.window.maxsize()) #maximize plot
224 | loss_per_epoch_fig.canvas.draw()
225 | plt.show(block=False)
226 | loss_per_epoch_fig.savefig("/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/evaluation_plots/loss_per_epoch.png") #save plot
227 |
228 |
229 | #PLOT ACCURACY PER EPOCH
230 | accuracy_per_epoch_fig = plt.figure("ACCURACY PER EPOCH PLOT")
231 | plt.plot(epoch_bank, accuracy_bank, 'bo-')
232 | accuracy_per_epoch_fig.suptitle('ACCURACY PERCENTAGE PER EPOCH')
233 | plt.xlabel('Epoch Count')
234 | plt.ylabel('Accuracy Percentage')
235 | manager = plt.get_current_fig_manager()
236 | manager.resize(*manager.window.maxsize()) #maximize plot
237 | accuracy_per_epoch_fig.canvas.draw()
238 | plt.show(block=False)
239 | accuracy_per_epoch_fig.savefig("/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/evaluation_plots/accuracy_per_epoch.png") #save plot
240 |
241 |
242 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER
243 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
244 | def main(unused_argv):
245 |
246 | #Declare needed variables
247 | global losses_bank
248 | global accuracy_bank
249 | global steps_bank
250 | global epoch_bank
251 |
252 | perform_shuffle=False
253 | repeat_count=1
254 | dataset_batch_size=1024 #1024 #Chuncks picked in dataset per time
255 | training_batch_size = np.int32(dataset_batch_size/8) #Chuncks processed by tf.estimator per time
256 | epoch_count=0
257 | overall_training_epochs=60 #60 epochs in total
258 |
259 | start_time=time.time() #taking current time as starting time
260 | start_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time))
261 | start_time_dt = datetime.datetime.strptime(start_time_string, '%Y-%m-%d %H:%M:%S')
262 | print("STARTED @ %s" % start_time_string)
263 | #LOAD TRAINING DATA
264 | print("LOADING DATASET\n\n")
265 | filenames = ["/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00000-of-00002", "/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00001-of-00002"] #Directory path to the '.tfrecord' files
266 | model_dir="/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/trained_alexnet_model"
267 | #DETERMINE TOTAL NUMBER OF RECORDS IN THE '.tfrecord' FILES
268 | print("GETTING COUNT OF RECORDS/EXAMPLES IN DATASET")
269 | record_count = 0
270 | for fn in filenames:
271 | for record in tf.python_io.tf_record_iterator(fn):
272 | record_count += 1
273 | print("Total number of records in the .tfrecord file(s): %i\n\n" % record_count)
274 | no_of_rounds = int(math.ceil(record_count/dataset_batch_size));
275 |
276 |
277 | #EDIT HOW OFTEN CHECK POINTS SHOULD BE SAVED
278 | check_point_interval = int(record_count/training_batch_size)
279 | my_estimator_config = tf.estimator.RunConfig(model_dir=model_dir,tf_random_seed=None,save_summary_steps=100,
280 | save_checkpoints_steps=check_point_interval,save_checkpoints_secs=None,session_config=None,keep_checkpoint_max=10,keep_checkpoint_every_n_hours=10000,log_step_count_steps=100)
281 |
282 | print("CREATING ESTIMATOR AND LOADING DATASET")
283 | #CREATE ESTIMATOR
284 | """Estimator: a TensorFlow class for performing high-level model training, evaluation, and inference"""
285 | traffic_sign_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=model_dir, config=my_estimator_config) #specify where the finally trained model and (checkpoints during training) should be saved in
286 |
287 | #SET-UP LOGGIN FOR PREDICTIONS
288 | tensors_to_log = {"probabilities": "softmax_tensor"}
289 | logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50) #Log after every 50 itterations
290 |
291 |
292 | #PROCESS AND RETREIVE DATASET CONTENT IN BATCHES OF 'dataset_batch_size'
293 | dataset = tf.data.TFRecordDataset(filenames=filenames)
294 | dataset = dataset.map(_process_dataset) #Get all content of dataset & apply function '_process_dataset' to all its content
295 | dataset = dataset.shuffle(buffer_size=dataset_batch_size) #Shuffle selection from the dataset/epoch. buffer size of 1000
296 | dataset = dataset.repeat(repeat_count) #Repeat ittereation through dataset 'repeat_count' times
297 | dataset = dataset.batch(dataset_batch_size) #Batch size to use to pick from dataset
298 | iterator = dataset.make_initializable_iterator() #Create iterator which helps to get all iamges in the dataset
299 | labels_tensor, images_tensor = iterator.get_next() #Get batch data
300 |
301 |
302 | #CREATE TF SESSION TO ITTERATIVELY EVALUATE THE BATCHES OF DATASET TENSORS RETREIVED AND PASS THEM TO ESTIMATOR FOR TRAINING/EVALUATION
303 | sess = tf.Session()
304 | print("Approximately %i strides needed to stream through 1 epoch of dataset\n\n" %no_of_rounds)
305 |
306 |
307 | for _ in range(overall_training_epochs):
308 | epoch_start_time=time.time() #taking current time as starting time
309 | epoch_start_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(epoch_start_time))
310 | epoch_start_time_dt = datetime.datetime.strptime(epoch_start_time_string, '%Y-%m-%d %H:%M:%S')
311 | epoch_count=epoch_count+1;
312 | print("Epoch %i out of %i" % (epoch_count, overall_training_epochs))
313 | print("Note: Each complete epoch processes %i images, feeding it into the classifier in batches of %i" % (record_count, dataset_batch_size))
314 | print("...Classifier then deals with each batch of %i images pushed to it in batch sizes of %i\n"%(dataset_batch_size, training_batch_size))
315 | sess.run(iterator.initializer)
316 |
317 | strides_count=1
318 | complete_evaluation_image_set = np.array([])
319 | complete_evaluation_label_set = np.array([])
320 | while True:
321 | try:
322 | imgs_so_far = dataset_batch_size*strides_count
323 | if(imgs_so_far>record_count):
324 | imgs_so_far = record_count
325 | print("\nStride %i (%i images) of stride %i (%i images) for Epoch %i" %(strides_count, imgs_so_far, no_of_rounds, record_count, epoch_count))
326 | evaluated_label, evaluated_image = sess.run([labels_tensor, images_tensor]) #evaluate tensor
327 | #convert evaluated tensors to np array
328 | label_np_array = np.asarray(evaluated_label, dtype=np.int32)
329 | image_np_array = np.asarray(evaluated_image, dtype=np.float32)
330 | #squeeze np array to make dimesnsions appropriate
331 | squeezed_label_np_array = label_np_array.squeeze()
332 | squeezed_image_np_array = image_np_array.squeeze()
333 | #mean normalization - normalize current batch of images i.e get mean of images in dataset and subtact it from all image intensities in the dataset
334 | dataset_image_mean = np.mean(squeezed_image_np_array)
335 | normalized_image_dataset = np.subtract(squeezed_image_np_array, dataset_image_mean) #help for faster convergence duing training
336 | #split data into training and testing/evaluation data
337 | image_train, image_evaluate, label_train, label_evaluate = train_test_split(normalized_image_dataset, squeezed_label_np_array, test_size=0.10, random_state=42, shuffle=True) #5% of dataset will be used for evaluation/testing
338 | #rectify dimension/shape
339 | if (image_evaluate.ndim<4): #if dimension is just 3 i.e only 1 image loaded
340 | print(image_evaluate.shape)
341 | image_evaluate = image_evaluate.reshape((1,) + image_evaluate.shape)
342 | if (image_train.ndim<4): #if dimension is just 3 i.e only 1 image loaded
343 | print(image_train.shape)
344 | image_train = image_train.reshape((1,) + image_train.shape)
345 | #rectify precision
346 | image_train.astype(np.float32)
347 | image_evaluate.astype(np.float32)
348 | #Store evaluation data in its place
349 | if(strides_count==1):
350 | complete_evaluation_image_set = image_evaluate
351 | else:
352 | complete_evaluation_image_set = np.concatenate((complete_evaluation_image_set.squeeze(), image_evaluate))
353 | complete_evaluation_label_set = np.append(complete_evaluation_label_set, label_evaluate)
354 | complete_evaluation_label_set = complete_evaluation_label_set.squeeze()
355 | #Feed current batch of training images to TF Estimator for training. TF Estimator deals with them in batches of 'batch_size=32'
356 | train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": image_train},y=label_train,batch_size=training_batch_size,num_epochs=1, shuffle=True) #Note, images have already been shuffled when placed in the TFRecord, shuffled again when being retreived from the record & will be shuffled again when being sent to the classifier
357 | traffic_sign_classifier.train(input_fn=train_input_fn,hooks=[logging_hook])
358 | except tf.errors.OutOfRangeError:
359 | print("End of Dataset Reached")
360 | break
361 | strides_count=strides_count+1
362 |
363 | #EVALUATE MODEL after every complete epoch (note that out of memory issues happen if all 20% of the dataset's images need to be stored in memory till full epoch is completed. So test set reduced to 5%)
364 | """Once trainingis completed, we then proceed to evaluate the accuracy level of our trained model
365 | To create eval_input_fn, we set num_epochs=1, so that the model evaluates the metrics over one epoch of
366 | data and returns the result. We also set shuffle=False to iterate through the data sequentially."""
367 | eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": complete_evaluation_image_set},y=complete_evaluation_label_set,num_epochs=1,shuffle=False)
368 | evaluation_result = traffic_sign_classifier.evaluate(input_fn=eval_input_fn) #Get dictionary of loss, global step size and accuracy e.g {'loss': 1.704558, 'accuracyy': 0.58105469, 'global_step': 742}
369 |
370 | #PLOT TRAINING PERFORMANCE
371 | epoch_loss = evaluation_result.get('loss')
372 | epoch_accuracy = evaluation_result.get('accuracy')
373 | epoch_steps = evaluation_result.get('global_step')
374 | losses_bank = np.append(losses_bank, epoch_loss)
375 | accuracy_bank = np.append(accuracy_bank, (epoch_accuracy*100))
376 | steps_bank = np.append(steps_bank, epoch_steps)
377 | epoch_bank = np.append(epoch_bank, epoch_count)
378 | _plot_training_progress()
379 |
380 | accuracy_percentage = epoch_accuracy*100
381 | print("Network Performance Analysis: Loss(%f), Accuracy(%f percent), Steps(%i)\n"%(epoch_loss,accuracy_percentage,epoch_steps))
382 | #print(evaluation_result)
383 |
384 | #PRINT EPOCH OPERATION TIME
385 | epoch_end_time=time.time() #taking current time as ending time
386 | epoch_end_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(epoch_end_time))
387 | epoch_end_time_dt = datetime.datetime.strptime(epoch_end_time_string, '%Y-%m-%d %H:%M:%S')
388 | epoch_elapsed_time = relativedelta(epoch_end_time_dt, epoch_start_time_dt)
389 | print("End of epoch %i.\n Time taken for this epoch: %d Day(s) : %d : Hour(s) : %d Minute(s) : %d Second(s)" %(epoch_count, epoch_elapsed_time.days, epoch_elapsed_time.hours, epoch_elapsed_time.minutes, epoch_elapsed_time.seconds))
390 | training_elapsed_time_sofar = relativedelta(epoch_end_time_dt, start_time_dt)
391 | print("Overall training time so far: %d Day(s) : %d : Hour(s) : %d Minute(s) : %d Second(s) \n\n\n" %(training_elapsed_time_sofar.days, training_elapsed_time_sofar.hours, training_elapsed_time_sofar.minutes, training_elapsed_time_sofar.seconds))
392 | sess.close()
393 |
394 | #SAVE FINAL MODEL
395 | #Not really Needed, because TF Estimator saves .meta .data .ckpt at the end of evey epoch
396 |
397 |
398 |
399 | #PRINT TOTAL TRAINING TIME & END
400 | end_time=time.time() #taking current time as ending time
401 | end_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time))
402 | end_time_dt = datetime.datetime.strptime(end_time_string, '%Y-%m-%d %H:%M:%S')
403 | elapsed_time_dt = relativedelta(end_time_dt, start_time_dt)
404 | print("END OF TRAINING..... ENDED @ %s" %end_time_string)
405 | print("Final Trained Model is Saved Here: %s" % model_dir)
406 | print("TIME TAKEN FOR ENTIRE TRAINING: %d Day(s) : %d : Hour(s) %d Minute(s) : %d Second(s)" % (elapsed_time_dt.days, elapsed_time_dt.hours, elapsed_time_dt.minutes, elapsed_time_dt.seconds))
407 |
408 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------"""
409 |
410 |
411 | if __name__=="__main__":
412 | tf.app.run()
413 |
--------------------------------------------------------------------------------
/tune_cnn.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/tune_cnn.py
--------------------------------------------------------------------------------