├── DatasetConverter ├── build.xml ├── build │ └── classes │ │ ├── .netbeans_automatic_build │ │ ├── .netbeans_update_resources │ │ └── converters │ │ ├── Converter$1.class │ │ ├── Converter.class │ │ └── ConverterMain.class ├── manifest.mf ├── nbproject │ ├── build-impl.xml │ ├── genfiles.properties │ ├── private │ │ ├── private.properties │ │ └── private.xml │ ├── project.properties │ └── project.xml └── src │ └── converters │ ├── Converter.java │ └── ConverterMain.java ├── DeleteExtrasFromDataSet.sh ├── Logs ├── imdbLogs └── initial_train_logs ├── README.md ├── classify_img.py ├── classify_img_arg.py ├── create_imdb.py ├── freeze_and_convert_to_tflite.sh ├── imgs ├── accuracy_per_epoch.png ├── loss_per_epoch.png ├── network_visualization.png └── tensorboard_plots.png ├── test.py ├── train_alexnet.py └── tune_cnn.py /DatasetConverter/build.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Builds, tests, and runs the project Converters. 12 | 13 | 73 | 74 | -------------------------------------------------------------------------------- /DatasetConverter/build/classes/.netbeans_automatic_build: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/.netbeans_automatic_build -------------------------------------------------------------------------------- /DatasetConverter/build/classes/.netbeans_update_resources: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/.netbeans_update_resources -------------------------------------------------------------------------------- /DatasetConverter/build/classes/converters/Converter$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/Converter$1.class -------------------------------------------------------------------------------- /DatasetConverter/build/classes/converters/Converter.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/Converter.class -------------------------------------------------------------------------------- /DatasetConverter/build/classes/converters/ConverterMain.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/DatasetConverter/build/classes/converters/ConverterMain.class -------------------------------------------------------------------------------- /DatasetConverter/manifest.mf: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | X-COMMENT: Main-Class will be added automatically by build 3 | 4 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/build-impl.xml: -------------------------------------------------------------------------------- 1 | 2 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | Must set src.dir 234 | Must set test.src.dir 235 | Must set build.dir 236 | Must set dist.dir 237 | Must set build.classes.dir 238 | Must set dist.javadoc.dir 239 | Must set build.test.classes.dir 240 | Must set build.test.results.dir 241 | Must set build.classes.excludes 242 | Must set dist.jar 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | Must set javac.includes 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | No tests executed. 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | Must set JVM to use for profiling in profiler.info.jvm 723 | Must set profiler agent JVM arguments in profiler.info.jvmargs.agent 724 | 725 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 767 | 768 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 792 | 793 | 794 | 795 | 796 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 831 | 832 | 833 | 834 | 835 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 844 | 845 | 846 | 847 | 848 | 849 | 850 | 851 | 852 | 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | 879 | 880 | 881 | 882 | 883 | 884 | 885 | 886 | 891 | 892 | 893 | 894 | 895 | 896 | 897 | 898 | 899 | 900 | 901 | 902 | 903 | 904 | 905 | 906 | 907 | 908 | 909 | 910 | 911 | 912 | 913 | 914 | 915 | 916 | 917 | 918 | 919 | 920 | 921 | 922 | 923 | 924 | 925 | 926 | 927 | 928 | 929 | 930 | 931 | 932 | 933 | 934 | 935 | 936 | 937 | 938 | 939 | 940 | 941 | 942 | 943 | 944 | 945 | 946 | 947 | 948 | 949 | 950 | 951 | Must select some files in the IDE or set javac.includes 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 965 | 966 | 967 | 968 | 969 | 970 | 971 | 972 | 973 | 974 | 975 | 976 | 977 | 978 | 979 | 980 | 981 | 982 | 983 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | 1000 | 1001 | To run this application from the command line without Ant, try: 1002 | 1003 | java -jar "${dist.jar.resolved}" 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1025 | 1026 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | Must select one file in the IDE or set run.class 1051 | 1052 | 1053 | 1054 | Must select one file in the IDE or set run.class 1055 | 1056 | 1057 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | 1068 | 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1078 | 1079 | 1080 | 1081 | Must select one file in the IDE or set debug.class 1082 | 1083 | 1084 | 1085 | 1086 | Must select one file in the IDE or set debug.class 1087 | 1088 | 1089 | 1090 | 1091 | Must set fix.includes 1092 | 1093 | 1094 | 1095 | 1096 | 1097 | 1098 | 1103 | 1106 | 1107 | This target only works when run from inside the NetBeans IDE. 1108 | 1109 | 1110 | 1111 | 1112 | 1113 | 1114 | 1115 | 1116 | Must select one file in the IDE or set profile.class 1117 | This target only works when run from inside the NetBeans IDE. 1118 | 1119 | 1120 | 1121 | 1122 | 1123 | 1124 | 1125 | 1126 | This target only works when run from inside the NetBeans IDE. 1127 | 1128 | 1129 | 1130 | 1131 | 1132 | 1133 | 1134 | 1135 | 1136 | 1137 | 1138 | 1139 | This target only works when run from inside the NetBeans IDE. 1140 | 1141 | 1142 | 1143 | 1144 | 1145 | 1146 | 1147 | 1148 | 1149 | 1150 | 1151 | 1152 | 1153 | 1154 | 1155 | 1156 | 1157 | 1158 | 1159 | 1160 | 1161 | 1164 | 1165 | 1166 | 1167 | 1168 | 1169 | 1170 | 1171 | 1172 | 1173 | 1174 | 1175 | 1176 | 1177 | Must select one file in the IDE or set run.class 1178 | 1179 | 1180 | 1181 | 1182 | 1183 | Must select some files in the IDE or set test.includes 1184 | 1185 | 1186 | 1187 | 1188 | Must select one file in the IDE or set run.class 1189 | 1190 | 1191 | 1192 | 1193 | Must select one file in the IDE or set applet.url 1194 | 1195 | 1196 | 1197 | 1202 | 1203 | 1204 | 1205 | 1206 | 1207 | 1208 | 1209 | 1210 | 1211 | 1212 | 1213 | 1214 | 1215 | 1216 | 1217 | 1218 | 1219 | 1220 | 1221 | 1222 | 1223 | 1224 | 1225 | 1226 | 1227 | 1228 | 1229 | 1230 | 1231 | 1232 | 1233 | 1234 | 1235 | 1236 | 1237 | 1238 | 1239 | 1240 | 1241 | 1246 | 1247 | 1248 | 1249 | 1250 | 1251 | 1252 | 1253 | 1254 | 1255 | 1256 | 1257 | 1258 | 1259 | 1260 | 1261 | 1262 | 1263 | 1264 | 1265 | 1266 | 1267 | 1268 | 1269 | 1270 | 1271 | 1272 | Must select some files in the IDE or set javac.includes 1273 | 1274 | 1275 | 1276 | 1277 | 1278 | 1279 | 1280 | 1281 | 1282 | 1283 | 1284 | 1289 | 1290 | 1291 | 1292 | 1293 | 1294 | 1295 | 1296 | Some tests failed; see details above. 1297 | 1298 | 1299 | 1300 | 1301 | 1302 | 1303 | 1304 | 1305 | Must select some files in the IDE or set test.includes 1306 | 1307 | 1308 | 1309 | Some tests failed; see details above. 1310 | 1311 | 1312 | 1313 | Must select some files in the IDE or set test.class 1314 | Must select some method in the IDE or set test.method 1315 | 1316 | 1317 | 1318 | Some tests failed; see details above. 1319 | 1320 | 1321 | 1326 | 1327 | Must select one file in the IDE or set test.class 1328 | 1329 | 1330 | 1331 | Must select one file in the IDE or set test.class 1332 | Must select some method in the IDE or set test.method 1333 | 1334 | 1335 | 1336 | 1337 | 1338 | 1339 | 1340 | 1341 | 1342 | 1343 | 1344 | 1349 | 1350 | Must select one file in the IDE or set applet.url 1351 | 1352 | 1353 | 1354 | 1355 | 1356 | 1357 | 1362 | 1363 | Must select one file in the IDE or set applet.url 1364 | 1365 | 1366 | 1367 | 1368 | 1369 | 1370 | 1371 | 1376 | 1377 | 1378 | 1379 | 1380 | 1381 | 1382 | 1383 | 1384 | 1385 | 1386 | 1387 | 1388 | 1389 | 1390 | 1391 | 1392 | 1393 | 1394 | 1395 | 1396 | 1397 | 1398 | 1399 | 1400 | 1401 | 1402 | 1403 | 1404 | 1405 | 1406 | 1407 | 1408 | 1409 | 1410 | 1411 | 1412 | 1413 | 1414 | 1415 | 1416 | 1417 | 1418 | 1419 | 1420 | 1421 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/genfiles.properties: -------------------------------------------------------------------------------- 1 | build.xml.data.CRC32=7f403d13 2 | build.xml.script.CRC32=e7a17702 3 | build.xml.stylesheet.CRC32=8064a381@1.80.1.48 4 | # This file is used by a NetBeans-based IDE to track changes in generated files such as build-impl.xml. 5 | # Do not edit this file. You may delete it but then the IDE will never regenerate such files for you. 6 | nbproject/build-impl.xml.data.CRC32=7f403d13 7 | nbproject/build-impl.xml.script.CRC32=75480b50 8 | nbproject/build-impl.xml.stylesheet.CRC32=830a3534@1.80.1.48 9 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/private/private.properties: -------------------------------------------------------------------------------- 1 | compile.on.save=true 2 | user.properties.file=/home/olu/.netbeans/8.2/build.properties 3 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/private/private.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | file:/home/olu/NetBeansProjects/FolderNameConverter/src/converters/Converter.java 7 | file:/home/olu/NetBeansProjects/FolderNameConverter/src/converters/ConverterMain.java 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/project.properties: -------------------------------------------------------------------------------- 1 | annotation.processing.enabled=true 2 | annotation.processing.enabled.in.editor=false 3 | annotation.processing.processor.options= 4 | annotation.processing.processors.list= 5 | annotation.processing.run.all.processors=true 6 | annotation.processing.source.output=${build.generated.sources.dir}/ap-source-output 7 | build.classes.dir=${build.dir}/classes 8 | build.classes.excludes=**/*.java,**/*.form 9 | # This directory is removed when the project is cleaned: 10 | build.dir=build 11 | build.generated.dir=${build.dir}/generated 12 | build.generated.sources.dir=${build.dir}/generated-sources 13 | # Only compile against the classpath explicitly listed here: 14 | build.sysclasspath=ignore 15 | build.test.classes.dir=${build.dir}/test/classes 16 | build.test.results.dir=${build.dir}/test/results 17 | # Uncomment to specify the preferred debugger connection transport: 18 | #debug.transport=dt_socket 19 | debug.classpath=\ 20 | ${run.classpath} 21 | debug.test.classpath=\ 22 | ${run.test.classpath} 23 | # Files in build.classes.dir which should be excluded from distribution jar 24 | dist.archive.excludes= 25 | # This directory is removed when the project is cleaned: 26 | dist.dir=dist 27 | dist.jar=${dist.dir}/Converters.jar 28 | dist.javadoc.dir=${dist.dir}/javadoc 29 | excludes= 30 | includes=** 31 | jar.compress=false 32 | javac.classpath= 33 | # Space-separated list of extra javac options 34 | javac.compilerargs= 35 | javac.deprecation=false 36 | javac.external.vm=true 37 | javac.processorpath=\ 38 | ${javac.classpath} 39 | javac.source=1.8 40 | javac.target=1.8 41 | javac.test.classpath=\ 42 | ${javac.classpath}:\ 43 | ${build.classes.dir} 44 | javac.test.processorpath=\ 45 | ${javac.test.classpath} 46 | javadoc.additionalparam= 47 | javadoc.author=false 48 | javadoc.encoding=${source.encoding} 49 | javadoc.noindex=false 50 | javadoc.nonavbar=false 51 | javadoc.notree=false 52 | javadoc.private=false 53 | javadoc.splitindex=true 54 | javadoc.use=true 55 | javadoc.version=false 56 | javadoc.windowtitle= 57 | main.class=converters.ConverterMain 58 | manifest.file=manifest.mf 59 | meta.inf.dir=${src.dir}/META-INF 60 | mkdist.disabled=false 61 | platform.active=default_platform 62 | run.classpath=\ 63 | ${javac.classpath}:\ 64 | ${build.classes.dir} 65 | # Space-separated list of JVM arguments used when running the project. 66 | # You may also define separate properties like run-sys-prop.name=value instead of -Dname=value. 67 | # To set system properties for unit tests define test-sys-prop.name=value: 68 | run.jvmargs= 69 | run.test.classpath=\ 70 | ${javac.test.classpath}:\ 71 | ${build.test.classes.dir} 72 | source.encoding=UTF-8 73 | src.dir=src 74 | test.src.dir=test 75 | -------------------------------------------------------------------------------- /DatasetConverter/nbproject/project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | org.netbeans.modules.java.j2seproject 4 | 5 | 6 | Converters 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /DatasetConverter/src/converters/Converter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * To change this license header, choose License Headers in Project Properties. 3 | * To change this template file, choose Tools | Templates 4 | * and open the template in the editor. 5 | */ 6 | package converters; 7 | 8 | import java.awt.image.BufferedImage; 9 | import java.io.BufferedReader; 10 | import java.io.File; 11 | import java.io.FileInputStream; 12 | import java.io.FileOutputStream; 13 | import java.io.FileReader; 14 | import java.io.IOException; 15 | import java.io.InputStreamReader; 16 | import java.text.SimpleDateFormat; 17 | import java.util.ArrayList; 18 | import java.util.Arrays; 19 | import java.util.Comparator; 20 | import java.util.Date; 21 | import java.util.Scanner; 22 | import java.util.concurrent.*; 23 | import java.util.logging.Level; 24 | import java.util.logging.Logger; 25 | import javax.imageio.ImageIO; 26 | 27 | /** 28 | * 29 | * @author olu 30 | * email: oluwoleoyetoke@gmail.com 31 | * Class contains methods used to 32 | * 1. Convert GTSB dataset folder names to labels for the class 33 | * 2. Convert the datasets .ppm images to .jpeg 34 | * (GTSB - German Traffic Sign Benchmark) 35 | */ 36 | public class Converter implements Runnable { 37 | String pathToFolderUnderView=null; 38 | String format; 39 | static int totalConversion; 40 | 41 | //Default constructor 42 | Converter(){ 43 | 44 | } 45 | 46 | //Constructor used when multithread operation is needed for the picture conversion 47 | Converter(String pathToFolderUnderView, String format){ 48 | this.pathToFolderUnderView= pathToFolderUnderView; 49 | this.format = format; 50 | } 51 | 52 | @Override 53 | public void run(){ 54 | System.out.println("(Multithreaded) Now Converting Contents of Folder "+this.pathToFolderUnderView); 55 | //Connect to sub folder and begin to convert all its subfiles 56 | long timeStart = System.currentTimeMillis(); 57 | 58 | File file = new File(this.pathToFolderUnderView); 59 | String extension=""; 60 | String[] fileNames = file.list(); 61 | for(int i=0; iSubfolders-->Each subfolder containing " 110 | + "specific classes of image\n" 111 | + "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg...."); 112 | 113 | //Get user confirmation 114 | Scanner scanner = new Scanner(System.in); 115 | System.out.println("Would you like to proceed? [Y/N]: "); 116 | String answer = scanner.next(); 117 | if(!answer.equals("Y")){ 118 | System.out.println("Exiting...."); 119 | return false; 120 | } 121 | 122 | //Connect to dataset base deirectory 123 | File base =null; 124 | try{ //For saftey sake 125 | base = new File(baseFolderDir); 126 | if(!base.isDirectory()){ //Check to make sure directory specified isnt just a file 127 | System.out.println("Not a directory"); 128 | return false; 129 | } 130 | }catch (Exception ex){ 131 | System.out.println("Error occured while opening directory: "+ex.getMessage()); 132 | return false; 133 | } 134 | 135 | System.out.println("Base Directory: "+baseFolderDir); 136 | 137 | //Get sub directories 138 | String[] subFiles = base.list(); 139 | File[] subFilesHandle = base.listFiles(); 140 | 141 | 142 | //Confirm that base directory has sub directories or at least, sub files 143 | int noOfContents = 0; 144 | noOfContents = subFiles.length; 145 | System.out.println("Number of sub directories or posibly files: "+noOfContents); 146 | if(noOfContents==0){ 147 | System.out.println("There are no sub files/directories in the base directory"); 148 | return false; 149 | } 150 | 151 | System.out.println("About to begin multithreaded conversion. Please note, this might take some time"); 152 | String pathToFolderUnderOperation =""; 153 | //Open each subdirectory and convert image present to desired format (Multi Threaded) 154 | //Use executors to manage concurrency 155 | ExecutorService executor = Executors.newCachedThreadPool(); 156 | for(int i=0; iSubfolders-->Each subfolder containing " 186 | + "specific classes of image\n" 187 | + "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg...."); 188 | 189 | //Get user confirmation 190 | Scanner scanner = new Scanner(System.in); 191 | System.out.println("Would you like to proceed? [Y/N]: "); 192 | String answer = scanner.next(); 193 | if(!answer.equals("Y")){ 194 | System.out.println("Exiting...."); 195 | return false; 196 | } 197 | 198 | //Validation 199 | if(baseFolderDir.isEmpty()){ 200 | System.out.println("baseFolderDir not set"); 201 | return false; 202 | }else if(labelingFileDir.isEmpty()){ 203 | System.out.println("blabelingFileDir not set"); 204 | return false; 205 | } 206 | 207 | //Try to open directory 208 | try{ //For saftey sake 209 | dir = new File(baseFolderDir); 210 | if(!dir.isDirectory()){ //Check to make sure directory specified isnt just a file 211 | System.out.println("Not a directory"); 212 | return false; 213 | } 214 | }catch (Exception ex){ 215 | System.out.println("Error occured while opening directory: "+ex.getMessage()); 216 | return false; 217 | } 218 | 219 | //Get sub directories 220 | String[] subFiles = dir.list(); 221 | File[] subFilesHandle = dir.listFiles(); 222 | 223 | // Sort files/folders handle by name 224 | Arrays.sort(subFilesHandle, new Comparator(){ 225 | @Override 226 | public int compare(Object f1, Object f2){ 227 | return ((File) f1).getName().compareTo(((File) f2).getName()); 228 | } 229 | }); 230 | Arrays.sort(subFiles); //sort files/folders string by name 231 | ArrayList subDirs = new ArrayList(); 232 | File test = null; 233 | int noOfContents = subFiles.length; 234 | for (int i=0; i getFormats() { 365 | return new ArrayList<>(Arrays.asList("jpg", "jpeg", "png", "ppm", "gif")); 366 | } 367 | 368 | private static String getFileExtensionFromPath(String path) { 369 | int i = path.lastIndexOf('.'); 370 | if (i > 0) { 371 | return path.substring(i + 1); 372 | } 373 | return ""; 374 | } 375 | 376 | private static String getOutputPathFromInputPath(String path, String format) { 377 | return path.substring(0, path.lastIndexOf('.')) + "." + format; 378 | } 379 | 380 | private static String executeCommand(String command) { 381 | 382 | StringBuilder output = new StringBuilder(); 383 | 384 | Process p; 385 | try { 386 | p = Runtime.getRuntime().exec(command); 387 | p.waitFor(); 388 | BufferedReader reader 389 | = new BufferedReader(new InputStreamReader(p.getInputStream())); 390 | 391 | String line; 392 | while ((line = reader.readLine()) != null) { 393 | output.append(line).append("\n"); 394 | } 395 | 396 | } catch (IOException | InterruptedException ex) { 397 | Logger.getLogger(Converter.class.getName()).log(Level.SEVERE, null, ex); 398 | } 399 | 400 | return output.toString(); 401 | } 402 | 403 | 404 | } 405 | -------------------------------------------------------------------------------- /DatasetConverter/src/converters/ConverterMain.java: -------------------------------------------------------------------------------- 1 | /* 2 | * To change this license header, choose License Headers in Project Properties. 3 | * To change this template file, choose Tools | Templates 4 | * and open the template in the editor. 5 | */ 6 | package converters; 7 | 8 | import java.util.Date; 9 | 10 | /** 11 | * 12 | * @author olu 13 | */ 14 | public class ConverterMain { 15 | 16 | /** 17 | * @param args the command line arguments 18 | */ 19 | public static void main(String[] args) { 20 | String baseFolderDir="/home/olu/Dev/data_base/sign_base/training"; //Change as appropriate to you 21 | String labelingFileDir="/home/olu/Dev/data_base/sign_base/labels.txt"; //Change as appropriate to you 22 | String formatName = "jpeg"; 23 | 24 | //Convert Folder Names 25 | Converter convert = new Converter(); 26 | boolean converted = convert.convertFolderName(baseFolderDir, labelingFileDir); 27 | 28 | //Convert all of datasets .ppm to .jpeg 29 | long timeStart = System.currentTimeMillis(); 30 | boolean converted2 = convert.convertAllDatasetImages(baseFolderDir, formatName); 31 | if(converted2==true){ 32 | long timeEnd = System.currentTimeMillis(); //in milliseconds 33 | long diff = timeEnd - timeStart; 34 | long diffSeconds = diff / 1000 % 60; 35 | long diffMinutes = diff / (60 * 1000) % 60; 36 | long diffHours = diff / (60 * 60 * 1000) % 24; 37 | long diffDays = diff / (24 * 60 * 60 * 1000); 38 | System.out.println("ALL "+formatName+" CONVERSIONS NOW COMPLETED. Took "+diffDays+" Day(s), "+diffHours+" Hour(s) "+diffMinutes+" Minute(s) and "+diffSeconds+" Second(s)"); 39 | } 40 | 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /DeleteExtrasFromDataSet.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | #Pass in Base Folder Location 4 | cd $1 -------------------------------------------------------------------------------- /Logs/imdbLogs: -------------------------------------------------------------------------------- 1 | ========== RESTART: /home/olu/Dev/scratch_train_sign/create_imdb.py ========== 2 | Result will be saved to /home/olu/Dev/data_base/sign_base/output 3 | Determining list of input files and labels from /home/olu/Dev/data_base/sign_base/labels.txt 4 | File path /home/olu/Dev/data_base/sign_base/training/speed_20/* 5 | 6 | File path /home/olu/Dev/data_base/sign_base/training/speed_30/* 7 | 8 | File path /home/olu/Dev/data_base/sign_base/training/speed_50/* 9 | 10 | File path /home/olu/Dev/data_base/sign_base/training/speed_60/* 11 | 12 | File path /home/olu/Dev/data_base/sign_base/training/speed_70/* 13 | 14 | File path /home/olu/Dev/data_base/sign_base/training/speed_80/* 15 | 16 | File path /home/olu/Dev/data_base/sign_base/training/speed_less_80/* 17 | 18 | File path /home/olu/Dev/data_base/sign_base/training/speed_100/* 19 | 20 | File path /home/olu/Dev/data_base/sign_base/training/speed_120/* 21 | 22 | File path /home/olu/Dev/data_base/sign_base/training/no_car_overtaking/* 23 | 24 | File path /home/olu/Dev/data_base/sign_base/training/no_truck_overtaking/* 25 | 26 | File path /home/olu/Dev/data_base/sign_base/training/priority_road/* 27 | 28 | File path /home/olu/Dev/data_base/sign_base/training/priority_road_2/* 29 | 30 | File path /home/olu/Dev/data_base/sign_base/training/yield_right_of_way/* 31 | 32 | File path /home/olu/Dev/data_base/sign_base/training/stop/* 33 | 34 | File path /home/olu/Dev/data_base/sign_base/training/road_closed/* 35 | 36 | File path /home/olu/Dev/data_base/sign_base/training/maximum_weight_allowed/* 37 | 38 | File path /home/olu/Dev/data_base/sign_base/training/entry_prohibited/* 39 | 40 | File path /home/olu/Dev/data_base/sign_base/training/danger/* 41 | 42 | File path /home/olu/Dev/data_base/sign_base/training/curve_left/* 43 | 44 | File path /home/olu/Dev/data_base/sign_base/training/curve_right/* 45 | 46 | File path /home/olu/Dev/data_base/sign_base/training/double_curve_right/* 47 | 48 | File path /home/olu/Dev/data_base/sign_base/training/rough_road/* 49 | 50 | File path /home/olu/Dev/data_base/sign_base/training/slippery_road/* 51 | 52 | File path /home/olu/Dev/data_base/sign_base/training/road_narrows_right/* 53 | 54 | File path /home/olu/Dev/data_base/sign_base/training/work_in_progress/* 55 | 56 | File path /home/olu/Dev/data_base/sign_base/training/traffic_light_ahead/* 57 | 58 | File path /home/olu/Dev/data_base/sign_base/training/pedestrian_crosswalk/* 59 | 60 | File path /home/olu/Dev/data_base/sign_base/training/children_area/* 61 | 62 | File path /home/olu/Dev/data_base/sign_base/training/bicycle_crossing/* 63 | 64 | File path /home/olu/Dev/data_base/sign_base/training/beware_of_ice/* 65 | 66 | File path /home/olu/Dev/data_base/sign_base/training/wild_animal_crossing/* 67 | 68 | File path /home/olu/Dev/data_base/sign_base/training/end_of_restriction/* 69 | 70 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_right/* 71 | 72 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_left/* 73 | 74 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight/* 75 | 76 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_right/* 77 | 78 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_left/* 79 | 80 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle/* 81 | 82 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle2/* 83 | 84 | File path /home/olu/Dev/data_base/sign_base/training/traffic_circle/* 85 | 86 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_car_overtaking/* 87 | 88 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_truck_overtaking/* 89 | 90 | Found 39209 JPEG files across 43 labels inside /home/olu/Dev/data_base/sign_base/training 91 | Launching 2 threads for spacings: [[0, 19604], [19604, 39209]] 92 | 2017-12-01 15:56:39.543697 [thread 1]: Processed 1000 of 19605 images in thread batch.2017-12-01 15:56:39.545929 [thread 0]: Processed 1000 of 19604 images in thread batch. 93 | 94 | 2017-12-01 15:56:44.254576 [thread 0]: Processed 2000 of 19604 images in thread batch. 95 | 2017-12-01 15:56:44.433656 [thread 1]: Processed 2000 of 19605 images in thread batch. 96 | 2017-12-01 15:56:48.754633 [thread 0]: Processed 3000 of 19604 images in thread batch. 97 | 2017-12-01 15:56:49.798092 [thread 1]: Processed 3000 of 19605 images in thread batch. 98 | 2017-12-01 15:56:53.908201 [thread 0]: Processed 4000 of 19604 images in thread batch. 99 | 2017-12-01 15:56:54.351327 [thread 1]: Processed 4000 of 19605 images in thread batch. 100 | 2017-12-01 15:56:58.855722 [thread 0]: Processed 5000 of 19604 images in thread batch. 101 | 2017-12-01 15:56:59.301768 [thread 1]: Processed 5000 of 19605 images in thread batch. 102 | 2017-12-01 15:57:03.704445 [thread 0]: Processed 6000 of 19604 images in thread batch. 103 | 2017-12-01 15:57:04.144965 [thread 1]: Processed 6000 of 19605 images in thread batch. 104 | 2017-12-01 15:57:08.543383 [thread 0]: Processed 7000 of 19604 images in thread batch. 105 | 2017-12-01 15:57:09.007709 [thread 1]: Processed 7000 of 19605 images in thread batch. 106 | 2017-12-01 15:57:13.622921 [thread 0]: Processed 8000 of 19604 images in thread batch. 107 | 2017-12-01 15:57:14.059550 [thread 1]: Processed 8000 of 19605 images in thread batch. 108 | 2017-12-01 15:57:18.477144 [thread 0]: Processed 9000 of 19604 images in thread batch. 109 | 2017-12-01 15:57:18.811312 [thread 1]: Processed 9000 of 19605 images in thread batch. 110 | 2017-12-01 15:57:23.763258 [thread 0]: Processed 10000 of 19604 images in thread batch. 111 | 2017-12-01 15:57:24.086001 [thread 1]: Processed 10000 of 19605 images in thread batch. 112 | 2017-12-01 15:57:28.965086 [thread 0]: Processed 11000 of 19604 images in thread batch. 113 | 2017-12-01 15:57:29.261494 [thread 1]: Processed 11000 of 19605 images in thread batch. 114 | 2017-12-01 15:57:34.061319 [thread 0]: Processed 12000 of 19604 images in thread batch. 115 | 2017-12-01 15:57:34.177079 [thread 1]: Processed 12000 of 19605 images in thread batch. 116 | 2017-12-01 15:57:39.415580 [thread 1]: Processed 13000 of 19605 images in thread batch.2017-12-01 15:57:39.418505 [thread 0]: Processed 13000 of 19604 images in thread batch. 117 | 118 | 2017-12-01 15:57:44.575577 [thread 1]: Processed 14000 of 19605 images in thread batch. 119 | 2017-12-01 15:57:44.973289 [thread 0]: Processed 14000 of 19604 images in thread batch. 120 | 2017-12-01 15:57:49.592148 [thread 1]: Processed 15000 of 19605 images in thread batch. 121 | 2017-12-01 15:57:50.393305 [thread 0]: Processed 15000 of 19604 images in thread batch. 122 | 2017-12-01 15:57:54.668816 [thread 1]: Processed 16000 of 19605 images in thread batch. 123 | 2017-12-01 15:57:55.473357 [thread 0]: Processed 16000 of 19604 images in thread batch. 124 | 2017-12-01 15:57:59.705012 [thread 1]: Processed 17000 of 19605 images in thread batch. 125 | 2017-12-01 15:58:00.571395 [thread 0]: Processed 17000 of 19604 images in thread batch. 126 | 2017-12-01 15:58:04.590437 [thread 1]: Processed 18000 of 19605 images in thread batch. 127 | 2017-12-01 15:58:05.412019 [thread 0]: Processed 18000 of 19604 images in thread batch. 128 | 2017-12-01 15:58:09.393710 [thread 1]: Processed 19000 of 19605 images in thread batch. 129 | 2017-12-01 15:58:10.293427 [thread 0]: Processed 19000 of 19604 images in thread batch. 130 | 2017-12-01 15:58:12.288467 [thread 1]: Wrote 19605 images to /home/olu/Dev/data_base/sign_base/output/validation-00001-of-00002 131 | 2017-12-01 15:58:12.326821 [thread 1]: Wrote 19605 images to 19605 shards. 132 | 2017-12-01 15:58:12.793973 [thread 0]: Wrote 19604 images to /home/olu/Dev/data_base/sign_base/output/validation-00000-of-00002 133 | 2017-12-01 15:58:12.840278 [thread 0]: Wrote 19604 images to 19604 shards. 134 | 2017-12-01 15:58:13.796779: Finished writing all 39209 images in data set. 135 | Determining list of input files and labels from /home/olu/Dev/data_base/sign_base/labels.txt 136 | File path /home/olu/Dev/data_base/sign_base/training/speed_20/* 137 | 138 | File path /home/olu/Dev/data_base/sign_base/training/speed_30/* 139 | 140 | File path /home/olu/Dev/data_base/sign_base/training/speed_50/* 141 | 142 | File path /home/olu/Dev/data_base/sign_base/training/speed_60/* 143 | 144 | File path /home/olu/Dev/data_base/sign_base/training/speed_70/* 145 | 146 | File path /home/olu/Dev/data_base/sign_base/training/speed_80/* 147 | 148 | File path /home/olu/Dev/data_base/sign_base/training/speed_less_80/* 149 | 150 | File path /home/olu/Dev/data_base/sign_base/training/speed_100/* 151 | 152 | File path /home/olu/Dev/data_base/sign_base/training/speed_120/* 153 | 154 | File path /home/olu/Dev/data_base/sign_base/training/no_car_overtaking/* 155 | 156 | File path /home/olu/Dev/data_base/sign_base/training/no_truck_overtaking/* 157 | 158 | File path /home/olu/Dev/data_base/sign_base/training/priority_road/* 159 | 160 | File path /home/olu/Dev/data_base/sign_base/training/priority_road_2/* 161 | 162 | File path /home/olu/Dev/data_base/sign_base/training/yield_right_of_way/* 163 | 164 | File path /home/olu/Dev/data_base/sign_base/training/stop/* 165 | 166 | File path /home/olu/Dev/data_base/sign_base/training/road_closed/* 167 | 168 | File path /home/olu/Dev/data_base/sign_base/training/maximum_weight_allowed/* 169 | 170 | File path /home/olu/Dev/data_base/sign_base/training/entry_prohibited/* 171 | 172 | File path /home/olu/Dev/data_base/sign_base/training/danger/* 173 | 174 | File path /home/olu/Dev/data_base/sign_base/training/curve_left/* 175 | 176 | File path /home/olu/Dev/data_base/sign_base/training/curve_right/* 177 | 178 | File path /home/olu/Dev/data_base/sign_base/training/double_curve_right/* 179 | 180 | File path /home/olu/Dev/data_base/sign_base/training/rough_road/* 181 | 182 | File path /home/olu/Dev/data_base/sign_base/training/slippery_road/* 183 | 184 | File path /home/olu/Dev/data_base/sign_base/training/road_narrows_right/* 185 | 186 | File path /home/olu/Dev/data_base/sign_base/training/work_in_progress/* 187 | 188 | File path /home/olu/Dev/data_base/sign_base/training/traffic_light_ahead/* 189 | 190 | File path /home/olu/Dev/data_base/sign_base/training/pedestrian_crosswalk/* 191 | 192 | File path /home/olu/Dev/data_base/sign_base/training/children_area/* 193 | 194 | File path /home/olu/Dev/data_base/sign_base/training/bicycle_crossing/* 195 | 196 | File path /home/olu/Dev/data_base/sign_base/training/beware_of_ice/* 197 | 198 | File path /home/olu/Dev/data_base/sign_base/training/wild_animal_crossing/* 199 | 200 | File path /home/olu/Dev/data_base/sign_base/training/end_of_restriction/* 201 | 202 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_right/* 203 | 204 | File path /home/olu/Dev/data_base/sign_base/training/must_turn_left/* 205 | 206 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight/* 207 | 208 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_right/* 209 | 210 | File path /home/olu/Dev/data_base/sign_base/training/must_go_straight_or_left/* 211 | 212 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle/* 213 | 214 | File path /home/olu/Dev/data_base/sign_base/training/mandatroy_direction_bypass_obstacle2/* 215 | 216 | File path /home/olu/Dev/data_base/sign_base/training/traffic_circle/* 217 | 218 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_car_overtaking/* 219 | 220 | File path /home/olu/Dev/data_base/sign_base/training/end_of_no_truck_overtaking/* 221 | 222 | Found 39209 JPEG files across 43 labels inside /home/olu/Dev/data_base/sign_base/training 223 | Launching 2 threads for spacings: [[0, 19604], [19604, 39209]] 224 | 2017-12-01 15:58:19.321528 [thread 1]: Processed 1000 of 19605 images in thread batch.2017-12-01 15:58:19.328218 [thread 0]: Processed 1000 of 19604 images in thread batch. 225 | 226 | 2017-12-01 15:58:24.103518 [thread 1]: Processed 2000 of 19605 images in thread batch. 227 | 2017-12-01 15:58:24.379995 [thread 0]: Processed 2000 of 19604 images in thread batch. 228 | 2017-12-01 15:58:28.848063 [thread 1]: Processed 3000 of 19605 images in thread batch. 229 | 2017-12-01 15:58:29.290989 [thread 0]: Processed 3000 of 19604 images in thread batch. 230 | 2017-12-01 15:58:33.692279 [thread 1]: Processed 4000 of 19605 images in thread batch. 231 | 2017-12-01 15:58:34.348252 [thread 0]: Processed 4000 of 19604 images in thread batch. 232 | 2017-12-01 15:58:38.885082 [thread 1]: Processed 5000 of 19605 images in thread batch. 233 | 2017-12-01 15:58:39.212437 [thread 0]: Processed 5000 of 19604 images in thread batch. 234 | 2017-12-01 15:58:44.014678 [thread 1]: Processed 6000 of 19605 images in thread batch. 235 | 2017-12-01 15:58:44.319710 [thread 0]: Processed 6000 of 19604 images in thread batch. 236 | 2017-12-01 15:58:48.727675 [thread 1]: Processed 7000 of 19605 images in thread batch. 237 | 2017-12-01 15:58:49.454886 [thread 0]: Processed 7000 of 19604 images in thread batch. 238 | 2017-12-01 15:58:53.757572 [thread 1]: Processed 8000 of 19605 images in thread batch. 239 | 2017-12-01 15:58:54.732761 [thread 0]: Processed 8000 of 19604 images in thread batch. 240 | 2017-12-01 15:58:59.082209 [thread 1]: Processed 9000 of 19605 images in thread batch. 241 | 2017-12-01 15:58:59.685727 [thread 0]: Processed 9000 of 19604 images in thread batch. 242 | 2017-12-01 15:59:03.855798 [thread 1]: Processed 10000 of 19605 images in thread batch. 243 | 2017-12-01 15:59:04.623629 [thread 0]: Processed 10000 of 19604 images in thread batch. 244 | 2017-12-01 15:59:08.897193 [thread 1]: Processed 11000 of 19605 images in thread batch. 245 | 2017-12-01 15:59:09.754831 [thread 0]: Processed 11000 of 19604 images in thread batch. 246 | 2017-12-01 15:59:13.925373 [thread 1]: Processed 12000 of 19605 images in thread batch. 247 | 2017-12-01 15:59:14.625715 [thread 0]: Processed 12000 of 19604 images in thread batch. 248 | 2017-12-01 15:59:18.758661 [thread 1]: Processed 13000 of 19605 images in thread batch. 249 | 2017-12-01 15:59:19.334345 [thread 0]: Processed 13000 of 19604 images in thread batch. 250 | 2017-12-01 15:59:23.557197 [thread 1]: Processed 14000 of 19605 images in thread batch. 251 | 2017-12-01 15:59:24.091283 [thread 0]: Processed 14000 of 19604 images in thread batch. 252 | 2017-12-01 15:59:28.380125 [thread 1]: Processed 15000 of 19605 images in thread batch. 253 | 2017-12-01 15:59:28.939316 [thread 0]: Processed 15000 of 19604 images in thread batch. 254 | 2017-12-01 15:59:33.511671 [thread 1]: Processed 16000 of 19605 images in thread batch. 255 | 2017-12-01 15:59:33.791349 [thread 0]: Processed 16000 of 19604 images in thread batch. 256 | 2017-12-01 15:59:38.410925 [thread 1]: Processed 17000 of 19605 images in thread batch. 257 | 2017-12-01 15:59:38.555400 [thread 0]: Processed 17000 of 19604 images in thread batch. 258 | 2017-12-01 15:59:43.261218 [thread 1]: Processed 18000 of 19605 images in thread batch. 259 | 2017-12-01 15:59:43.384039 [thread 0]: Processed 18000 of 19604 images in thread batch. 260 | 2017-12-01 15:59:48.125398 [thread 1]: Processed 19000 of 19605 images in thread batch. 261 | 2017-12-01 15:59:48.189964 [thread 0]: Processed 19000 of 19604 images in thread batch. 262 | 2017-12-01 15:59:50.915470 [thread 1]: Wrote 19605 images to /home/olu/Dev/data_base/sign_base/output/train-00001-of-00002 263 | 2017-12-01 15:59:50.925460 [thread 1]: Wrote 19605 images to 19605 shards. 264 | 2017-12-01 15:59:51.064479 [thread 0]: Wrote 19604 images to /home/olu/Dev/data_base/sign_base/output/train-00000-of-00002 265 | 2017-12-01 15:59:51.069508 [thread 0]: Wrote 19604 images to 19604 shards. 266 | 2017-12-01 15:59:51.536970: Finished writing all 39209 images in data set. 267 | >>> 268 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Computer_Vision_Using_TensorFlowLite 2 | 3 | On this project the AlexNet Convolutional Neural Network is trained using traffic sign images from the German Road Traffic Sign Benchmark. The initially trained network is then quantized/optimized for deployment on mobile devices using TensorFlowLite 4 | 5 | ## Project Steps 6 | 1. Download German Traffic Sign Benchmark Training Images [from here](https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip) 7 | 2. Convert From .ppm to .jpg Format 8 | 3. Label db Folder Appropriately 9 | 4. Convert Dataset to TFRecord 10 | 5. Create CNN (alexnet) 11 | 6. Train CNN Using TFRecord Data 12 | 7. Test & Evaluate Trained Model 13 | 8. Quantize/Tune/Optimize trained model for mobile deployment 14 | 9. Test & Evaluate Tuned Model 15 | 10. Deploy Tuned Model to Mobile Platform [Here](https://github.com/OluwoleOyetoke/Accelerated-Android-Vision) 16 | 17 | ### Steps 1, 2 & 3: Get Dataset, Convert (from .ppm to .jpeg) & Label Appropriately 18 | The **DatasetConverter** folder contains the java program written to go through the GTSBR dataset, rename all its folders using the class name of the sets of images the folders contain. It also converts the image file types from .ppm to .jpeg 19 | 20 | ### Step 4: Convert Dataset to TFRecord 21 | With TensorFlow, we can store our whole dataset and all of its meta-data as a serialized record called **TFRecord**. Loading up our data into this format helps for portability and simplicity. Also note that this record can be broken into multiple shards and used when performing distributed training. A 'TFRecord' in TensorFlow is basically TensorFlow's default data format. A record file containing serialized tf.train.examples. To avoid confusion, a TensorFlow 'example' is a normalized data format for storing data for training and inference purposes. It contains a key-value store where each key string maps to a feature message. These feature messages can be things like a packed byte list, float list, int64 list. Note that many 'examples' come together to form a TFRecord. The script **create_imdb.py** is used to make the TFRecords out of the dataset. To learn more about creating a TFRecord file or streaming out data from it, see these posts [TensorFlow TFRecord Blog Post 1](http://eagle-beacon.com/blog/posts/Loading_And_Poping_TFRecords.html) here & [TensorFlow TFRecord Blog Post 2](http://eagle-beacon.com/blog/posts/Loading_And_Poping_TFRecords_P2.html) 22 | 23 | ### Step 5: Create AlexNet Structure 24 | The script **train_alexnet.py** is used to create and train the CNN. See TensorBoard visualization of AlexNet structure below: 25 | 26 | 27 | ![Network Visualization](https://github.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/blob/master/imgs/network_visualization.png) 28 | 29 | **Figure Showing TesnorBoard Visualization of the Network** 30 | 31 | 32 | ### Step 6: Train CNN Using TFRecord Data 33 | The figures below show the drop in loss of the network as training progressed. The Adam Optimizer was used. The Figure below shows the loss reducing per epoch 34 | ![Loss Per Epoch](https://github.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/blob/master/imgs/accuracy_per_epoch.png) 35 | 36 | **Figure Showing Improvement In Network Performance Per Epoch (Epoch 0-20)** 37 | 38 | After the full training procedure, the trained model performed at **over 98% accuracy** 39 | 40 | ### Step 7: Test & Evaluate Trained Model 41 | To test the trained network model, the script **classify_img_arg.py** can be used. 42 | Usage format: 43 | 44 | ``` 45 | 46 | $python one_time_classify.py [saved_model_directory] [path_to_image] 47 | 48 | ``` 49 | 50 | ### Step 8: Quantize/Tune/Optimize Trained Model for Mobile Deployment 51 | Here, the TensorFlow Graph Transform tool comes in handy in helping us shrink the size of the model and making it deployable on mobile. The transform tools are designed to work on models that are saved as GraphDef files, usually in a binary protobuf format, the low-level definition of a TensorFlow computational graph which includes a list of nodes and the input and output connections between them. But firstly before we start any form of transformation/optimization of the model, it is wise to freez the graph, i.e. making sure that trained weights are fused with the graph and converting these weights into embedded constants within the graph file itself. To do this, we will need to run the [reeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) script. 52 | 53 | The various transforms we can perform on our graph include striping unused nodes, removing unused nodes, folding batch norm layers etc. 54 | 55 | In a nutshell, to achieve our conversion of the Tensorflow .pb model to a TensorFlowLite .lite model, we will: 56 | 1. Freeze the grpah i.e merge checkpoint values with graph stucture. In other words, load variables into the graph and convert them to constants 57 | 2. Perform one or two simple transformations/optimizations on the froze model 58 | 3. Convert the frozen graph definition into the the [flat buffer format](https://google.github.io/flatbuffers/) (.lite) 59 | 60 | The various transforms we can perform on our graph include stiping unused nodes, remving unused nodes, folding batch norm layers etc. 61 | 62 | #### Step 8.1: Sample Transform Definition and A Little Bit of Explanation 63 | Note that these transform options are key to shrinking the model file size 64 | 65 | ``` 66 | transforms = 'strip_unused_nodes(type=float, shape="1,299,299,3") 67 | remove_nodes(op=Identity, op=CheckNumerics) 68 | fold_constants(ignore_errors=true) 69 | fold_batch_norms fold_old_batch_norms 70 | round_weights(num_steps=256) 71 | quantize_weights obfuscate_names 72 | quantize_nodes sort_by_execution_order' 73 | ``` 74 | 75 | **round_weights(num_steps=256)** - Essentially, this rounds the weights so that nearby numbers are stored as exactly the same values, the resulting bit stream has a lot more repetition and so compresses down a lot more effectively. Compressed tflite model can be as small as 70-% less the size of the original model. The nice thing about this transform option is that it doesn't change the structure of the graph at all, so it's running exactly the same operations and should have the same latency and memory usage as before. We can adjust the num_steps parameter to control how many values each weight buffer is rounded to, so lower numbers will increase the compression at the cost of accuracy 76 | 77 | **quantize_weights** - Storing the weights of the model as 8-bit will drastically reduce its size, however, we will be trading off inference accuracy here. 78 | 79 | **obfuscate_names** - Supposing our graph has a lot of small nodes in it, the names can also end up being a cause of space utilization. Using this option will help cut that down. 80 | 81 | For some platforms it is very helpful to be able to do as many calculations as possible in eight-bit, rather than floating-point. a few transform option are available to help achieve this, although in most cases some modification will need to be made to the actual training code to make this transofrm happen effectively. 82 | 83 | #### 8.2: Transformation & Conversion Execution 84 | As we know, TensorFlowLite is still in its early development stage and many new features are being added daily. Conversely, there are some TensorFlow operation which are not currently fully supported at the moment [e.g ArgMax](https://github.com/tensorflow/tensorflow/issues/15948). As a matter of fact, at the moment, [models that were quantized using transform_graph are not supported by TF Lite ](https://github.com/tensorflow/tensorflow/issues/15871#issuecomment-356419505). However, on the brighter note, we can still convert our TensorFlow custom alexnet model to a TFLite model, but we will need to turn of some transform options such as 'quantize_weights', 'quantize_nodes' and keep our model inference, input types as FLOATs. In this case, model size would not really change. 85 | 86 | Code excerpt below from the [freeze_and_convert_to_tflite.sh](https://github.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/blob/master/freeze_and_convert_to_tflite.sh) script shows how this .tflite conversion is achieved. 87 | 88 | ``` 89 | #FREEZE GRAPH 90 | bazel build tensorflow/python/tools:freeze_graph 91 | bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$1 --input_checkpoint=$2 --input_binary=$3 --output_graph=$4 --output_node_names=$6 92 | 93 | 94 | #VIEW SUMARY OF GRAPH 95 | bazel build /tensorflow/tools/graph_transforms:summarize_graph 96 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=$4 97 | 98 | #TRANSFORM/OPTIMIZE GRAPH 99 | bazel build tensorflow/tools/graph_transforms:transform_graph 100 | bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$4 --out_graph=${10} --inputs=$5 --outputs=$6 --transforms='strip_unused_nodes(type=float, shape="1,227,227,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) round_weights(num_steps=256) obfuscate_names sort_by_execution_order fold_old_batch_norms fold_batch_norms' 101 | 102 | 103 | #CONVERT TO TFLITE MODEL 104 | bazel build /tensorflow/contrib/lite/toco:toco 105 | bazel-bin/tensorflow/contrib/lite/toco/toco --input_format=TENSORFLOW_GRAPHDEF --input_file=${10} --output_format=TFLITE --output_file=$7 --inference_type=$9 --#input_type=$8 --input_arrays=$5 --output_arrays=$6 --inference_input_type=$8 --input_shapes=1,227,227,3 106 | 107 | ``` 108 | 109 | Note, this will require you to build TensorFlow from source. You can get instructions to do this from [here](https://www.tensorflow.org/install/install_sources) 110 | -------------------------------------------------------------------------------- /classify_img.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script passes raw image to an already trained model to get prediction 3 | 4 | @date: 20th December, 2017 5 | @author: Oluwole Oyetoke 6 | @Language: Python 7 | @email: oluwoleoyetoke@gmail.com 8 | 9 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING 10 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | from sklearn.model_selection import train_test_split 15 | 16 | import os 17 | import math 18 | import time 19 | import numpy as np 20 | import tensorflow as tf #import tensorflow 21 | import matplotlib.pyplot as plt 22 | from PIL import Image 23 | 24 | flags = tf.app.flags 25 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width") 26 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height") 27 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels") 28 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases") 29 | FLAGS = flags.FLAGS 30 | 31 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN) 32 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 33 | 34 | 35 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER 36 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 37 | def main(unused_argv): 38 | 39 | #Specify checkpoint & image directory 40 | checkpoint_directory="/home/olu/Dev/data_base/sign_base/backup/Checkpoints_N_Model_Epoch_34__copy/trained_alexnet_model" 41 | filename="/home/olu/Dev/data_base/sign_base/training_227x227/road_closed/00002_00005.jpeg" 42 | 43 | #Process image to be sent to Neural Net 44 | #im = np.array(Image.open(filename)) 45 | #img_batch = im.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels) 46 | 47 | img = Image.open(filename) 48 | img_resized = img.resize((227, 227), Image.ANTIALIAS) 49 | img_batch_np = np.array(img_resized) 50 | img_batch = img_batch_np.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels) 51 | plt.imshow(img_batch_np) 52 | 53 | #Declare categories/classes as string 54 | categories = ["speed_20", "speed_30","speed_50","speed_60","speed_70", 55 | "speed_80","speed_less_80","speed_100","speed_120", 56 | "no_car_overtaking","no_truck_overtaking","priority_road", 57 | "priority_road_2","yield_right_of_way","stop","road_closed", 58 | "maximum_weight_allowed","entry_prohibited","danger","curve_left", 59 | "curve_right","double_curve_right","rough_road","slippery_road", 60 | "road_narrows_right","work_in_progress","traffic_light_ahead", 61 | "pedestrian_crosswalk","children_area","bicycle_crossing", 62 | "beware_of_ice","wild_animal_crossing","end_of_restriction", 63 | "must_turn_right","must_turn_left","must_go_straight", 64 | "must_go_straight_or_right","must_go_straight_or_left", 65 | "mandatroy_direction_bypass_obstacle", 66 | "mandatroy_direction_bypass_obstacle2", 67 | "traffic_circle","end_of_no_car_overtaking", 68 | "end_of_no_truck_overtaking"]; 69 | 70 | 71 | #Recreate network graph. 72 | sess = tf.Session() 73 | latest_checkpoint_name = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_directory) 74 | saver = tf.train.import_meta_graph(latest_checkpoint_name+'.meta') #At this step only graph is created. 75 | 76 | #Accessing the default graph which we have restored 77 | graph = tf.get_default_graph() 78 | 79 | #Get model's graph 80 | checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory) 81 | saver.restore(sess, checkpoint_file) #Load the weights saved using the restore method. 82 | 83 | probabilities = graph.get_tensor_by_name("softmax_tensor:0") 84 | classes = graph.get_tensor_by_name("classes_tensor:0") #'ArgMax:0' is the name of the argmax tensor in the train_alexnet.py file. 85 | feed_dict = {"input_layer:0": img_batch} #'Reshape:0' is the name of the 'input_layer' tensor in the train_alexnet.py. Given to it as default. 86 | predicted_class = sess.run(classes, feed_dict) 87 | predicted_probabilities = sess.run(probabilities, feed_dict) 88 | assurance = predicted_probabilities[0,int(predicted_class)]*100; 89 | 90 | print("Predicted Sign: ", categories[int(predicted_class)], " (With ", assurance," Percent Assurance)") 91 | print("finished") 92 | plt.show() 93 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 94 | 95 | 96 | if __name__=="__main__": 97 | tf.app.run() 98 | 99 | -------------------------------------------------------------------------------- /classify_img_arg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script passes raw image to an already trained model to get prediction 3 | 4 | @date: 20th December, 2017 5 | @Language: Python 6 | 7 | usage = $python classify_img_arg.py [saved_model_directory] [path_to_image] 8 | 9 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING 10 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 11 | 12 | 13 | from __future__ import absolute_import 14 | from __future__ import print_function 15 | 16 | import sys 17 | import os 18 | import math 19 | import time 20 | import numpy as np 21 | import tensorflow as tf #import tensorflow 22 | import matplotlib.pyplot as plt 23 | from PIL import Image 24 | 25 | flags = tf.app.flags 26 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width") 27 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height") 28 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels") 29 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases") 30 | FLAGS = flags.FLAGS 31 | 32 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN) 33 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 34 | 35 | 36 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER 37 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 38 | #Specify checkpoint & image directory 39 | checkpoint_directory= sys.argv[1] # "/home/olu/Dev/data_base/sign_base/backup/Checkpoints_N_Model2-After_Epoch_20_copy/trained_alexnet_model" 40 | filename= sys.argv[2] # "/home/olu/Dev/data_base/sign_base/training_227x227/road_closed/00002_00005.jpeg" 41 | 42 | #Declare categories/classes as string 43 | categories = ["speed_20", "speed_30","speed_50","speed_60","speed_70", 44 | "speed_80","speed_less_80","speed_100","speed_120", 45 | "no_car_overtaking","no_truck_overtaking","priority_road", 46 | "priority_road_2","yield_right_of_way","stop","road_closed", 47 | "maximum_weight_allowed","entry_prohibited","danger","curve_left", 48 | "curve_right","double_curve_right","rough_road","slippery_road", 49 | "road_narrows_right","work_in_progress","traffic_light_ahead", 50 | "pedestrian_crosswalk","children_area","bicycle_crossing", 51 | "beware_of_ice","wild_animal_crossing","end_of_restriction", 52 | "must_turn_right","must_turn_left","must_go_straight", 53 | "must_go_straight_or_right","must_go_straight_or_left", 54 | "mandatroy_direction_bypass_obstacle", 55 | "mandatroy_direction_bypass_obstacle2", 56 | "traffic_circle","end_of_no_car_overtaking", 57 | "end_of_no_truck_overtaking"]; 58 | 59 | print("resizing image.....") 60 | #Process image to be sent to Neural Net 61 | img = Image.open(filename) 62 | img_resized = img.resize((FLAGS.image_width, FLAGS.image_height), Image.ANTIALIAS) 63 | img_batch_np = np.array(img_resized) 64 | plt.imshow(img_batch_np) 65 | img_batch = img_batch_np.reshape(1, FLAGS.image_width, FLAGS.image_height,FLAGS.image_channels) 66 | 67 | print("loading network graph.....") 68 | #Recreate network graph. 69 | sess = tf.Session() 70 | latest_checkpoint_name = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_directory) 71 | saver = tf.train.import_meta_graph(latest_checkpoint_name+'.meta') #At this step only graph is created. 72 | 73 | #Accessing the default graph which we have restored 74 | graph = tf.get_default_graph() 75 | 76 | print("loading network weights.....") 77 | #Get model's graph 78 | checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory) 79 | saver.restore(sess, checkpoint_file) #Load the weights saved using the restore method. 80 | 81 | print("classification process started.....") 82 | start = time.time() 83 | probabilities = graph.get_tensor_by_name("softmax_tensor:0") 84 | classes = graph.get_tensor_by_name("ArgMax:0") #'ArgMax:0' is the name of the argmax tensor in the train_alexnet.py file. 85 | feed_dict = {"Reshape:0": img_batch} #'Reshape:0' is the name of the 'input_layer' tensor in the train_alexnet.py. Given to it as default. 86 | predicted_class = sess.run(classes, feed_dict) 87 | predicted_probabilities = sess.run(probabilities, feed_dict) 88 | assurance = predicted_probabilities[0,int(predicted_class)]*100; 89 | end = time.time() 90 | difference = end-start 91 | difference_milli = difference*1000 92 | 93 | print("Predicted Sign: '", categories[int(predicted_class)], "' With ", assurance," Percent Assurance") 94 | print("Time Taken For Classification: %f millisecond(s)" % difference_milli) 95 | plt.show() 96 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 97 | -------------------------------------------------------------------------------- /create_imdb.py: -------------------------------------------------------------------------------- 1 | """-------------------------------------------------------------------------------------------------------------------------------------------------- 2 | REFERENCE: 3 | ---------- 4 | Code adapted from Google Tensor FLow Git Hub Repositiory: 5 | https://github.com/tensorflow/models/blob/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/inception/inception/data/build_image_data.py 6 | 7 | @author: Oluwole Oyetoke 8 | @date: 5th December, 2017 9 | @langauge: Python/TF 10 | @email: oluwoleoyetoke@gmail.com 11 | 12 | INTRODUCTION: 13 | ------------- 14 | Converts image dataset to a sharded dataset. The sharded dataset consists of Tensor Flow Records Format (TFRecords) and with Example protos. 15 | train_directory/train-00000-of-01024 16 | train_directory/train-00001-of-01024 17 | ... 18 | train_directory/train-01023-of-01024 19 | and 20 | validation_directory/validation-00000-of-00128 21 | validation_directory/validation-00001-of-00128 22 | ... 23 | validation_directory/validation-00127-of-00128 24 | 25 | EXPECTATIONS: 26 | ------------ 27 | 1. Image data set should be in .jpeg format 28 | 2. It is adviced that you have only folders in the base directory containing your training images." 29 | "Base folder-->Subfolders-->Each subfolder containing specific classes of image." 30 | "E.g Training Folder -> stop_sign_folder -> 1.jpg, 2.jpg, 3.jpg...."; 31 | (data_dir/label_0/image0.jpeg 32 | (data_dir/label_0/image1.jpg) 33 | 3.The sub-directory should be the unique label associated with the images in the folder. 34 | 35 | 36 | SHARDS CONTENT: 37 | -------------- 38 | Where we have selected [x] number of image files per training dataset shard and [y] number of image files per evaluation dataset shard, 39 | for each of the shards, each record within the TFRecord file (shard) is a serialized example proto consisting of the following fields: 40 | image/encoded: string containing JPEG encoded image in RGB colorspace 41 | image/height: integer, image height in pixels 42 | image/width: integer, image width in pixels 43 | image/colorspace: string, specifying the colorspace, always 'RGB' 44 | image/channels: integer, specifying the number of channels, always 3 45 | image/format: string, specifying the format, always 'JPEG' 46 | image/filename: string containing the basename of the image file e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG' 47 | image/class/label:integer specifying the index in a classification layer. The label ranges from [0, num_labels] where 0 is unused and left as the background class. 48 | image/class/text: string specifying the human-readable version of the label e.g. 'dog' 49 | If your data set involves bounding boxes, please look at build_imagenet_data.py. 50 | 51 | 52 | 53 | 54 | #IMPORTS 55 | -------------------------------------------------------------------------------------------------------------------------------------------------------""" 56 | from __future__ import absolute_import 57 | from __future__ import division 58 | from __future__ import print_function 59 | 60 | from datetime import datetime 61 | from PIL import Image 62 | from time import sleep 63 | import os 64 | import random 65 | import sys 66 | import threading 67 | 68 | import numpy as np 69 | import tensorflow as tf 70 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 71 | 72 | 73 | 74 | 75 | # SETTING SOME GLOBAL DATA 76 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 77 | tf.app.flags.DEFINE_string('train_directory', '/home/olu/Dev/data_base/sign_base/training_227x227', 'Training data directory') 78 | tf.app.flags.DEFINE_string('validation_directory', '/home/olu/Dev/data_base/sign_base/training_227x227', 'Validation data directory') 79 | tf.app.flags.DEFINE_string('output_directory', '/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227', 'Output data directory') 80 | tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.') 81 | tf.app.flags.DEFINE_integer('validation_shards', 2, 'Number of shards in validation TFRecord files.') 82 | tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.') 83 | tf.app.flags.DEFINE_string('labels_file', '/home/olu/Dev/data_base/sign_base/labels.txt', 'Labels_file.txt') 84 | tf.app.flags.DEFINE_integer("image_height", 227, "Height of the output image after crop and resize.") #Alexnet takes 227 x 227 image input 85 | tf.app.flags.DEFINE_integer("image_width", 227, "Width of the output image after crop and resize.") 86 | FLAGS = tf.app.flags.FLAGS 87 | """ The labels file contains a list of valid labels are held in this file. The file contains entries such as: 88 | speed_100 89 | speed_120 90 | no_car_overtaking 91 | no_truck_overtaking 92 | Each line corresponds to a label, and each label (per line) is mapped to an integer corresponding to the line number starting from 0. 93 | 94 | 95 | -------------------------------------------------------------------------------------------------------------------------------------------------------""" 96 | 97 | 98 | 99 | 100 | # WRAPPER FOR INSERTING int64 FEATURES int64 FEATURES & BYTES FEATURES INTO EXAMPLES PROTO 101 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 102 | def _int64_feature(value): 103 | if not isinstance(value, list): 104 | value = [value] 105 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 106 | 107 | def _bytes_feature(value): 108 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 109 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 110 | 111 | 112 | 113 | 114 | 115 | #FUNCTION FOR BUILDING A PROTO 116 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 117 | def _convert_to_example(filename, image_buffer, label, text, shape_buffer): 118 | """Build an Example proto for an example. 119 | Args: 120 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 121 | image_buffer: string, JPEG encoding of RGB image 122 | label: integer, identifier for the ground truth for the network 123 | text: string, unique human-readable, e.g. 'dog' 124 | height: integer, image height in pixels 125 | width: integer, image width in pixels 126 | Returns: 127 | Example proto 128 | """ 129 | 130 | colorspace = 'RGB' 131 | channels = 3 132 | image_format = 'JPEG' 133 | #Save TFrecord containing image_bytes, shape [337,337,3], label, text, filename 134 | example = tf.train.Example(features=tf.train.Features(feature={ 135 | 'image/shape': _bytes_feature(shape_buffer), 136 | 'image/class/label': _int64_feature(label), 137 | 'image/class/text': _bytes_feature(tf.compat.as_bytes(text)), 138 | 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))), 139 | 'image/encoded': _bytes_feature(image_buffer)})) 140 | return example 141 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 142 | 143 | 144 | 145 | 146 | #CLASS WIH FUNCTIONS TO HELP ENCODE & DECODE IMAGES 147 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 148 | class ImageCoder(object): 149 | 150 | def __init__(self): 151 | # Create a single Session to run all image coding calls. 152 | self._sess = tf.Session() 153 | 154 | # Initializes function that converts PNG to JPEG data. 155 | self._png_data = tf.placeholder(dtype=tf.string) 156 | image = tf.image.decode_png(self._png_data, channels=3) 157 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 158 | 159 | # Initializes function that decodes RGB JPEG data. 160 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 161 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 162 | 163 | def png_to_jpeg(self, image_data): 164 | return self._sess.run(self._png_to_jpeg, 165 | feed_dict={self._png_data: image_data}) 166 | 167 | def decode_jpeg(self, image_data): 168 | image = self._sess.run(self._decode_jpeg, 169 | feed_dict={self._decode_jpeg_data: image_data}) 170 | assert len(image.shape) == 3 171 | assert image.shape[2] == 3 172 | return image 173 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 174 | 175 | 176 | 177 | 178 | # PRE-PROCESS SINGLE IMAGE(Check if PNG, convert to JPEG, confirm conversion) 179 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 180 | def _is_png(filename): 181 | """Determine if a file contains a PNG format image. 182 | Args: 183 | filename: string, path of the image file. 184 | Returns: 185 | boolean indicating if the image is a PNG. 186 | """ 187 | return '.png' in filename 188 | 189 | 190 | def _process_image(filename, coder): 191 | """Process a single image file. 192 | Args: 193 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 194 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 195 | Returns: 196 | image_buffer: string, JPEG encoding of RGB image. 197 | height: integer, image height in pixels. 198 | width: integer, image width in pixels. 199 | """ 200 | #Resize image to networks input size 201 | size=(FLAGS.image_height, FLAGS.image_width) 202 | original_image = Image.open(filename) 203 | width, height = original_image.size 204 | #print('The original image size is {wide} wide x {height} high'.format(wide=width, height=height)) 205 | 206 | resized_image = original_image.resize(size) 207 | width, height = resized_image.size 208 | #print('The resized image size is {wide} wide x {height} high'.format(wide=width, height=height)) 209 | resized_image.save(filename) 210 | 211 | 212 | #Sleep a bit before file is re-read 5 milliseconds 213 | sleep(0.005) 214 | 215 | #ensure that all dataset images have been conveted to .jpeg 216 | image = np.asarray(original_image, np.uint8) #get image data 217 | shape = np.array(image.shape, np.int32) #get image shape 218 | shape_data = shape.tobytes() #convert image shape to bytes 219 | image_data = image.tobytes() # convert image to raw data bytes in the array. 220 | 221 | """ ANOTHER METHOD 222 | # Read the image file. 223 | with tf.gfile.FastGFile(filename, 'rb') as f: 224 | image_data = f.read() 225 | 226 | 227 | # Convert any PNG to JPEG's for consistency. 228 | if _is_png(filename): 229 | print('Converting PNG to JPEG for %s' % filename) 230 | image_data = coder.png_to_jpeg(image_data) 231 | 232 | # Decode the RGB JPEG. 233 | image = coder.decode_jpeg(image_data) 234 | 235 | # Check that image converted to RGB 236 | assert len(image.shape) == 3 237 | height = image.shape[0] 238 | width = image.shape[1] 239 | assert image.shape[2] == 3""" 240 | return image_data, shape_data 241 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 242 | 243 | 244 | 245 | 246 | # PROCESS BATCHES OF IMAGES AS AS EXAMPLE PROTO SAVED TO TFRecord PER SHARD 247 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 248 | def _process_image_files_batch(coder, thread_index, ranges, name, filenames, 249 | texts, labels, num_shards): 250 | """Processes and saves list of images as TFRecord in 1 thread. 251 | Args: 252 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 253 | thread_index: integer, unique batch to run index is within [0, len(ranges)). 254 | ranges: list of pairs of integers specifying ranges of each batches to 255 | analyze in parallel. 256 | name: string, unique identifier specifying the data set 257 | filenames: list of strings; each string is a path to an image file 258 | texts: list of strings; each string is human readable, e.g. 'dog' 259 | labels: list of integer; each integer identifies the ground truth 260 | num_shards: integer number of shards for this data set. 261 | """ 262 | # Each thread produces N shards where N = int(num_shards / num_threads). 263 | # For instance, if num_shards = 128, and the num_threads = 2, then the first 264 | # thread would produce shards [0, 64). 265 | num_threads = len(ranges) 266 | assert not num_shards % num_threads 267 | num_shards_per_batch = int(num_shards / num_threads) #Same as number of shards per thread 268 | 269 | shard_ranges = np.linspace(ranges[thread_index][0], 270 | ranges[thread_index][1], 271 | num_shards_per_batch + 1).astype(int) 272 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] 273 | 274 | counter = 0 275 | for s in range(num_shards_per_batch): 276 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010' 277 | shard = thread_index * num_shards_per_batch + s 278 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards) 279 | output_file = os.path.join(FLAGS.output_directory, output_filename) 280 | writer = tf.python_io.TFRecordWriter(output_file) 281 | 282 | shard_counter = 0 283 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) 284 | for i in files_in_shard: 285 | filename = filenames[i] 286 | label = labels[i] 287 | text = texts[i] 288 | 289 | try: 290 | # image_buffer, height, width = _process_image(filename, coder) 291 | image_buffer, shape_buffer = _process_image(filename, coder) 292 | except Exception as e: 293 | print(e) 294 | print('SKIPPED: Unexpected eror while decoding %s.' % filename) 295 | continue 296 | 297 | example = _convert_to_example(filename, image_buffer, label, 298 | text, shape_buffer) 299 | writer.write(example.SerializeToString()) 300 | shard_counter += 1 301 | counter += 1 302 | 303 | if not counter % 1000: 304 | print('%s [thread %d]: Processed %d of %d images in thread batch.' % 305 | (datetime.now(), thread_index, counter, num_files_in_thread)) 306 | sys.stdout.flush() 307 | 308 | writer.close() 309 | print('%s [thread %d]: Wrote %d images to %s' % 310 | (datetime.now(), thread_index, shard_counter, output_file)) 311 | sys.stdout.flush() 312 | shard_counter = 0 313 | print('%s [thread %d]: Wrote %d images to %d shards.' % 314 | (datetime.now(), thread_index, counter, num_files_in_thread)) 315 | sys.stdout.flush() 316 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 317 | 318 | 319 | 320 | 321 | # PROCESS AND SAVES LIST OF IMAGES AS TFRecord OF EXAMPLE PROTOS (Entire Dataset details sent here) 322 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 323 | def _process_image_files(name, filenames, texts, labels, num_shards): 324 | """Process and save list of images as TFRecord of Example protos. 325 | Args: 326 | name: string, unique identifier specifying the data set 327 | filenames: list of strings; each string is a path to an image file 328 | texts: list of strings; each string is human readable, e.g. 'dog' 329 | labels: list of integer; each integer identifies the ground truth 330 | num_shards: integer number of shards for this data set. 331 | """ 332 | assert len(filenames) == len(texts) 333 | assert len(filenames) == len(labels) 334 | 335 | # Break all images into batches with a [ranges[i][0], ranges[i][1]]. 336 | spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int) 337 | ranges = [] 338 | for i in range(len(spacing) - 1): 339 | ranges.append([spacing[i], spacing[i + 1]]) 340 | 341 | # Launch a thread for each batch. 342 | print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges)) 343 | sys.stdout.flush() 344 | 345 | # Create a mechanism for monitoring when all threads are finished. 346 | coord = tf.train.Coordinator() 347 | 348 | # Create a generic TensorFlow-based utility for converting all image codings. 349 | coder = ImageCoder() 350 | 351 | threads = [] 352 | for thread_index in range(len(ranges)): 353 | args = (coder, thread_index, ranges, name, filenames, 354 | texts, labels, num_shards) 355 | #From the entire data set details sent to _process_image_files, convert then to proto examples in batches (per no of threads set) and save as TFRecord 356 | t = threading.Thread(target=_process_image_files_batch, args=args) 357 | t.start() 358 | threads.append(t) 359 | 360 | # Wait for all the threads to terminate. 361 | coord.join(threads) 362 | print('%s: Finished writing all %d images in data set.' % 363 | (datetime.now(), len(filenames))) 364 | sys.stdout.flush() 365 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 366 | 367 | 368 | 369 | 370 | 371 | #BUILD LIST OF ALL IMAGES FILES AND LABELS IN THE DATA SET 372 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 373 | def _find_image_files(data_dir, labels_file): 374 | """ 375 | Args: 376 | data_dir: string, path to the root directory of images. 377 | Assumes that the image data set resides in JPEG files located in 378 | the following directory structure. 379 | data_dir/dog/another-image.JPEG 380 | data_dir/dog/my-image.jpg 381 | where 'dog' is the label associated with these images. 382 | labels_file: string, path to the labels file. 383 | The list of valid labels are held in this file. Assumes that the file 384 | contains entries as such: 385 | dog 386 | cat 387 | flower 388 | where each line corresponds to a label. We map each label contained in 389 | the file to an integer starting with the integer 0 corresponding to the 390 | label contained in the first line. 391 | Returns: 392 | filenames: list of strings; each string is a path to an image file. 393 | texts: list of strings; each string is the class, e.g. 'dog' 394 | labels: list of integer; each integer identifies the ground truth. 395 | """ 396 | print('Determining list of input files and labels from %s ' % labels_file) 397 | unique_labels = [l.strip() for l in tf.gfile.FastGFile( 398 | labels_file, 'r').readlines()] 399 | 400 | labels = [] 401 | filenames = [] 402 | texts = [] 403 | 404 | # Leave label index 0 empty as a background class. 405 | label_index = 1 406 | 407 | # Construct the list of JPEG files and labels. 408 | for text in unique_labels: 409 | jpeg_file_path = '%s/%s/*' % (data_dir, text) 410 | print("File path %s \n" % jpeg_file_path); 411 | matching_files = tf.gfile.Glob(jpeg_file_path) 412 | 413 | labels.extend([label_index] * len(matching_files)) 414 | texts.extend([text] * len(matching_files)) 415 | filenames.extend(matching_files) 416 | 417 | if not label_index % 100: 418 | print('Finished finding files in %d of %d classes.' % ( 419 | label_index, len(labels))) 420 | label_index += 1 421 | 422 | # Shuffle the ordering of all image files in order to guarantee 423 | # random ordering of the images with respect to label in the 424 | # saved TFRecord files. Make the randomization repeatable. 425 | shuffled_index = list(range(len(filenames))) 426 | random.seed(12345) 427 | random.shuffle(shuffled_index) 428 | 429 | filenames = [filenames[i] for i in shuffled_index] 430 | texts = [texts[i] for i in shuffled_index] 431 | labels = [labels[i] for i in shuffled_index] 432 | 433 | print('Found %d JPEG files across %d labels inside %s' % 434 | (len(filenames), len(unique_labels), data_dir)) 435 | return filenames, texts, labels 436 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 437 | 438 | 439 | 440 | #CALL TO PROCESS DATASET IS MADE HERE 441 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 442 | def _process_dataset(name, directory, num_shards, labels_file): 443 | """Process a complete data set and save it as a TFRecord. 444 | Args: 445 | name: string, unique identifier specifying the data set. 446 | directory: string, root path to the data set. 447 | num_shards: integer number of shards for this data set. 448 | labels_file: string, path to the labels file. 449 | """ 450 | filenames, texts, labels = _find_image_files(directory, labels_file) #Build list of dataset image file (path to them) and their labels as string and integer 451 | _process_image_files(name, filenames, texts, labels, num_shards) #Process the entire list of images in the dataset into a TF record 452 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 453 | 454 | 455 | 456 | #MAIN 457 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 458 | def main(unused_argv): 459 | assert not FLAGS.train_shards % FLAGS.num_threads, ('Please make the FLAGS.num_threads commensurate with FLAGS.train_shards') 460 | assert not FLAGS.validation_shards % FLAGS.num_threads, ('Please make the FLAGS.num_threads commensurate with ''FLAGS.validation_shards') 461 | print('Result will be saved to %s' % FLAGS.output_directory) 462 | 463 | # Run it! 464 | _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file) 465 | _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file) 466 | 467 | 468 | if __name__ == '__main__': 469 | tf.app.run() 470 | 471 | """-------------------------------------------------------------------------------------------------------------------------------------------------------""" 472 | -------------------------------------------------------------------------------- /freeze_and_convert_to_tflite.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################################################################################################# 4 | # Bash script to help freez tensorflow GraphDef + ckpt into FrozenGraphDef, Optimize and convert to tflite model# 5 | # $1 = path/to/.pbtx file # 6 | # $2 = path/to/.ckpt file # 7 | # $3 = true/false. 'true' for .pb and 'false' for .pbtx input # 8 | # $4 = path/to/save/frozengraph # 9 | # $5 = graph input node name # 10 | # $6 = graph output_node_name # 11 | # $7 = path/to/tflite/file e.g path/to/tflite_model.lite # 12 | # $8 = input type e.g FLOAT # 13 | # $9 = inference (output) type (FLOAT or QUANTIZED) # 14 | # $10 = path/to/optimized_graph.pb # 15 | # # 16 | # Sample usage: # 17 | # freez_graph_tf.sh /tmp/graph.pbtx /tmp/model.ckpt-0 false /tmp/frozen.pb ArgMax # 18 | # # 19 | # Make sure you are runing this from the tensorflow home folder after cloning the TF repository # 20 | # # 21 | ################################################################################################################# 22 | 23 | 24 | #FREEZE GRAPH 25 | #bazel build tensorflow/python/tools:freeze_graph 26 | bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$1 --input_checkpoint=$2 --input_binary=$3 --output_graph=$4 --output_node_names=$6 27 | 28 | 29 | #VIEW SUMARY OF GRAPH 30 | #bazel build /tensorflow/tools/graph_transforms:summarize_graph 31 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=$4 32 | 33 | #TRANSFORM/OPTIMIZE GRAPH 34 | #bazel build tensorflow/tools/graph_transforms:transform_graph 35 | bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=$4 --out_graph=${10} --inputs=$5 --outputs=$6 --transforms='strip_unused_nodes(type=float, shape="1,227,227,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) round_weights(num_steps=256) obfuscate_names sort_by_execution_order fold_old_batch_norms fold_batch_norms' 36 | 37 | 38 | #CONVERT TO TFLITE MODEL 39 | #bazel build /tensorflow/contrib/lite/toco:toco 40 | bazel-bin/tensorflow/contrib/lite/toco/toco --input_format=TENSORFLOW_GRAPHDEF --input_file=${10} --output_format=TFLITE --output_file=$7 --inference_type=$9 --input_type=$8 --input_arrays=$5 --output_arrays=$6 --inference_input_type=$8 --input_shapes=1,227,227,3 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /imgs/accuracy_per_epoch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/accuracy_per_epoch.png -------------------------------------------------------------------------------- /imgs/loss_per_epoch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/loss_per_epoch.png -------------------------------------------------------------------------------- /imgs/network_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/network_visualization.png -------------------------------------------------------------------------------- /imgs/tensorboard_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/imgs/tensorboard_plots.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #IMPORTS 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from sklearn.model_selection import train_test_split 6 | 7 | #import Image 8 | import math 9 | import numpy as np 10 | import tensorflow as tf 11 | import matplotlib.pyplot as plt 12 | import os 13 | import time 14 | 15 | def _int64_feature(value): 16 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 17 | 18 | def _bytes_feature(value): 19 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 20 | 21 | 22 | #FUNCTION TO GET ALL DATASET DATA 23 | def _process_dataset(serialized): 24 | 25 | 26 | #Specify the fatures you want to extract 27 | features = {'image/shape': tf.FixedLenFeature([], tf.string), 28 | 'image/class/label': tf.FixedLenFeature([], tf.int64), 29 | 'image/class/text': tf.FixedLenFeature([], tf.string), 30 | 'image/filename': tf.FixedLenFeature([], tf.string), 31 | 'image/encoded': tf.FixedLenFeature([], tf.string)} 32 | parsed_example = tf.parse_single_example(serialized, features=features) 33 | 34 | #Finese extracted data 35 | image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8) 36 | shape = tf.decode_raw(parsed_example['image/shape'], tf.int32) 37 | label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32) 38 | reshaped_img = tf.reshape(image_raw, shape) 39 | casted_img = tf.cast(reshaped_img, tf.float32) 40 | label_tensor= [label] 41 | image_tensor = [casted_img] 42 | return label_tensor, image_tensor 43 | 44 | 45 | #MAIN FUNCTION 46 | def main(unused_argv): 47 | 48 | 49 | print("STARTED\n\n") 50 | 51 | #Declare needed variables 52 | perform_shuffle=False 53 | repeat_count=1 54 | batch_size=1000 55 | num_of_epochs=1 56 | 57 | #Directory path to the '.tfrecord' files 58 | filenames = ["/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00000-of-00002", "/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00001-of-00002"] 59 | 60 | 61 | print("GETTING RECORD COUNT") 62 | #Determine total number of records in the '.tfrecord' files 63 | record_count = 0 64 | for fn in filenames: 65 | for record in tf.python_io.tf_record_iterator(fn): 66 | record_count += 1 67 | print("Total Number of Records in the .tfrecord file(s): %i" % record_count) 68 | 69 | 70 | dataset = tf.data.TFRecordDataset(filenames=filenames) 71 | dataset = dataset.map(_process_dataset) #Get all content of dataset 72 | dataset = dataset.shuffle(buffer_size=1000) #Shuffle selection from the dataset 73 | dataset = dataset.repeat(repeat_count) #Repeats dataset this # times 74 | dataset = dataset.batch(batch_size) #Batch size to use 75 | iterator = dataset.make_initializable_iterator() #Create iterator which helps to get all iamges in the dataset 76 | labels_tensor, images_tensor = iterator.get_next() #Get batch data 77 | no_of_rounds = int(math.ceil(record_count/batch_size)); 78 | 79 | #Create tf session, get nest set of batches, and evelauate them in batches 80 | sess = tf.Session() 81 | print("Total number of strides needed to stream through dataset: ~%i" %no_of_rounds) 82 | 83 | 84 | for _ in range(2): 85 | sess.run(iterator.initializer) 86 | count=0 87 | complete_evaluation_image_set = np.array([]) 88 | complete_evaluation_label_set = np.array([]) 89 | while True: 90 | try: 91 | print("Now evaluating tensors for stride %i out of %i" % (count, no_of_rounds)) 92 | evaluated_label, evaluated_image = sess.run([labels_tensor, images_tensor]) 93 | #convert evaluated tensors to np array 94 | label_np_array = np.asarray(evaluated_label, dtype=np.uint8) 95 | image_np_array = np.asarray(evaluated_image, dtype=np.uint8) 96 | #squeeze np array to make dimesnsions appropriate 97 | squeezed_label_np_array = label_np_array.squeeze() 98 | squeezed_image_np_array = image_np_array.squeeze() 99 | #Split data into training and testing data 100 | image_train, image_test, label_train, label_test = train_test_split(squeezed_image_np_array, squeezed_label_np_array, test_size=0.010, random_state=42, shuffle=True) 101 | #Store evaluation data in its place 102 | complete_evaluation_image_set = np.append(complete_evaluation_image_set, image_test) 103 | complete_evaluation_label_set = np.append(complete_evaluation_label_set, label_test) 104 | #Feed current batch to TF Estimator for training 105 | except tf.errors.OutOfRangeError: 106 | print("End of Dataset Reached") 107 | break 108 | count=count+1 109 | print(complete_evaluation_label_set.shape) 110 | print(complete_evaluation_image_set.shape) 111 | sess.close() 112 | 113 | print("End of Training") 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | tf.app.run() 119 | 120 | -------------------------------------------------------------------------------- /train_alexnet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | """ 4 | Used to Create AlexNet and Train it using image data stored in a TF Record over several Epochs 5 | 6 | @date: 4th December, 2017 7 | @author: Oluwole Oyetoke 8 | @Language: Python 9 | @email: oluwoleoyetoke@gmail.com 10 | 11 | AlexNet NETWORK OVERVIEW 12 | AlexNet Structure: 60 million Parameters 13 | 8 layers in total: 5 Convolutional and 3 Fully Connected Layers 14 | [227x227x3] INPUT 15 | [55x55x96] CONV1: 96 11x11 filters at stride 4, pad 0 16 | [27x27x96] MAX POOL1: 3x3 filters at stride 2 17 | [27x27x96] NORM1: Normalization layer 18 | [27x27x256] CONV2: 256 5x5 filters at stride 1, pad 2 19 | [13x13x256] MAX POOL2: 3x3 filters at stride 2 20 | [13x13x256] NORM2: Normalization layer 21 | [13x13x384] CONV3: 384 3x3 filters at stride 1, pad 1 22 | [13x13x384] CONV4: 384 3x3 filters at stride 1, pad 1 23 | [13x13x256] CONV5: 256 3x3 filters at stride 1, pad 1 24 | [6x6x256] MAX POOL3: 3x3 filters at stride 2 25 | [4096] FC6: 4096 neurons 26 | [4096] FC7: 4096 neurons 27 | [1000] FC8: 43 neurons (class scores) 28 | 29 | 30 | 31 | 32 | 33 | 34 | #IMPORTS, VARIABLE DECLARATION, AND LOGGING TYPE SETTING 35 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | from sklearn.model_selection import train_test_split 40 | from dateutil.relativedelta import relativedelta 41 | 42 | import os 43 | import math 44 | import time 45 | import numpy as np 46 | import tensorflow as tf #import tensorflow 47 | import matplotlib.pyplot as plt 48 | import datetime 49 | 50 | flags = tf.app.flags 51 | flags.DEFINE_integer("image_width", "227", "Alexnet input layer width") 52 | flags.DEFINE_integer("image_height", "227", "Alexnet input layer height") 53 | flags.DEFINE_integer("image_channels", "3", "Alexnet input layer channels") 54 | flags.DEFINE_integer("num_of_classes", "43", "Number of training clases") 55 | FLAGS = flags.FLAGS 56 | 57 | losses_bank = np.array([]) #global 58 | accuracy_bank = np.array([]) #global 59 | steps_bank = np.array([]) #global 60 | epoch_bank = np.array([]) #global 61 | 62 | tf.logging.set_verbosity(tf.logging.WARN) #setting up logging (can be DEBUG, ERROR, FATAL, INFO or WARN) 63 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 64 | 65 | 66 | 67 | 68 | 69 | 70 | #WRAPPER FOR INSERTING int64 FEATURES int64 FEATURES & BYTES FEATURES INTO EXAMPLES PROTO 71 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 72 | def _int64_feature(value): 73 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 74 | 75 | def _bytes_feature(value): 76 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 77 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 78 | 79 | 80 | #CREATE CNN STRUCTURE 81 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 82 | def cnn_model_fn(features, labels, mode): 83 | 84 | """INPUT LAYER""" 85 | input_layer = tf.reshape(features["x"], [-1, FLAGS.image_width, FLAGS.image_height, FLAGS.image_channels], name="input_layer") #Alexnet uses 227x227x3 input layer. '-1' means pick batch size randomly 86 | #print(input_layer) 87 | 88 | """%FIRST CONVOLUTION BLOCK 89 | The first convolutional layer filters the 227×227×3 input image with 90 | 96 kernels of size 11×11 with a stride of 4 pixels. Bias of 1.""" 91 | conv1 = tf.layers.conv2d(inputs=input_layer, filters=96, kernel_size=[11, 11], strides=4, padding="valid", activation=tf.nn.relu) 92 | lrn1 = tf.nn.lrn(input=conv1, depth_radius=5, bias=1.0, alpha=0.0001/5.0, beta=0.75); #Normalization layer 93 | pool1_conv1 = tf.layers.max_pooling2d(inputs=lrn1, pool_size=[3, 3], strides=2) #Max Pool Layer 94 | #print(pool1_conv1) 95 | 96 | 97 | """SECOND CONVOLUTION BLOCK 98 | Divide the 96 channel blob input from block one into 48 and process independently""" 99 | conv2 = tf.layers.conv2d(inputs=pool1_conv1, filters=256, kernel_size=[5, 5], strides=1, padding="same", activation=tf.nn.relu) 100 | lrn2 = tf.nn.lrn(input=conv2, depth_radius=5, bias=1.0, alpha=0.0001/5.0, beta=0.75); #Normalization layer 101 | pool2_conv2 = tf.layers.max_pooling2d(inputs=lrn2, pool_size=[3, 3], strides=2) #Max Pool Layer 102 | #print(pool2_conv2) 103 | 104 | """THIRD CONVOLUTION BLOCK 105 | Note that the third, fourth, and fifth convolution layers are connected to one 106 | another without any intervening pooling or normalization layers. 107 | The third convolutional layer has 384 kernels of size 3 × 3 108 | connected to the (normalized, pooled) outputs of the second convolutional layer""" 109 | conv3 = tf.layers.conv2d(inputs=pool2_conv2, filters=384, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu) 110 | #print(conv3) 111 | 112 | #FOURTH CONVOLUTION BLOCK 113 | """%The fourth convolutional layer has 384 kernels of size 3 × 3""" 114 | conv4 = tf.layers.conv2d(inputs=conv3, filters=384, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu) 115 | #print(conv4) 116 | 117 | #FIFTH CONVOLUTION BLOCK 118 | """%the fifth convolutional layer has 256 kernels of size 3 × 3""" 119 | conv5 = tf.layers.conv2d(inputs=conv4, filters=256, kernel_size=[3, 3], strides=1, padding="same", activation=tf.nn.relu) 120 | pool3_conv5 = tf.layers.max_pooling2d(inputs=conv5, pool_size=[3, 3], strides=2, padding="valid") #Max Pool Layer 121 | #print(pool3_conv5) 122 | 123 | 124 | #FULLY CONNECTED LAYER 1 125 | """The fully-connected layers have 4096 neurons each""" 126 | pool3_conv5_flat = tf.reshape(pool3_conv5, [-1, 6* 6 * 256]) #output of conv block is 6x6x256 therefore, to connect it to a fully connected layer, we can flaten it out 127 | fc1 = tf.layers.dense(inputs=pool3_conv5_flat, units=4096, activation=tf.nn.relu) 128 | #fc1 = tf.layers.conv2d(inputs=pool3_conv5, filters=4096, kernel_size=[6, 6], strides=1, padding="valid", activation=tf.nn.relu) #representing the FCL using a convolution block (no need to do 'pool3_conv5_flat' above) 129 | #print(fc1) 130 | 131 | #FULLY CONNECTED LAYER 2 132 | """since the output from above is [1x1x4096]""" 133 | fc2 = tf.layers.dense(inputs=fc1, units=4096, activation=tf.nn.relu) 134 | #fc2 = tf.layers.conv2d(inputs=fc1, filters=4096, kernel_size=[1, 1], strides=1, padding="valid", activation=tf.nn.relu) 135 | #print(fc2) 136 | 137 | #FULLY CONNECTED LAYER 3 138 | """since the output from above is [1x1x4096]""" 139 | logits = tf.layers.dense(inputs=fc2, units=FLAGS.num_of_classes, name="logits_layer") 140 | #fc3 = tf.layers.conv2d(inputs=fc2, filters=43, kernel_size=[1, 1], strides=1, padding="valid") 141 | #logits = tf.layers.dense(inputs=fc3, units=FLAGS.num_of_classes) #converting the convolutional block (tf.layers.conv2d) to a dense layer (tf.layers.dense). Only needed if we had used tf.layers.conv2d to represent the FCLs 142 | #print(logits) 143 | 144 | #PASS OUTPUT OF LAST FC LAYER TO A SOFTMAX LAYER 145 | """convert these raw values into two different formats that our model function can return: 146 | The predicted class for each example: a digit from 1–43. 147 | The probabilities for each possible target class for each example 148 | tf.argmax(input=fc3, axis=1: Generate predictions from the 43 last filters returned from the fc3. Axis 1 will apply argmax to the rows 149 | tf.nn.softmax(logits, name="softmax_tensor"): Generate the probability distribution 150 | """ 151 | predictions = { 152 | "classes": tf.argmax(input=logits, axis=1, name="classes_tensor"), 153 | "probabilities": tf.nn.softmax(logits, name="softmax_tensor") 154 | } 155 | 156 | #Return result if we were in prediction mode and not training 157 | if mode == tf.estimator.ModeKeys.PREDICT: 158 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 159 | 160 | #CALCULATE OUR LOSS 161 | """For both training and evaluation, we need to define a loss function that measures how closely the 162 | model's predictions match the target classes. For multiclass classification, cross entropy is typically used as the loss metric.""" 163 | onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=FLAGS.num_of_classes) 164 | loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits) 165 | tf.summary.scalar('Loss Per Stride', loss) #Just to see loss values per batch on tensorboard 166 | 167 | #CONFIGURE TRAINING 168 | """Since the loss of the CNN is the softmax cross-entropy of the fc3 layer 169 | and our labels. Let's configure our model to optimize this loss value during 170 | training. We'll use a learning rate of 0.001 and stochastic gradient descent 171 | as the optimization algorithm:""" 172 | if mode == tf.estimator.ModeKeys.TRAIN: 173 | optimizer = tf.train.AdamOptimizer(learning_rate=0.00001) 174 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) #global_Step needed for proper graph on tensor board 175 | #optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.00005) #Very small learning rate used. Training will be slower at converging by better 176 | #train_op = optimizer.minimize(loss=loss,global_step=tf.train.get_global_step()) 177 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) 178 | 179 | #ADD EVALUATION METRICS 180 | eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])} 181 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) 182 | """-----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 183 | 184 | 185 | 186 | 187 | #FUNCTION TO PROCESS ALL DATASET DATA 188 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 189 | def _process_dataset(serialized): 190 | #Specify the fatures you want to extract 191 | features = {'image/shape': tf.FixedLenFeature([], tf.string), 192 | 'image/class/label': tf.FixedLenFeature([], tf.int64), 193 | 'image/class/text': tf.FixedLenFeature([], tf.string), 194 | 'image/filename': tf.FixedLenFeature([], tf.string), 195 | 'image/encoded': tf.FixedLenFeature([], tf.string)} 196 | parsed_example = tf.parse_single_example(serialized, features=features) 197 | 198 | #Finese extracted data 199 | image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8) 200 | shape = tf.decode_raw(parsed_example['image/shape'], tf.int32) 201 | label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32) 202 | reshaped_img = tf.reshape(image_raw, [FLAGS.image_width, FLAGS.image_height, FLAGS.image_channels]) 203 | casted_img = tf.cast(reshaped_img, tf.float32) 204 | label_tensor= [label] 205 | image_tensor = [casted_img] 206 | return label_tensor, image_tensor 207 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 208 | 209 | #PLOT TRAINING PROGRESS 210 | def _plot_training_progress(): 211 | global losses_bank #to make sure losses_bank is not declared again in this method (as a local variable) 212 | global accuracy_bank 213 | global steps_bank 214 | global epoch_bank 215 | 216 | #PLOT LOSS PER EPOCH 217 | loss_per_epoch_fig = plt.figure("LOSS PER EPOCH PLOT") 218 | plt.plot(epoch_bank, losses_bank, 'ro-') 219 | loss_per_epoch_fig.suptitle('LOSS LEVEL PER EPOCH') 220 | plt.xlabel('Epoch Count') 221 | plt.ylabel('Loss Value') 222 | manager = plt.get_current_fig_manager() 223 | manager.resize(*manager.window.maxsize()) #maximize plot 224 | loss_per_epoch_fig.canvas.draw() 225 | plt.show(block=False) 226 | loss_per_epoch_fig.savefig("/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/evaluation_plots/loss_per_epoch.png") #save plot 227 | 228 | 229 | #PLOT ACCURACY PER EPOCH 230 | accuracy_per_epoch_fig = plt.figure("ACCURACY PER EPOCH PLOT") 231 | plt.plot(epoch_bank, accuracy_bank, 'bo-') 232 | accuracy_per_epoch_fig.suptitle('ACCURACY PERCENTAGE PER EPOCH') 233 | plt.xlabel('Epoch Count') 234 | plt.ylabel('Accuracy Percentage') 235 | manager = plt.get_current_fig_manager() 236 | manager.resize(*manager.window.maxsize()) #maximize plot 237 | accuracy_per_epoch_fig.canvas.draw() 238 | plt.show(block=False) 239 | accuracy_per_epoch_fig.savefig("/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/evaluation_plots/accuracy_per_epoch.png") #save plot 240 | 241 | 242 | #TRAINING AND EVALUATING THE ALEXNET CNN CLASSIFIER 243 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 244 | def main(unused_argv): 245 | 246 | #Declare needed variables 247 | global losses_bank 248 | global accuracy_bank 249 | global steps_bank 250 | global epoch_bank 251 | 252 | perform_shuffle=False 253 | repeat_count=1 254 | dataset_batch_size=1024 #1024 #Chuncks picked in dataset per time 255 | training_batch_size = np.int32(dataset_batch_size/8) #Chuncks processed by tf.estimator per time 256 | epoch_count=0 257 | overall_training_epochs=60 #60 epochs in total 258 | 259 | start_time=time.time() #taking current time as starting time 260 | start_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(start_time)) 261 | start_time_dt = datetime.datetime.strptime(start_time_string, '%Y-%m-%d %H:%M:%S') 262 | print("STARTED @ %s" % start_time_string) 263 | #LOAD TRAINING DATA 264 | print("LOADING DATASET\n\n") 265 | filenames = ["/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00000-of-00002", "/home/olu/Dev/data_base/sign_base/output/TFRecord_227x227/train-00001-of-00002"] #Directory path to the '.tfrecord' files 266 | model_dir="/home/olu/Dev/data_base/sign_base/output/Checkpoints_N_Model/trained_alexnet_model" 267 | #DETERMINE TOTAL NUMBER OF RECORDS IN THE '.tfrecord' FILES 268 | print("GETTING COUNT OF RECORDS/EXAMPLES IN DATASET") 269 | record_count = 0 270 | for fn in filenames: 271 | for record in tf.python_io.tf_record_iterator(fn): 272 | record_count += 1 273 | print("Total number of records in the .tfrecord file(s): %i\n\n" % record_count) 274 | no_of_rounds = int(math.ceil(record_count/dataset_batch_size)); 275 | 276 | 277 | #EDIT HOW OFTEN CHECK POINTS SHOULD BE SAVED 278 | check_point_interval = int(record_count/training_batch_size) 279 | my_estimator_config = tf.estimator.RunConfig(model_dir=model_dir,tf_random_seed=None,save_summary_steps=100, 280 | save_checkpoints_steps=check_point_interval,save_checkpoints_secs=None,session_config=None,keep_checkpoint_max=10,keep_checkpoint_every_n_hours=10000,log_step_count_steps=100) 281 | 282 | print("CREATING ESTIMATOR AND LOADING DATASET") 283 | #CREATE ESTIMATOR 284 | """Estimator: a TensorFlow class for performing high-level model training, evaluation, and inference""" 285 | traffic_sign_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=model_dir, config=my_estimator_config) #specify where the finally trained model and (checkpoints during training) should be saved in 286 | 287 | #SET-UP LOGGIN FOR PREDICTIONS 288 | tensors_to_log = {"probabilities": "softmax_tensor"} 289 | logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50) #Log after every 50 itterations 290 | 291 | 292 | #PROCESS AND RETREIVE DATASET CONTENT IN BATCHES OF 'dataset_batch_size' 293 | dataset = tf.data.TFRecordDataset(filenames=filenames) 294 | dataset = dataset.map(_process_dataset) #Get all content of dataset & apply function '_process_dataset' to all its content 295 | dataset = dataset.shuffle(buffer_size=dataset_batch_size) #Shuffle selection from the dataset/epoch. buffer size of 1000 296 | dataset = dataset.repeat(repeat_count) #Repeat ittereation through dataset 'repeat_count' times 297 | dataset = dataset.batch(dataset_batch_size) #Batch size to use to pick from dataset 298 | iterator = dataset.make_initializable_iterator() #Create iterator which helps to get all iamges in the dataset 299 | labels_tensor, images_tensor = iterator.get_next() #Get batch data 300 | 301 | 302 | #CREATE TF SESSION TO ITTERATIVELY EVALUATE THE BATCHES OF DATASET TENSORS RETREIVED AND PASS THEM TO ESTIMATOR FOR TRAINING/EVALUATION 303 | sess = tf.Session() 304 | print("Approximately %i strides needed to stream through 1 epoch of dataset\n\n" %no_of_rounds) 305 | 306 | 307 | for _ in range(overall_training_epochs): 308 | epoch_start_time=time.time() #taking current time as starting time 309 | epoch_start_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(epoch_start_time)) 310 | epoch_start_time_dt = datetime.datetime.strptime(epoch_start_time_string, '%Y-%m-%d %H:%M:%S') 311 | epoch_count=epoch_count+1; 312 | print("Epoch %i out of %i" % (epoch_count, overall_training_epochs)) 313 | print("Note: Each complete epoch processes %i images, feeding it into the classifier in batches of %i" % (record_count, dataset_batch_size)) 314 | print("...Classifier then deals with each batch of %i images pushed to it in batch sizes of %i\n"%(dataset_batch_size, training_batch_size)) 315 | sess.run(iterator.initializer) 316 | 317 | strides_count=1 318 | complete_evaluation_image_set = np.array([]) 319 | complete_evaluation_label_set = np.array([]) 320 | while True: 321 | try: 322 | imgs_so_far = dataset_batch_size*strides_count 323 | if(imgs_so_far>record_count): 324 | imgs_so_far = record_count 325 | print("\nStride %i (%i images) of stride %i (%i images) for Epoch %i" %(strides_count, imgs_so_far, no_of_rounds, record_count, epoch_count)) 326 | evaluated_label, evaluated_image = sess.run([labels_tensor, images_tensor]) #evaluate tensor 327 | #convert evaluated tensors to np array 328 | label_np_array = np.asarray(evaluated_label, dtype=np.int32) 329 | image_np_array = np.asarray(evaluated_image, dtype=np.float32) 330 | #squeeze np array to make dimesnsions appropriate 331 | squeezed_label_np_array = label_np_array.squeeze() 332 | squeezed_image_np_array = image_np_array.squeeze() 333 | #mean normalization - normalize current batch of images i.e get mean of images in dataset and subtact it from all image intensities in the dataset 334 | dataset_image_mean = np.mean(squeezed_image_np_array) 335 | normalized_image_dataset = np.subtract(squeezed_image_np_array, dataset_image_mean) #help for faster convergence duing training 336 | #split data into training and testing/evaluation data 337 | image_train, image_evaluate, label_train, label_evaluate = train_test_split(normalized_image_dataset, squeezed_label_np_array, test_size=0.10, random_state=42, shuffle=True) #5% of dataset will be used for evaluation/testing 338 | #rectify dimension/shape 339 | if (image_evaluate.ndim<4): #if dimension is just 3 i.e only 1 image loaded 340 | print(image_evaluate.shape) 341 | image_evaluate = image_evaluate.reshape((1,) + image_evaluate.shape) 342 | if (image_train.ndim<4): #if dimension is just 3 i.e only 1 image loaded 343 | print(image_train.shape) 344 | image_train = image_train.reshape((1,) + image_train.shape) 345 | #rectify precision 346 | image_train.astype(np.float32) 347 | image_evaluate.astype(np.float32) 348 | #Store evaluation data in its place 349 | if(strides_count==1): 350 | complete_evaluation_image_set = image_evaluate 351 | else: 352 | complete_evaluation_image_set = np.concatenate((complete_evaluation_image_set.squeeze(), image_evaluate)) 353 | complete_evaluation_label_set = np.append(complete_evaluation_label_set, label_evaluate) 354 | complete_evaluation_label_set = complete_evaluation_label_set.squeeze() 355 | #Feed current batch of training images to TF Estimator for training. TF Estimator deals with them in batches of 'batch_size=32' 356 | train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": image_train},y=label_train,batch_size=training_batch_size,num_epochs=1, shuffle=True) #Note, images have already been shuffled when placed in the TFRecord, shuffled again when being retreived from the record & will be shuffled again when being sent to the classifier 357 | traffic_sign_classifier.train(input_fn=train_input_fn,hooks=[logging_hook]) 358 | except tf.errors.OutOfRangeError: 359 | print("End of Dataset Reached") 360 | break 361 | strides_count=strides_count+1 362 | 363 | #EVALUATE MODEL after every complete epoch (note that out of memory issues happen if all 20% of the dataset's images need to be stored in memory till full epoch is completed. So test set reduced to 5%) 364 | """Once trainingis completed, we then proceed to evaluate the accuracy level of our trained model 365 | To create eval_input_fn, we set num_epochs=1, so that the model evaluates the metrics over one epoch of 366 | data and returns the result. We also set shuffle=False to iterate through the data sequentially.""" 367 | eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": complete_evaluation_image_set},y=complete_evaluation_label_set,num_epochs=1,shuffle=False) 368 | evaluation_result = traffic_sign_classifier.evaluate(input_fn=eval_input_fn) #Get dictionary of loss, global step size and accuracy e.g {'loss': 1.704558, 'accuracyy': 0.58105469, 'global_step': 742} 369 | 370 | #PLOT TRAINING PERFORMANCE 371 | epoch_loss = evaluation_result.get('loss') 372 | epoch_accuracy = evaluation_result.get('accuracy') 373 | epoch_steps = evaluation_result.get('global_step') 374 | losses_bank = np.append(losses_bank, epoch_loss) 375 | accuracy_bank = np.append(accuracy_bank, (epoch_accuracy*100)) 376 | steps_bank = np.append(steps_bank, epoch_steps) 377 | epoch_bank = np.append(epoch_bank, epoch_count) 378 | _plot_training_progress() 379 | 380 | accuracy_percentage = epoch_accuracy*100 381 | print("Network Performance Analysis: Loss(%f), Accuracy(%f percent), Steps(%i)\n"%(epoch_loss,accuracy_percentage,epoch_steps)) 382 | #print(evaluation_result) 383 | 384 | #PRINT EPOCH OPERATION TIME 385 | epoch_end_time=time.time() #taking current time as ending time 386 | epoch_end_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(epoch_end_time)) 387 | epoch_end_time_dt = datetime.datetime.strptime(epoch_end_time_string, '%Y-%m-%d %H:%M:%S') 388 | epoch_elapsed_time = relativedelta(epoch_end_time_dt, epoch_start_time_dt) 389 | print("End of epoch %i.\n Time taken for this epoch: %d Day(s) : %d : Hour(s) : %d Minute(s) : %d Second(s)" %(epoch_count, epoch_elapsed_time.days, epoch_elapsed_time.hours, epoch_elapsed_time.minutes, epoch_elapsed_time.seconds)) 390 | training_elapsed_time_sofar = relativedelta(epoch_end_time_dt, start_time_dt) 391 | print("Overall training time so far: %d Day(s) : %d : Hour(s) : %d Minute(s) : %d Second(s) \n\n\n" %(training_elapsed_time_sofar.days, training_elapsed_time_sofar.hours, training_elapsed_time_sofar.minutes, training_elapsed_time_sofar.seconds)) 392 | sess.close() 393 | 394 | #SAVE FINAL MODEL 395 | #Not really Needed, because TF Estimator saves .meta .data .ckpt at the end of evey epoch 396 | 397 | 398 | 399 | #PRINT TOTAL TRAINING TIME & END 400 | end_time=time.time() #taking current time as ending time 401 | end_time_string = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(end_time)) 402 | end_time_dt = datetime.datetime.strptime(end_time_string, '%Y-%m-%d %H:%M:%S') 403 | elapsed_time_dt = relativedelta(end_time_dt, start_time_dt) 404 | print("END OF TRAINING..... ENDED @ %s" %end_time_string) 405 | print("Final Trained Model is Saved Here: %s" % model_dir) 406 | print("TIME TAKEN FOR ENTIRE TRAINING: %d Day(s) : %d : Hour(s) %d Minute(s) : %d Second(s)" % (elapsed_time_dt.days, elapsed_time_dt.hours, elapsed_time_dt.minutes, elapsed_time_dt.seconds)) 407 | 408 | """----------------------------------------------------------------------------------------------------------------------------------------------------------------""" 409 | 410 | 411 | if __name__=="__main__": 412 | tf.app.run() 413 | -------------------------------------------------------------------------------- /tune_cnn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OluwoleOyetoke/Computer_Vision_Using_TensorFlowLite/01f8ab51a33cd2f578a691327fd56800681e407b/tune_cnn.py --------------------------------------------------------------------------------