├── .gitignore ├── README.md ├── bbox.py ├── cam_demo.py ├── cfg └── yolov3.cfg ├── darknet.py ├── data ├── clothing_vocab.pkl ├── coco.names └── voc.names ├── file_demo.py ├── model.py ├── nice_example.png ├── pallete ├── pallete2 ├── preprocess.py ├── test ├── 819.jpg ├── 828.jpg └── 829.jpg └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | ##### Python ##### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | ###### macOS ##### 142 | # General 143 | .DS_Store 144 | .AppleDouble 145 | .LSOverride 146 | 147 | # Icon must end with two \r 148 | Icon 149 | 150 | 151 | # Thumbnails 152 | ._* 153 | 154 | # Files that might appear in the root of a volume 155 | .DocumentRevisions-V100 156 | .fseventsd 157 | .Spotlight-V100 158 | .TemporaryItems 159 | .Trashes 160 | .VolumeIcon.icns 161 | .com.apple.timemachine.donotpresent 162 | 163 | # Directories potentially created on remote AFP share 164 | .AppleDB 165 | .AppleDesktop 166 | Network Trash Folder 167 | Temporary Items 168 | .apdisk 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FashionAI-Analysis 2 | This is a project for extracting attribute of clothes detected in image. 3 | ## Dependencies 4 | - python >= 3.6 5 | - [pytorch](https://pytorch.org/) >= 1.2 6 | - opencv 7 | - matplotlib 8 | - [RoiAlign](https://github.com/longcw/RoIAlign.pytorch) 9 | ## Installation 10 | 1. Download & install cuda 10.2 toolkit [here](https://developer.nvidia.com/cuda-10.2-download-archive?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=debnetwork) 11 | 2. Download & install anaconda python 3.7 version 12 | 3. Install requirements 13 | ## How to test the model 14 | 1. Download yolo-v3 model from [here](https://drive.google.com/file/d/1yCz6pc6qHJD2Zcz8ldDmJ3NzE8wjaiT6/view?usp=sharing) and put in 'root directory'. 15 | 2. Downoad model from [here](https://drive.google.com/file/d/1k3lvA96ZstbV4a_QtYTuohY79xg_nJYe/view?usp=sharing) and put in 'root directory'. 16 | 17 | 3. Run `file_demo.py` to run each image file demostration 18 | 19 | 4. Run `cam_demo.py` to run web-cam demostration (not tested) 20 | ## Clothing multi-attributes definition 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 37 | 38 | 43 | 44 | 49 | 50 | 55 | 56 | 61 | 62 | 67 | 68 | 73 | 74 | 79 | 80 | 85 | 86 | 91 | 92 | 97 | 98 | 103 | 104 | 109 | 110 | 115 | 116 | 117 | 118 | 119 | 120 | 125 | 126 | 131 | 132 | 137 | 138 | 143 | 144 | 149 | 150 | 155 | 156 | 161 | 162 | 167 | 168 | 173 | 174 | 179 | 180 | 185 | 186 | 191 | 192 | 197 | 198 | 203 | 204 | 205 | 206 | 207 | 208 | 213 | 214 | 219 | 220 | 225 | 226 | 231 | 232 | 237 | 238 | 243 | 244 | 249 | 250 | 255 | 256 | 261 | 262 | 267 | 268 | 273 | 274 | 279 | 280 | 285 | 286 | 291 | 292 | 293 | 294 | 295 | 296 | 301 | 302 | 307 | 308 | 313 | 314 | 319 | 320 | 325 | 326 | 331 | 332 | 337 | 338 | 343 | 344 | 349 | 350 | 355 | 356 | 361 | 362 | 367 | 368 | 373 | 374 | 379 | 380 | 381 | 382 | 383 | 384 | 389 | 390 | 395 | 396 | 401 | 402 | 407 | 408 | 413 | 414 | 419 | 420 | 425 | 426 | 431 | 432 | 437 | 438 | 443 | 444 | 449 | 450 | 455 | 456 | 461 | 462 | 467 | 468 | 469 | 470 | 471 | 472 | 477 | 478 | 483 | 484 | 489 | 490 | 495 | 496 | 501 | 502 | 507 | 508 | 513 | 514 | 519 | 520 | 525 | 526 | 531 | 532 | 537 | 538 | 543 | 544 | 549 | 550 | 555 | 556 | 557 | 558 | 559 | 560 | 565 | 566 | 571 | 572 | 577 | 578 | 583 | 584 | 589 | 590 | 595 | 596 | 601 | 602 | 607 | 608 | 613 | 614 | 619 | 620 | 625 | 626 | 631 | 632 | 637 | 638 | 643 | 644 | 645 | 646 | 647 | 648 | 653 | 654 | 659 | 660 | 665 | 666 | 671 | 672 | 677 | 678 | 683 | 684 | 689 | 690 | 695 | 696 | 701 | 702 | 707 | 708 | 713 | 714 | 719 | 720 | 725 | 726 | 731 | 732 | 733 | 734 | 735 | 736 | 741 | 742 | 747 | 748 | 753 | 754 | 759 | 760 | 765 | 766 | 771 | 772 | 777 | 778 | 783 | 784 | 789 | 790 | 795 | 796 | 801 | 802 | 807 | 808 | 813 | 814 | 819 | 820 | 821 | 822 | 823 | 824 | 829 | 830 | 835 | 836 | 841 | 842 | 847 | 848 | 853 | 854 | 859 | 860 | 865 | 866 | 871 | 872 | 877 | 878 | 883 | 884 | 889 | 890 | 895 | 896 | 901 | 902 | 907 | 908 | 909 | 910 | 911 | 912 | 917 | 918 | 923 | 924 | 929 | 930 | 935 | 936 | 941 | 942 | 947 | 948 | 953 | 954 | 959 | 960 | 965 | 966 | 971 | 972 | 977 | 978 | 983 | 984 | 989 | 990 | 995 | 996 | 997 | 998 | 999 | 1000 | 1005 | 1006 | 1011 | 1012 | 1017 | 1018 | 1023 | 1024 | 1029 | 1030 | 1035 | 1036 | 1041 | 1042 | 1047 | 1048 | 1053 | 1054 | 1059 | 1060 | 1065 | 1066 | 1071 | 1072 | 1077 | 1078 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1093 | 1094 | 1099 | 1100 | 1105 | 1106 | 1111 | 1112 | 1117 | 1118 | 1123 | 1124 | 1129 | 1130 | 1135 | 1136 | 1141 | 1142 | 1147 | 1148 | 1153 | 1154 | 1159 | 1160 | 1165 | 1166 | 1171 | 1172 | 1173 | 1174 | 1175 | 1176 | 1181 | 1182 | 1187 | 1188 | 1193 | 1194 | 1199 | 1200 | 1205 | 1206 | 1211 | 1212 | 1217 | 1218 | 1223 | 1224 | 1229 | 1230 | 1235 | 1236 | 1241 | 1242 | 1247 | 1248 | 1253 | 1254 | 1259 | 1260 | 1261 | 1262 | 1263 | 1264 | 1269 | 1270 | 1275 | 1276 | 1281 | 1282 | 1287 | 1288 | 1293 | 1294 | 1299 | 1300 | 1305 | 1306 | 1311 | 1312 | 1317 | 1318 | 1323 | 1324 | 1329 | 1330 | 1335 | 1336 | 1341 | 1342 | 1347 | 1348 | 1349 | 1350 | 1351 | 1352 | 1357 | 1358 | 1363 | 1364 | 1369 | 1370 | 1375 | 1376 | 1381 | 1382 | 1387 | 1388 | 1393 | 1394 | 1399 | 1400 | 1405 | 1406 | 1411 | 1412 | 1417 | 1418 | 1423 | 1424 | 1429 | 1430 | 1435 | 1436 | 1437 | 1438 | 1439 | 1440 | 1445 | 1446 | 1451 | 1452 | 1457 | 1458 | 1463 | 1464 | 1469 | 1470 | 1475 | 1476 | 1481 | 1482 | 1487 | 1488 | 1493 | 1494 | 1499 | 1500 | 1505 | 1506 | 1511 | 1512 | 1517 | 1518 | 1523 | 1524 | 1525 | 1526 | 1527 | 1528 |
33 | 34 | 35 | 36 | 39 | 40 | 1 41 | 42 | 45 | 46 | 2 47 | 48 | 51 | 52 | 3 53 | 54 | 57 | 58 | 4 59 | 60 | 63 | 64 | 5 65 | 66 | 69 | 70 | 6 71 | 72 | 75 | 76 | 7 77 | 78 | 81 | 82 | 8 83 | 84 | 87 | 88 | 9 89 | 90 | 93 | 94 | 10 95 | 96 | 99 | 100 | 11 101 | 102 | 105 | 106 | 12 107 | 108 | 111 | 112 | 13 113 | 114 |
121 | 122 | GT values 123 | 124 | 127 | 128 | Top color(14) 129 | 130 | 133 | 134 | Top pattern(6) 135 | 136 | 139 | 140 | Top gender(2) 141 | 142 | 145 | 146 | Top season(4) 147 | 148 | 151 | 152 | Top type(7) 153 | 154 | 157 | 158 | Top sleeves(3) 159 | 160 | 163 | 164 | Bottom color(14) 165 | 166 | 169 | 170 | Bottom pattern(6) 171 | 172 | 175 | 176 | Bottom gender(2) 177 | 178 | 181 | 182 | Bottom season(4) 183 | 184 | 187 | 188 | Bottom length(2) 189 | 190 | 193 | 194 | Bottom type(2) 195 | 196 | 199 | 200 | leg pose(3) 201 | 202 |
209 | 210 | 0 211 | 212 | 215 | 216 | null 217 | 218 | 221 | 222 | null 223 | 224 | 227 | 228 | null 229 | 230 | 233 | 234 | null 235 | 236 | 239 | 240 | null 241 | 242 | 245 | 246 | null 247 | 248 | 251 | 252 | null 253 | 254 | 257 | 258 | null 259 | 260 | 263 | 264 | null 265 | 266 | 269 | 270 | null 271 | 272 | 275 | 276 | null 277 | 278 | 281 | 282 | null 283 | 284 | 287 | 288 | null 289 | 290 |
297 | 298 | 1 299 | 300 | 303 | 304 | white 305 | 306 | 309 | 310 | plain 311 | 312 | 315 | 316 | man 317 | 318 | 321 | 322 | spring 323 | 324 | 327 | 328 | shirt 329 | 330 | 333 | 334 | short sleeves 335 | 336 | 339 | 340 | white 341 | 342 | 345 | 346 | plain 347 | 348 | 351 | 352 | man 353 | 354 | 357 | 358 | spring 359 | 360 | 363 | 364 | short pants 365 | 366 | 369 | 370 | pants 371 | 372 | 375 | 376 | standing 377 | 378 |
385 | 386 | 2 387 | 388 | 391 | 392 | black 393 | 394 | 397 | 398 | checker 399 | 400 | 403 | 404 | woman 405 | 406 | 409 | 410 | summer 411 | 412 | 415 | 416 | jumper 417 | 418 | 421 | 422 | long sleeves 423 | 424 | 427 | 428 | black 429 | 430 | 433 | 434 | checker 435 | 436 | 439 | 440 | woman 441 | 442 | 445 | 446 | summer 447 | 448 | 451 | 452 | long pants 453 | 454 | 457 | 458 | skirt 459 | 460 | 463 | 464 | sitting 465 | 466 |
473 | 474 | 3 475 | 476 | 479 | 480 | gray 481 | 482 | 485 | 486 | dotted 487 | 488 | 491 | 492 | 493 | 494 | 497 | 498 | autunm 499 | 500 | 503 | 504 | jacket 505 | 506 | 509 | 510 | no sleeves 511 | 512 | 515 | 516 | gray 517 | 518 | 521 | 522 | dotted 523 | 524 | 527 | 528 | 529 | 530 | 533 | 534 | autunm 535 | 536 | 539 | 540 | 541 | 542 | 545 | 546 | 547 | 548 | 551 | 552 | lying 553 | 554 |
561 | 562 | 4 563 | 564 | 567 | 568 | pink 569 | 570 | 573 | 574 | floral 575 | 576 | 579 | 580 | 581 | 582 | 585 | 586 | winter 587 | 588 | 591 | 592 | vest 593 | 594 | 597 | 598 | 599 | 600 | 603 | 604 | pink 605 | 606 | 609 | 610 | floral 611 | 612 | 615 | 616 | 617 | 618 | 621 | 622 | winter 623 | 624 | 627 | 628 | 629 | 630 | 633 | 634 | 635 | 636 | 639 | 640 | 641 | 642 |
649 | 650 | 5 651 | 652 | 655 | 656 | red 657 | 658 | 661 | 662 | striped 663 | 664 | 667 | 668 | 669 | 670 | 673 | 674 | 675 | 676 | 679 | 680 | parka 681 | 682 | 685 | 686 | 687 | 688 | 691 | 692 | red 693 | 694 | 697 | 698 | striped 699 | 700 | 703 | 704 | 705 | 706 | 709 | 710 | 711 | 712 | 715 | 716 | 717 | 718 | 721 | 722 | 723 | 724 | 727 | 728 | 729 | 730 |
737 | 738 | 6 739 | 740 | 743 | 744 | green 745 | 746 | 749 | 750 | mixed 751 | 752 | 755 | 756 | 757 | 758 | 761 | 762 | 763 | 764 | 767 | 768 | coat 769 | 770 | 773 | 774 | 775 | 776 | 779 | 780 | green 781 | 782 | 785 | 786 | mixed 787 | 788 | 791 | 792 | 793 | 794 | 797 | 798 | 799 | 800 | 803 | 804 | 805 | 806 | 809 | 810 | 811 | 812 | 815 | 816 | 817 | 818 |
825 | 826 | 7 827 | 828 | 831 | 832 | blue 833 | 834 | 837 | 838 | 839 | 840 | 843 | 844 | 845 | 846 | 849 | 850 | 851 | 852 | 855 | 856 | dress 857 | 858 | 861 | 862 | 863 | 864 | 867 | 868 | blue 869 | 870 | 873 | 874 | 875 | 876 | 879 | 880 | 881 | 882 | 885 | 886 | 887 | 888 | 891 | 892 | 893 | 894 | 897 | 898 | 899 | 900 | 903 | 904 | 905 | 906 |
913 | 914 | 8 915 | 916 | 919 | 920 | brown 921 | 922 | 925 | 926 | 927 | 928 | 931 | 932 | 933 | 934 | 937 | 938 | 939 | 940 | 943 | 944 | 945 | 946 | 949 | 950 | 951 | 952 | 955 | 956 | brown 957 | 958 | 961 | 962 | 963 | 964 | 967 | 968 | 969 | 970 | 973 | 974 | 975 | 976 | 979 | 980 | 981 | 982 | 985 | 986 | 987 | 988 | 991 | 992 | 993 | 994 |
1001 | 1002 | 9 1003 | 1004 | 1007 | 1008 | navy 1009 | 1010 | 1013 | 1014 | 1015 | 1016 | 1019 | 1020 | 1021 | 1022 | 1025 | 1026 | 1027 | 1028 | 1031 | 1032 | 1033 | 1034 | 1037 | 1038 | 1039 | 1040 | 1043 | 1044 | navy 1045 | 1046 | 1049 | 1050 | 1051 | 1052 | 1055 | 1056 | 1057 | 1058 | 1061 | 1062 | 1063 | 1064 | 1067 | 1068 | 1069 | 1070 | 1073 | 1074 | 1075 | 1076 | 1079 | 1080 | 1081 | 1082 |
1089 | 1090 | 10 1091 | 1092 | 1095 | 1096 | beige 1097 | 1098 | 1101 | 1102 | 1103 | 1104 | 1107 | 1108 | 1109 | 1110 | 1113 | 1114 | 1115 | 1116 | 1119 | 1120 | 1121 | 1122 | 1125 | 1126 | 1127 | 1128 | 1131 | 1132 | beige 1133 | 1134 | 1137 | 1138 | 1139 | 1140 | 1143 | 1144 | 1145 | 1146 | 1149 | 1150 | 1151 | 1152 | 1155 | 1156 | 1157 | 1158 | 1161 | 1162 | 1163 | 1164 | 1167 | 1168 | 1169 | 1170 |
1177 | 1178 | 11 1179 | 1180 | 1183 | 1184 | yellow 1185 | 1186 | 1189 | 1190 | 1191 | 1192 | 1195 | 1196 | 1197 | 1198 | 1201 | 1202 | 1203 | 1204 | 1207 | 1208 | 1209 | 1210 | 1213 | 1214 | 1215 | 1216 | 1219 | 1220 | yellow 1221 | 1222 | 1225 | 1226 | 1227 | 1228 | 1231 | 1232 | 1233 | 1234 | 1237 | 1238 | 1239 | 1240 | 1243 | 1244 | 1245 | 1246 | 1249 | 1250 | 1251 | 1252 | 1255 | 1256 | 1257 | 1258 |
1265 | 1266 | 12 1267 | 1268 | 1271 | 1272 | purple 1273 | 1274 | 1277 | 1278 | 1279 | 1280 | 1283 | 1284 | 1285 | 1286 | 1289 | 1290 | 1291 | 1292 | 1295 | 1296 | 1297 | 1298 | 1301 | 1302 | 1303 | 1304 | 1307 | 1308 | purple 1309 | 1310 | 1313 | 1314 | 1315 | 1316 | 1319 | 1320 | 1321 | 1322 | 1325 | 1326 | 1327 | 1328 | 1331 | 1332 | 1333 | 1334 | 1337 | 1338 | 1339 | 1340 | 1343 | 1344 | 1345 | 1346 |
1353 | 1354 | 13 1355 | 1356 | 1359 | 1360 | orange 1361 | 1362 | 1365 | 1366 | 1367 | 1368 | 1371 | 1372 | 1373 | 1374 | 1377 | 1378 | 1379 | 1380 | 1383 | 1384 | 1385 | 1386 | 1389 | 1390 | 1391 | 1392 | 1395 | 1396 | orange 1397 | 1398 | 1401 | 1402 | 1403 | 1404 | 1407 | 1408 | 1409 | 1410 | 1413 | 1414 | 1415 | 1416 | 1419 | 1420 | 1421 | 1422 | 1425 | 1426 | 1427 | 1428 | 1431 | 1432 | 1433 | 1434 |
1441 | 1442 | 14 1443 | 1444 | 1447 | 1448 | mixed 1449 | 1450 | 1453 | 1454 | 1455 | 1456 | 1459 | 1460 | 1461 | 1462 | 1465 | 1466 | 1467 | 1468 | 1471 | 1472 | 1473 | 1474 | 1477 | 1478 | 1479 | 1480 | 1483 | 1484 | mixed 1485 | 1486 | 1489 | 1490 | 1491 | 1492 | 1495 | 1496 | 1497 | 1498 | 1501 | 1502 | 1503 | 1504 | 1507 | 1508 | 1509 | 1510 | 1513 | 1514 | 1515 | 1516 | 1519 | 1520 | 1521 | 1522 |
1529 | 1530 | ## A nice example 1531 | ![Nice example](nice_example.png?raw=true "Title") 1532 | -------------------------------------------------------------------------------- /bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import random 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | def confidence_filter(result, confidence): 10 | conf_mask = (result[:,:,4] > confidence).float().unsqueeze(2) 11 | result = result*conf_mask 12 | 13 | return result 14 | 15 | def confidence_filter_cls(result, confidence): 16 | max_scores = torch.max(result[:,:,5:25], 2)[0] 17 | res = torch.cat((result, max_scores),2) 18 | print(res.shape) 19 | 20 | 21 | cond_1 = (res[:,:,4] > confidence).float() 22 | cond_2 = (res[:,:,25] > 0.995).float() 23 | 24 | conf = cond_1 + cond_2 25 | conf = torch.clamp(conf, 0.0, 1.0) 26 | conf = conf.unsqueeze(2) 27 | result = result*conf 28 | return result 29 | 30 | 31 | 32 | def get_abs_coord(box): 33 | box[2], box[3] = abs(box[2]), abs(box[3]) 34 | x1 = (box[0] - box[2]/2) - 1 35 | y1 = (box[1] - box[3]/2) - 1 36 | x2 = (box[0] + box[2]/2) - 1 37 | y2 = (box[1] + box[3]/2) - 1 38 | return x1, y1, x2, y2 39 | 40 | 41 | 42 | def sanity_fix(box): 43 | if (box[0] > box[2]): 44 | box[0], box[2] = box[2], box[0] 45 | 46 | if (box[1] > box[3]): 47 | box[1], box[3] = box[3], box[1] 48 | 49 | return box 50 | 51 | def bbox_iou(box1, box2, device): 52 | """ 53 | Returns the IoU of two bounding boxes 54 | 55 | 56 | """ 57 | #Get the coordinates of bounding boxes 58 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] 59 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] 60 | 61 | #get the corrdinates of the intersection rectangle 62 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 63 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 64 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 65 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 66 | 67 | #Intersection area 68 | 69 | inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape).to(device))*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) 70 | 71 | #Union Area 72 | b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) 73 | b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) 74 | 75 | iou = inter_area / (b1_area + b2_area - inter_area) 76 | 77 | return iou 78 | 79 | def bbox_iou2(box1, box2, device): 80 | """ 81 | Returns the IoU of two bounding boxes 82 | 83 | 84 | """ 85 | #Get the coordinates of bounding boxes 86 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] 87 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] 88 | 89 | #get the corrdinates of the intersection rectangle 90 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 91 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 92 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 93 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 94 | 95 | #Intersection area 96 | inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape).to(device))*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).to(device)) 97 | 98 | #Union Area 99 | b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) 100 | b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) 101 | 102 | iou = inter_area / b2_area 103 | 104 | return iou 105 | 106 | 107 | def pred_corner_coord(prediction): 108 | #Get indices of non-zero confidence bboxes 109 | ind_nz = torch.nonzero(prediction[:,:,4]).transpose(0,1).contiguous() 110 | 111 | box = prediction[ind_nz[0], ind_nz[1]] 112 | 113 | 114 | box_a = box.new(box.shape) 115 | box_a[:,0] = (box[:,0] - box[:,2]/2) 116 | box_a[:,1] = (box[:,1] - box[:,3]/2) 117 | box_a[:,2] = (box[:,0] + box[:,2]/2) 118 | box_a[:,3] = (box[:,1] + box[:,3]/2) 119 | box[:,:4] = box_a[:,:4] 120 | 121 | prediction[ind_nz[0], ind_nz[1]] = box 122 | 123 | return prediction 124 | 125 | 126 | 127 | 128 | def write(x, batches, results, colors, classes): 129 | c1 = tuple(x[1:3].int()) 130 | c2 = tuple(x[3:5].int()) 131 | img = results[int(x[0])] 132 | cls = int(x[-1]) 133 | label = "{0}".format(classes[cls]) 134 | color = random.choice(colors) 135 | cv2.rectangle(img, c1, c2,color, 1) 136 | t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0] 137 | c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4 138 | cv2.rectangle(img, c1, c2,color, -1) 139 | cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1); 140 | return img 141 | -------------------------------------------------------------------------------- /cam_demo.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | import torch 3 | from collections import Counter 4 | import numpy as np 5 | import argparse 6 | import pickle 7 | import os 8 | import time 9 | from torchvision import transforms 10 | from model import EncoderClothing, DecoderClothing 11 | from darknet import Darknet 12 | from PIL import Image 13 | from util import * 14 | import cv2 15 | import pickle as pkl 16 | from preprocess import prep_image2 17 | 18 | import sys 19 | if sys.version_info >= (3,0): 20 | from roi_align.roi_align import RoIAlign 21 | else : 22 | from roi_align import RoIAlign 23 | 24 | # Device configuration 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | 27 | 28 | color = { 'white':0, 'black':1, 'grey':2, 'pink':3, 'red':4, 29 | 'green':5, 'blue':6, 'brown':7, 'navy':8, 'beige':9, 'yellow':10, 'purple':11, 30 | 'orange':12, 'others':13 } 31 | 32 | pattern = { 'single':0, 'checker':1, 'dotted':2, 'floral':3, 'striped':4, 'others':5 } 33 | 34 | gender = { 'man':0, 'woman':1 } 35 | 36 | season = { 'spring':0, 'summer':1, 'autumn':2, 'winter':3 } 37 | 38 | c_class = { 'shirt':0, 'jumper':1, 'jacket':2, 'vest':3, 'coat':4, 39 | 'dress':5, 'pants':6, 'skirt':7 } 40 | 41 | sleeves = { 'long':0, 'short':1, 'no':2 } 42 | 43 | a_class = { 'scarf':0, 'cane':1, 'bag':2, 'shoes':3, 'hat':4, 'face':5, 'glasses':6} 44 | 45 | colors_a = ["", "white", "black", "gray", "pink", "red", "green", "blue", "brown", "navy", "beige", \ 46 | "yellow", "purple", "orange", "mixed color"] #0-14 47 | pattern_a = ["", "no pattern", "checker", "dotted", "floral", "striped", "custom pattern"] #0-6 48 | gender_a = ["", "man", "woman"] #0-2 49 | season_a = ["", "spring", "summer", "autumn", "winter"] #0-4 50 | upper_t_a = ["", "shirt", "jumper", "jacket", "vest", "parka", "coat", "dress"]#0-7 51 | u_sleeves_a = ["", "short sleeves", "long sleeves", "no sleeves"]#0-3 52 | 53 | lower_t_a = ["", "pants", "skirt"]#0-2 54 | l_sleeves_a = ["", "short", "long"]#0-2 55 | leg_pose_a = ["", "standing", "sitting", "lying"]#0-3 56 | 57 | glasses_a = ["", "glasses"] 58 | 59 | attribute_pool = [colors_a, pattern_a, gender_a, season_a, upper_t_a, u_sleeves_a, \ 60 | colors_a, pattern_a, gender_a, season_a, lower_t_a, l_sleeves_a, leg_pose_a] 61 | 62 | def counter_parts1(sentence): 63 | 64 | color_bin = np.zeros(14) 65 | pattern_bin = np.zeros(6) 66 | gender_bin = np.zeros(2) 67 | season_bin = np.zeros(4) 68 | class_bin = np.zeros(8) 69 | sleeves_bin = np.zeros(3) 70 | 71 | token = sentence.split() 72 | for word in token: 73 | c_pos = color.get(word) 74 | p_pos = pattern.get(word) 75 | g_pos = gender.get(word) 76 | s_pos = season.get(word) 77 | cl_pos = c_class.get(word) 78 | sl_pos = sleeves.get(word) 79 | 80 | if c_pos is not None: 81 | color_bin[c_pos] = 1 82 | elif p_pos is not None: 83 | pattern_bin[p_pos] = 1 84 | elif g_pos is not None: 85 | gender_bin[g_pos] = 1 86 | elif s_pos is not None: 87 | season_bin[s_pos] = 1 88 | elif cl_pos is not None: 89 | class_bin[cl_pos] = 1 90 | elif sl_pos is not None: 91 | sleeves_bin[sl_pos] = 1 92 | 93 | return color_bin, pattern_bin, gender_bin, season_bin, class_bin, sleeves_bin 94 | 95 | def counter_parts2(sentence): 96 | 97 | color_bin = np.zeros(14) 98 | pattern_bin = np.zeros(6) 99 | gender_bin = np.zeros(2) 100 | season_bin = np.zeros(4) 101 | class_bin = np.zeros(8) 102 | sleeves_bin = np.zeros(3) 103 | 104 | token = sentence.split() 105 | for word in token: 106 | c_pos = color.get(word) 107 | p_pos = pattern.get(word) 108 | g_pos = gender.get(word) 109 | s_pos = season.get(word) 110 | cl_pos = c_class.get(word) 111 | sl_pos = sleeves.get(word) 112 | 113 | if c_pos is not None: 114 | color_bin[c_pos] = 1 115 | elif p_pos is not None: 116 | pattern_bin[p_pos] = 1 117 | elif g_pos is not None: 118 | gender_bin[g_pos] = 1 119 | elif s_pos is not None: 120 | season_bin[s_pos] = 1 121 | elif cl_pos is not None: 122 | class_bin[cl_pos] = 1 123 | elif sl_pos is not None: 124 | sleeves_bin[sl_pos] = 1 125 | 126 | return color_bin, pattern_bin, gender_bin, season_bin, class_bin, sleeves_bin 127 | 128 | 129 | def keep_majority(sentence, color_stream, pattern_stream, gender_stream, season_stream, class_stream, sleeves_stream): 130 | 131 | tokens = sentence.split('and') 132 | color_bin, pattern_bin, gender_bin, season_bin, class_bin, sleeves_bin = counter_parts1(tokens[0]) # for tokens[0] ==> upper 133 | if tokens.__len__() == 2: 134 | color_bin, pattern_bin, gender_bin, season_bin, class_bin, sleeves_bin = counter_parts1(tokens[1]) # for tokens[1] ==> lower 135 | 136 | if color_stream.__len__() < 50: 137 | color_stream.insert(0, color_bin) 138 | pattern_stream.insert(0, pattern_bin) 139 | gender_stream.insert(0, gender_bin) 140 | season_stream.insert(0, season_bin) 141 | class_stream.insert(0, class_bin) 142 | sleeves_stream.insert(0, sleeves_bin) 143 | return '' 144 | else: 145 | color_stream.pop() 146 | pattern_stream.pop() 147 | gender_stream.pop() 148 | season_stream.pop() 149 | class_stream.pop() 150 | sleeves_stream.pop() 151 | 152 | color_stream.insert(0, color_bin) 153 | pattern_stream.insert(0, pattern_bin) 154 | gender_stream.insert(0, gender_bin) 155 | season_stream.insert(0, season_bin) 156 | class_stream.insert(0, class_bin) 157 | sleeves_stream.insert(0, sleeves_bin) 158 | 159 | #print(sum(season_stream)) 160 | 161 | c_max = np.argmax(sum(color_stream)) 162 | p_max = np.argmax(sum(pattern_stream)) 163 | g_max = np.argmax(sum(gender_stream)) 164 | s_max = np.argmax(sum(season_stream)) 165 | cl_max = np.argmax(sum(class_stream)) 166 | sl_max = np.argmax(sum(sleeves_stream)) 167 | 168 | """ new_t = list() 169 | new_t.append([k for k, v in color.items() if v == c_max][0]) 170 | if [k for k, v in pattern.items() if v == p_max][0] != 'single': 171 | new_t.append([k for k, v in pattern.items() if v == p_max][0]) 172 | 173 | new_t.append([k for k, v in gender.items() if v == g_max][0]) 174 | new_t.append([k for k, v in season.items() if v == s_max][0]) 175 | new_t.append([k for k, v in c_class.items() if v == cl_max][0]) 176 | new_t.append('with') 177 | new_t.append([k for k, v in sleeves.items() if v == sl_max][0]) 178 | new_t.append('sleeves') 179 | 180 | new_sentence = '======> ' + ' '.join(new_t) """ 181 | 182 | new_t = list() 183 | #new_t.append([k for k, v in c_class.items() if v == cl_max][0]) 184 | #new_t.append(':') 185 | new_t.append([k for k, v in color.items() if v == c_max][0]) 186 | if [k for k, v in pattern.items() if v == p_max][0] != 'single': 187 | new_t.append([k for k, v in pattern.items() if v == p_max][0]) 188 | 189 | new_t.append([k for k, v in gender.items() if v == g_max][0]) 190 | new_t.append([k for k, v in season.items() if v == s_max][0]) 191 | new_t.append([k for k, v in c_class.items() if v == cl_max][0]) 192 | new_t.append('with') 193 | new_t.append([k for k, v in sleeves.items() if v == sl_max][0]) 194 | new_t.append('sleeves') 195 | 196 | #new_sentence = '======> ' + ' '.join(new_t) 197 | 198 | new_sentence = ' '.join(new_t) 199 | 200 | #print(sum(sleeves_stream))q 201 | #print(new_sentence) 202 | #sys.stdout.write(new_sentence + '\r\r') 203 | #sys.stdout.flush() 204 | return new_sentence 205 | 206 | def main(args): 207 | # Image preprocessing 208 | transform = transforms.Compose([ 209 | transforms.ToTensor()]) 210 | 211 | 212 | num_classes = 80 213 | yolov3 = Darknet(args.cfg_file) 214 | yolov3.load_weights(args.weights_file) 215 | yolov3.net_info["height"] = args.reso 216 | inp_dim = int(yolov3.net_info["height"]) 217 | assert inp_dim % 32 == 0 218 | assert inp_dim > 32 219 | print("yolo-v3 network successfully loaded") 220 | 221 | attribute_size = [15, 7, 3, 5, 8, 4, 15, 7, 3, 5, 3, 3, 4] 222 | 223 | encoder = EncoderClothing(args.embed_size, device, args.roi_size, attribute_size) 224 | 225 | yolov3.to(device) 226 | encoder.to(device) 227 | 228 | yolov3.eval() 229 | encoder.eval() 230 | 231 | encoder.load_state_dict(torch.load(args.encoder_path)) 232 | 233 | #cap = cv2.VideoCapture('demo2.mp4') 234 | 235 | cap = cv2.VideoCapture(0) 236 | assert cap.isOpened(), 'Cannot capture source' 237 | 238 | frames = 0 239 | start = time.time() 240 | 241 | counter = Counter() 242 | color_stream = list() 243 | pattern_stream = list() 244 | gender_stream = list() 245 | season_stream = list() 246 | class_stream = list() 247 | sleeves_stream = list() 248 | 249 | ret, frame = cap.read() 250 | if ret: 251 | 252 | image, orig_img, dim = prep_image2(frame, inp_dim) 253 | im_dim = torch.FloatTensor(dim).repeat(1,2) 254 | 255 | image_tensor = image.to(device) 256 | detections = yolov3(image_tensor, device, True) 257 | 258 | 259 | os.system('clear') 260 | cv2.imshow("frame", orig_img) 261 | cv2.moveWindow("frame", 50, 50) 262 | text_img = np.zeros((200, 1750, 3)) 263 | cv2.imshow("text", text_img) 264 | cv2.moveWindow("text", 50, dim[1]+110) 265 | 266 | while cap.isOpened(): 267 | 268 | ret, frame = cap.read() 269 | #### ret, frame = ros_message_cam_image() 270 | if ret: 271 | 272 | image, orig_img, dim = prep_image2(frame, inp_dim) 273 | im_dim = torch.FloatTensor(dim).repeat(1,2) 274 | 275 | image_tensor = image.to(device) 276 | im_dim = im_dim.to(device) 277 | 278 | # Generate an caption from the image 279 | detections = yolov3(image_tensor, device, True) # prediction mode for yolo-v3 280 | detections = write_results(detections, args.confidence, num_classes, device, nms=True, nms_conf=args.nms_thresh) 281 | 282 | #### detections = ros_message_rois() 283 | #### ros_rois --> [0,0, x1, y1, x2, y2] 284 | 285 | # original image dimension --> im_dim 286 | 287 | #view_image(detections) 288 | text_img = np.zeros((200, 1750, 3)) 289 | 290 | if type(detections) != int: 291 | if detections.shape[0]: 292 | bboxs = detections[:, 1:5].clone() 293 | 294 | im_dim = im_dim.repeat(detections.shape[0], 1) 295 | scaling_factor = torch.min(inp_dim/im_dim, 1)[0].view(-1, 1) 296 | 297 | detections[:, [1, 3]] -= (inp_dim - scaling_factor*im_dim[:, 0].view(-1, 1))/2 298 | detections[:, [2, 4]] -= (inp_dim - scaling_factor*im_dim[:, 1].view(-1, 1))/2 299 | 300 | detections[:, 1:5] /= scaling_factor 301 | 302 | small_object_ratio = torch.FloatTensor(detections.shape[0]) 303 | 304 | for i in range(detections.shape[0]): 305 | detections[i, [1, 3]] = torch.clamp(detections[i, [1, 3]], 0.0, im_dim[i, 0]) 306 | detections[i, [2, 4]] = torch.clamp(detections[i, [2, 4]], 0.0, im_dim[i, 1]) 307 | 308 | object_area = (detections[i, 3] - detections[i, 1])*(detections[i, 4] - detections[i, 2]) 309 | orig_img_area = im_dim[i, 0]*im_dim[i, 1] 310 | small_object_ratio[i] = object_area/orig_img_area 311 | 312 | detections = detections[small_object_ratio > 0.05] 313 | im_dim = im_dim[small_object_ratio > 0.05] 314 | 315 | if detections.size(0) > 0: 316 | feature = yolov3.get_feature() 317 | feature = feature.repeat(detections.size(0), 1, 1, 1) 318 | 319 | orig_img_dim = im_dim[:, 1:] 320 | orig_img_dim = orig_img_dim.repeat(1, 2) 321 | 322 | scaling_val = 16 323 | 324 | bboxs /= scaling_val 325 | bboxs = bboxs.round() 326 | bboxs_index = torch.arange(bboxs.size(0), dtype=torch.int) 327 | bboxs_index = bboxs_index.to(device) 328 | bboxs = bboxs.to(device) 329 | 330 | roi_align = RoIAlign(args.roi_size, args.roi_size, transform_fpcoor=True).to(device) 331 | roi_features = roi_align(feature, bboxs, bboxs_index) 332 | 333 | outputs = encoder(roi_features) 334 | 335 | for i in range(detections.shape[0]): 336 | 337 | sampled_caption = [] 338 | #attr_fc = outputs[] 339 | for j in range(len(outputs)): 340 | max_index = torch.max(outputs[j][i].data, 0)[1] 341 | word = attribute_pool[j][max_index] 342 | sampled_caption.append(word) 343 | 344 | c11 = sampled_caption[11] 345 | sampled_caption[11] = sampled_caption[10] 346 | sampled_caption[10] = c11 347 | 348 | sentence = ' '.join(sampled_caption) 349 | 350 | sys.stdout.write(' ' + '\r') 351 | 352 | sys.stdout.write(sentence + ' '+ '\r') 353 | sys.stdout.flush() 354 | write(detections[i], orig_img, sentence, i+1, coco_classes, colors) 355 | 356 | cv2.putText(text_img, sentence, (0, i*40+35), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 1 ) 357 | 358 | cv2.imshow("frame", orig_img) 359 | cv2.imshow("text", text_img) 360 | 361 | key = cv2.waitKey(1) 362 | if key & 0xFF == ord('q'): 363 | break 364 | if key & 0xFF == ord('w'): 365 | wait(0) 366 | if key & 0xFF == ord('s'): 367 | continue 368 | frames += 1 369 | #print("FPS of the video is {:5.2f}".format( frames / (time.time() - start))) 370 | 371 | else: 372 | break 373 | 374 | if __name__ == '__main__': 375 | parser = argparse.ArgumentParser() 376 | parser.add_argument('--encoder_path', type=str, default='encoder-12-1170.ckpt', help='path for trained encoder') 377 | 378 | # Encoder - Yolo-v3 parameters 379 | parser.add_argument('--confidence', type=float, default = 0.5, help = 'Object Confidence to filter predictions') 380 | parser.add_argument('--nms_thresh', type=float , default = 0.4, help = 'NMS Threshhold') 381 | parser.add_argument('--cfg_file', type = str, default = 'cfg/yolov3.cfg', help ='Config file') 382 | parser.add_argument('--weights_file', type = str, default = 'yolov3.weights', help = 'weightsfile') 383 | parser.add_argument('--reso', type=str, default = '416', help = 'Input resolution of the network. Increase to increase accuracy. Decrease to increase speed') 384 | parser.add_argument('--scales', type=str, default = '1,2,3', help = 'Scales to use for detection') 385 | 386 | # Model parameters (should be same as paramters in train.py) 387 | parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors') 388 | parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states') 389 | parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm') 390 | parser.add_argument('--roi_size', type=int , default=13) 391 | args = parser.parse_args() 392 | 393 | coco_classes = load_classes('data/coco.names') 394 | colors = pkl.load(open("pallete2", "rb")) 395 | 396 | main(args) 397 | -------------------------------------------------------------------------------- /cfg/yolov3.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | # Testing 3 | batch=1 4 | subdivisions=1 5 | # Training 6 | # batch=64 7 | # subdivisions=16 8 | width= 320 9 | height = 320 10 | channels=3 11 | momentum=0.9 12 | decay=0.0005 13 | angle=0 14 | saturation = 1.5 15 | exposure = 1.5 16 | hue=.1 17 | 18 | learning_rate=0.001 19 | burn_in=1000 20 | max_batches = 500200 21 | policy=steps 22 | steps=400000,450000 23 | scales=.1,.1 24 | 25 | [convolutional] 26 | batch_normalize=1 27 | filters=32 28 | size=3 29 | stride=1 30 | pad=1 31 | activation=leaky 32 | 33 | # Downsample 34 | 35 | [convolutional] 36 | batch_normalize=1 37 | filters=64 38 | size=3 39 | stride=2 40 | pad=1 41 | activation=leaky 42 | 43 | [convolutional] 44 | batch_normalize=1 45 | filters=32 46 | size=1 47 | stride=1 48 | pad=1 49 | activation=leaky 50 | 51 | [convolutional] 52 | batch_normalize=1 53 | filters=64 54 | size=3 55 | stride=1 56 | pad=1 57 | activation=leaky 58 | 59 | [shortcut] 60 | from=-3 61 | activation=linear 62 | 63 | # Downsample 64 | 65 | [convolutional] 66 | batch_normalize=1 67 | filters=128 68 | size=3 69 | stride=2 70 | pad=1 71 | activation=leaky 72 | 73 | [convolutional] 74 | batch_normalize=1 75 | filters=64 76 | size=1 77 | stride=1 78 | pad=1 79 | activation=leaky 80 | 81 | [convolutional] 82 | batch_normalize=1 83 | filters=128 84 | size=3 85 | stride=1 86 | pad=1 87 | activation=leaky 88 | 89 | [shortcut] 90 | from=-3 91 | activation=linear 92 | 93 | [convolutional] 94 | batch_normalize=1 95 | filters=64 96 | size=1 97 | stride=1 98 | pad=1 99 | activation=leaky 100 | 101 | [convolutional] 102 | batch_normalize=1 103 | filters=128 104 | size=3 105 | stride=1 106 | pad=1 107 | activation=leaky 108 | 109 | [shortcut] 110 | from=-3 111 | activation=linear 112 | 113 | # Downsample 114 | 115 | [convolutional] 116 | batch_normalize=1 117 | filters=256 118 | size=3 119 | stride=2 120 | pad=1 121 | activation=leaky 122 | 123 | [convolutional] 124 | batch_normalize=1 125 | filters=128 126 | size=1 127 | stride=1 128 | pad=1 129 | activation=leaky 130 | 131 | [convolutional] 132 | batch_normalize=1 133 | filters=256 134 | size=3 135 | stride=1 136 | pad=1 137 | activation=leaky 138 | 139 | [shortcut] 140 | from=-3 141 | activation=linear 142 | 143 | [convolutional] 144 | batch_normalize=1 145 | filters=128 146 | size=1 147 | stride=1 148 | pad=1 149 | activation=leaky 150 | 151 | [convolutional] 152 | batch_normalize=1 153 | filters=256 154 | size=3 155 | stride=1 156 | pad=1 157 | activation=leaky 158 | 159 | [shortcut] 160 | from=-3 161 | activation=linear 162 | 163 | [convolutional] 164 | batch_normalize=1 165 | filters=128 166 | size=1 167 | stride=1 168 | pad=1 169 | activation=leaky 170 | 171 | [convolutional] 172 | batch_normalize=1 173 | filters=256 174 | size=3 175 | stride=1 176 | pad=1 177 | activation=leaky 178 | 179 | [shortcut] 180 | from=-3 181 | activation=linear 182 | 183 | [convolutional] 184 | batch_normalize=1 185 | filters=128 186 | size=1 187 | stride=1 188 | pad=1 189 | activation=leaky 190 | 191 | [convolutional] 192 | batch_normalize=1 193 | filters=256 194 | size=3 195 | stride=1 196 | pad=1 197 | activation=leaky 198 | 199 | [shortcut] 200 | from=-3 201 | activation=linear 202 | 203 | 204 | [convolutional] 205 | batch_normalize=1 206 | filters=128 207 | size=1 208 | stride=1 209 | pad=1 210 | activation=leaky 211 | 212 | [convolutional] 213 | batch_normalize=1 214 | filters=256 215 | size=3 216 | stride=1 217 | pad=1 218 | activation=leaky 219 | 220 | [shortcut] 221 | from=-3 222 | activation=linear 223 | 224 | [convolutional] 225 | batch_normalize=1 226 | filters=128 227 | size=1 228 | stride=1 229 | pad=1 230 | activation=leaky 231 | 232 | [convolutional] 233 | batch_normalize=1 234 | filters=256 235 | size=3 236 | stride=1 237 | pad=1 238 | activation=leaky 239 | 240 | [shortcut] 241 | from=-3 242 | activation=linear 243 | 244 | [convolutional] 245 | batch_normalize=1 246 | filters=128 247 | size=1 248 | stride=1 249 | pad=1 250 | activation=leaky 251 | 252 | [convolutional] 253 | batch_normalize=1 254 | filters=256 255 | size=3 256 | stride=1 257 | pad=1 258 | activation=leaky 259 | 260 | [shortcut] 261 | from=-3 262 | activation=linear 263 | 264 | [convolutional] 265 | batch_normalize=1 266 | filters=128 267 | size=1 268 | stride=1 269 | pad=1 270 | activation=leaky 271 | 272 | [convolutional] 273 | batch_normalize=1 274 | filters=256 275 | size=3 276 | stride=1 277 | pad=1 278 | activation=leaky 279 | 280 | [shortcut] 281 | from=-3 282 | activation=linear 283 | 284 | # Downsample 285 | 286 | [convolutional] 287 | batch_normalize=1 288 | filters=512 289 | size=3 290 | stride=2 291 | pad=1 292 | activation=leaky 293 | 294 | [convolutional] 295 | batch_normalize=1 296 | filters=256 297 | size=1 298 | stride=1 299 | pad=1 300 | activation=leaky 301 | 302 | [convolutional] 303 | batch_normalize=1 304 | filters=512 305 | size=3 306 | stride=1 307 | pad=1 308 | activation=leaky 309 | 310 | [shortcut] 311 | from=-3 312 | activation=linear 313 | 314 | 315 | [convolutional] 316 | batch_normalize=1 317 | filters=256 318 | size=1 319 | stride=1 320 | pad=1 321 | activation=leaky 322 | 323 | [convolutional] 324 | batch_normalize=1 325 | filters=512 326 | size=3 327 | stride=1 328 | pad=1 329 | activation=leaky 330 | 331 | [shortcut] 332 | from=-3 333 | activation=linear 334 | 335 | 336 | [convolutional] 337 | batch_normalize=1 338 | filters=256 339 | size=1 340 | stride=1 341 | pad=1 342 | activation=leaky 343 | 344 | [convolutional] 345 | batch_normalize=1 346 | filters=512 347 | size=3 348 | stride=1 349 | pad=1 350 | activation=leaky 351 | 352 | [shortcut] 353 | from=-3 354 | activation=linear 355 | 356 | 357 | [convolutional] 358 | batch_normalize=1 359 | filters=256 360 | size=1 361 | stride=1 362 | pad=1 363 | activation=leaky 364 | 365 | [convolutional] 366 | batch_normalize=1 367 | filters=512 368 | size=3 369 | stride=1 370 | pad=1 371 | activation=leaky 372 | 373 | [shortcut] 374 | from=-3 375 | activation=linear 376 | 377 | [convolutional] 378 | batch_normalize=1 379 | filters=256 380 | size=1 381 | stride=1 382 | pad=1 383 | activation=leaky 384 | 385 | [convolutional] 386 | batch_normalize=1 387 | filters=512 388 | size=3 389 | stride=1 390 | pad=1 391 | activation=leaky 392 | 393 | [shortcut] 394 | from=-3 395 | activation=linear 396 | 397 | 398 | [convolutional] 399 | batch_normalize=1 400 | filters=256 401 | size=1 402 | stride=1 403 | pad=1 404 | activation=leaky 405 | 406 | [convolutional] 407 | batch_normalize=1 408 | filters=512 409 | size=3 410 | stride=1 411 | pad=1 412 | activation=leaky 413 | 414 | [shortcut] 415 | from=-3 416 | activation=linear 417 | 418 | 419 | [convolutional] 420 | batch_normalize=1 421 | filters=256 422 | size=1 423 | stride=1 424 | pad=1 425 | activation=leaky 426 | 427 | [convolutional] 428 | batch_normalize=1 429 | filters=512 430 | size=3 431 | stride=1 432 | pad=1 433 | activation=leaky 434 | 435 | [shortcut] 436 | from=-3 437 | activation=linear 438 | 439 | [convolutional] 440 | batch_normalize=1 441 | filters=256 442 | size=1 443 | stride=1 444 | pad=1 445 | activation=leaky 446 | 447 | [convolutional] 448 | batch_normalize=1 449 | filters=512 450 | size=3 451 | stride=1 452 | pad=1 453 | activation=leaky 454 | 455 | [shortcut] 456 | from=-3 457 | activation=linear 458 | 459 | # Downsample 460 | 461 | [convolutional] 462 | batch_normalize=1 463 | filters=1024 464 | size=3 465 | stride=2 466 | pad=1 467 | activation=leaky 468 | 469 | [convolutional] 470 | batch_normalize=1 471 | filters=512 472 | size=1 473 | stride=1 474 | pad=1 475 | activation=leaky 476 | 477 | [convolutional] 478 | batch_normalize=1 479 | filters=1024 480 | size=3 481 | stride=1 482 | pad=1 483 | activation=leaky 484 | 485 | [shortcut] 486 | from=-3 487 | activation=linear 488 | 489 | [convolutional] 490 | batch_normalize=1 491 | filters=512 492 | size=1 493 | stride=1 494 | pad=1 495 | activation=leaky 496 | 497 | [convolutional] 498 | batch_normalize=1 499 | filters=1024 500 | size=3 501 | stride=1 502 | pad=1 503 | activation=leaky 504 | 505 | [shortcut] 506 | from=-3 507 | activation=linear 508 | 509 | [convolutional] 510 | batch_normalize=1 511 | filters=512 512 | size=1 513 | stride=1 514 | pad=1 515 | activation=leaky 516 | 517 | [convolutional] 518 | batch_normalize=1 519 | filters=1024 520 | size=3 521 | stride=1 522 | pad=1 523 | activation=leaky 524 | 525 | [shortcut] 526 | from=-3 527 | activation=linear 528 | 529 | [convolutional] 530 | batch_normalize=1 531 | filters=512 532 | size=1 533 | stride=1 534 | pad=1 535 | activation=leaky 536 | 537 | [convolutional] 538 | batch_normalize=1 539 | filters=1024 540 | size=3 541 | stride=1 542 | pad=1 543 | activation=leaky 544 | 545 | [shortcut] 546 | from=-3 547 | activation=linear 548 | 549 | ###################### 550 | 551 | [convolutional] 552 | batch_normalize=1 553 | filters=512 554 | size=1 555 | stride=1 556 | pad=1 557 | activation=leaky 558 | 559 | [convolutional] 560 | batch_normalize=1 561 | size=3 562 | stride=1 563 | pad=1 564 | filters=1024 565 | activation=leaky 566 | 567 | [convolutional] 568 | batch_normalize=1 569 | filters=512 570 | size=1 571 | stride=1 572 | pad=1 573 | activation=leaky 574 | 575 | [convolutional] 576 | batch_normalize=1 577 | size=3 578 | stride=1 579 | pad=1 580 | filters=1024 581 | activation=leaky 582 | 583 | [convolutional] 584 | batch_normalize=1 585 | filters=512 586 | size=1 587 | stride=1 588 | pad=1 589 | activation=leaky 590 | 591 | [convolutional] 592 | batch_normalize=1 593 | size=3 594 | stride=1 595 | pad=1 596 | filters=1024 597 | activation=leaky 598 | 599 | [convolutional] 600 | size=1 601 | stride=1 602 | pad=1 603 | filters=255 604 | activation=linear 605 | 606 | 607 | [yolo] 608 | mask = 6,7,8 609 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 610 | classes=80 611 | num=9 612 | jitter=.3 613 | ignore_thresh = .5 614 | truth_thresh = 1 615 | random=1 616 | 617 | 618 | [route] 619 | layers = -4 620 | 621 | [convolutional] 622 | batch_normalize=1 623 | filters=256 624 | size=1 625 | stride=1 626 | pad=1 627 | activation=leaky 628 | 629 | [upsample] 630 | stride=2 631 | 632 | [route] 633 | layers = -1, 61 634 | 635 | 636 | 637 | [convolutional] 638 | batch_normalize=1 639 | filters=256 640 | size=1 641 | stride=1 642 | pad=1 643 | activation=leaky 644 | 645 | [convolutional] 646 | batch_normalize=1 647 | size=3 648 | stride=1 649 | pad=1 650 | filters=512 651 | activation=leaky 652 | 653 | [convolutional] 654 | batch_normalize=1 655 | filters=256 656 | size=1 657 | stride=1 658 | pad=1 659 | activation=leaky 660 | 661 | [convolutional] 662 | batch_normalize=1 663 | size=3 664 | stride=1 665 | pad=1 666 | filters=512 667 | activation=leaky 668 | 669 | [convolutional] 670 | batch_normalize=1 671 | filters=256 672 | size=1 673 | stride=1 674 | pad=1 675 | activation=leaky 676 | 677 | [convolutional] 678 | batch_normalize=1 679 | size=3 680 | stride=1 681 | pad=1 682 | filters=512 683 | activation=leaky 684 | 685 | [convolutional] 686 | size=1 687 | stride=1 688 | pad=1 689 | filters=255 690 | activation=linear 691 | 692 | 693 | [yolo] 694 | mask = 3,4,5 695 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 696 | classes=80 697 | num=9 698 | jitter=.3 699 | ignore_thresh = .5 700 | truth_thresh = 1 701 | random=1 702 | 703 | 704 | 705 | [route] 706 | layers = -4 707 | 708 | [convolutional] 709 | batch_normalize=1 710 | filters=128 711 | size=1 712 | stride=1 713 | pad=1 714 | activation=leaky 715 | 716 | [upsample] 717 | stride=2 718 | 719 | [route] 720 | layers = -1, 36 721 | 722 | 723 | 724 | [convolutional] 725 | batch_normalize=1 726 | filters=128 727 | size=1 728 | stride=1 729 | pad=1 730 | activation=leaky 731 | 732 | [convolutional] 733 | batch_normalize=1 734 | size=3 735 | stride=1 736 | pad=1 737 | filters=256 738 | activation=leaky 739 | 740 | [convolutional] 741 | batch_normalize=1 742 | filters=128 743 | size=1 744 | stride=1 745 | pad=1 746 | activation=leaky 747 | 748 | [convolutional] 749 | batch_normalize=1 750 | size=3 751 | stride=1 752 | pad=1 753 | filters=256 754 | activation=leaky 755 | 756 | [convolutional] 757 | batch_normalize=1 758 | filters=128 759 | size=1 760 | stride=1 761 | pad=1 762 | activation=leaky 763 | 764 | [convolutional] 765 | batch_normalize=1 766 | size=3 767 | stride=1 768 | pad=1 769 | filters=256 770 | activation=leaky 771 | 772 | [convolutional] 773 | size=1 774 | stride=1 775 | pad=1 776 | filters=255 777 | activation=linear 778 | 779 | 780 | [yolo] 781 | mask = 0,1,2 782 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 783 | classes=80 784 | num=9 785 | jitter=.3 786 | ignore_thresh = .5 787 | truth_thresh = 1 788 | random=1 789 | 790 | -------------------------------------------------------------------------------- /darknet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | from util import count_parameters as count 11 | from util import convert2cpu as cpu 12 | from util import predict_transform 13 | 14 | #from roi_align.roi_align import RoIAlign 15 | 16 | """ def to_varabile(arr, requires_grad=False, is_cuda=True): 17 | tensor = torch.from_numpy(arr) 18 | if is_cuda: 19 | tensor = tensor.cuda() 20 | var = Variable(tensor, requires_grad=requires_grad) 21 | return var """ 22 | 23 | def to_varabile(tensor, requires_grad=False, is_cuda=True): 24 | if is_cuda: 25 | tensor = tensor.cuda() 26 | var = Variable(tensor, requires_grad=requires_grad) 27 | return var 28 | 29 | 30 | def get_test_input(): 31 | img = cv2.imread("dog-cycle-car.png") 32 | img = cv2.resize(img, (416,416)) 33 | img_ = img[:,:,::-1].transpose((2,0,1)) 34 | img_ = img_[np.newaxis,:,:,:]/255.0 35 | img_ = torch.from_numpy(img_).float() 36 | img_ = Variable(img_) 37 | return img_ 38 | 39 | 40 | def parse_cfg(cfgfile): 41 | """ 42 | Takes a configuration file 43 | 44 | Returns a list of blocks. Each blocks describes a block in the neural 45 | network to be built. Block is represented as a dictionary in the list 46 | 47 | """ 48 | file = open(cfgfile, 'r') 49 | lines = file.read().split('\n') #store the lines in a list 50 | lines = [x for x in lines if len(x) > 0] #get read of the empty lines 51 | lines = [x for x in lines if x[0] != '#'] 52 | lines = [x.rstrip().lstrip() for x in lines] 53 | 54 | 55 | block = {} 56 | blocks = [] 57 | 58 | for line in lines: 59 | if line[0] == "[": #This marks the start of a new block 60 | if len(block) != 0: 61 | blocks.append(block) 62 | block = {} 63 | block["type"] = line[1:-1].rstrip() 64 | else: 65 | key,value = line.split("=") 66 | block[key.rstrip()] = value.lstrip() 67 | blocks.append(block) 68 | 69 | # print('\n\n'.join([repr(x) for x in blocks])) 70 | return blocks 71 | # print('\n\n'.join([repr(x) for x in blocks])) 72 | 73 | import pickle as pkl 74 | 75 | class MaxPoolStride1(nn.Module): 76 | def __init__(self, kernel_size): 77 | super(MaxPoolStride1, self).__init__() 78 | self.kernel_size = kernel_size 79 | self.pad = kernel_size - 1 80 | 81 | def forward(self, x): 82 | padded_x = F.pad(x, (0,self.pad,0,self.pad), mode="replicate") 83 | pooled_x = nn.MaxPool2d(self.kernel_size, self.pad)(padded_x) 84 | return pooled_x 85 | 86 | 87 | class EmptyLayer(nn.Module): 88 | def __init__(self): 89 | super(EmptyLayer, self).__init__() 90 | 91 | 92 | class DetectionLayer(nn.Module): 93 | def __init__(self, anchors): 94 | super(DetectionLayer, self).__init__() 95 | self.anchors = anchors 96 | 97 | def forward(self, x, inp_dim, num_classes, confidence): 98 | x = x.data 99 | global CUDA 100 | prediction = x 101 | prediction = predict_transform(prediction, inp_dim, self.anchors, num_classes, confidence, device) 102 | return prediction 103 | 104 | 105 | 106 | 107 | 108 | class Upsample(nn.Module): 109 | def __init__(self, stride=2): 110 | super(Upsample, self).__init__() 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | stride = self.stride 115 | assert(x.data.dim() == 4) 116 | B = x.data.size(0) 117 | C = x.data.size(1) 118 | H = x.data.size(2) 119 | W = x.data.size(3) 120 | ws = stride 121 | hs = stride 122 | x = x.view(B, C, H, 1, W, 1).expand(B, C, H, stride, W, stride).contiguous().view(B, C, H*stride, W*stride) 123 | return x 124 | # 125 | 126 | class ReOrgLayer(nn.Module): 127 | def __init__(self, stride = 2): 128 | super(ReOrgLayer, self).__init__() 129 | self.stride= stride 130 | 131 | def forward(self,x): 132 | assert(x.data.dim() == 4) 133 | B,C,H,W = x.data.shape 134 | hs = self.stride 135 | ws = self.stride 136 | assert(H % hs == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(H) 137 | assert(W % ws == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(W) 138 | x = x.view(B,C, H // hs, hs, W // ws, ws).transpose(-2,-3).contiguous() 139 | x = x.view(B,C, H // hs * W // ws, hs, ws) 140 | x = x.view(B,C, H // hs * W // ws, hs*ws).transpose(-1,-2).contiguous() 141 | x = x.view(B, C, ws*hs, H // ws, W // ws).transpose(1,2).contiguous() 142 | x = x.view(B, C*ws*hs, H // ws, W // ws) 143 | return x 144 | 145 | 146 | def create_modules(blocks): 147 | net_info = blocks[0] #Captures the information about the input and pre-processing 148 | 149 | module_list = nn.ModuleList() 150 | 151 | index = 0 #indexing blocks helps with implementing route layers (skip connections) 152 | 153 | 154 | prev_filters = 3 155 | 156 | output_filters = [] 157 | 158 | for x in blocks: 159 | module = nn.Sequential() 160 | 161 | if (x["type"] == "net"): 162 | continue 163 | 164 | #If it's a convolutional layer 165 | if (x["type"] == "convolutional"): 166 | #Get the info about the layer 167 | activation = x["activation"] 168 | try: 169 | batch_normalize = int(x["batch_normalize"]) 170 | bias = False 171 | except: 172 | batch_normalize = 0 173 | bias = True 174 | 175 | filters= int(x["filters"]) 176 | padding = int(x["pad"]) 177 | kernel_size = int(x["size"]) 178 | stride = int(x["stride"]) 179 | 180 | if padding: 181 | pad = (kernel_size - 1) // 2 182 | else: 183 | pad = 0 184 | 185 | #Add the convolutional layer 186 | conv = nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias = bias) 187 | module.add_module("conv_{0}".format(index), conv) 188 | 189 | #Add the Batch Norm Layer 190 | if batch_normalize: 191 | bn = nn.BatchNorm2d(filters) 192 | module.add_module("batch_norm_{0}".format(index), bn) 193 | 194 | #Check the activation. 195 | #It is either Linear or a Leaky ReLU for YOLO 196 | if activation == "leaky": 197 | activn = nn.LeakyReLU(0.1, inplace = True) 198 | module.add_module("leaky_{0}".format(index), activn) 199 | 200 | 201 | 202 | #If it's an upsampling layer 203 | #We use Bilinear2dUpsampling 204 | 205 | elif (x["type"] == "upsample"): 206 | stride = int(x["stride"]) 207 | # upsample = Upsample(stride) 208 | upsample = nn.Upsample(scale_factor = 2, mode = "nearest") 209 | # upsample = nn.functional.interpolate(None, scale_factor = 2, mode = "nearest") 210 | module.add_module("upsample_{}".format(index), upsample) 211 | 212 | #If it is a route layer 213 | elif (x["type"] == "route"): 214 | x["layers"] = x["layers"].split(',') 215 | 216 | #Start of a route 217 | start = int(x["layers"][0]) 218 | 219 | #end, if there exists one. 220 | try: 221 | end = int(x["layers"][1]) 222 | except: 223 | end = 0 224 | 225 | 226 | 227 | #Positive anotation 228 | if start > 0: 229 | start = start - index 230 | 231 | if end > 0: 232 | end = end - index 233 | 234 | 235 | route = EmptyLayer() 236 | module.add_module("route_{0}".format(index), route) 237 | 238 | 239 | 240 | if end < 0: 241 | filters = output_filters[index + start] + output_filters[index + end] 242 | else: 243 | filters= output_filters[index + start] 244 | 245 | 246 | 247 | #shortcut corresponds to skip connection 248 | elif x["type"] == "shortcut": 249 | from_ = int(x["from"]) 250 | shortcut = EmptyLayer() 251 | module.add_module("shortcut_{}".format(index), shortcut) 252 | 253 | 254 | elif x["type"] == "maxpool": 255 | stride = int(x["stride"]) 256 | size = int(x["size"]) 257 | if stride != 1: 258 | maxpool = nn.MaxPool2d(size, stride) 259 | else: 260 | maxpool = MaxPoolStride1(size) 261 | 262 | module.add_module("maxpool_{}".format(index), maxpool) 263 | 264 | #Yolo is the detection layer 265 | elif x["type"] == "yolo": 266 | mask = x["mask"].split(",") 267 | mask = [int(xx) for xx in mask] 268 | 269 | anchors = x["anchors"].split(",") 270 | anchors = [int(a) for a in anchors] 271 | anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors),2)] 272 | anchors = [anchors[i] for i in mask] 273 | 274 | detection = DetectionLayer(anchors) 275 | module.add_module("Detection_{}".format(index), detection) 276 | 277 | 278 | 279 | else: 280 | print("Something I dunno") 281 | assert False 282 | 283 | 284 | module_list.append(module) 285 | prev_filters = filters 286 | output_filters.append(filters) 287 | index += 1 288 | 289 | 290 | return (net_info, module_list) 291 | 292 | 293 | 294 | class Darknet(nn.Module): 295 | def __init__(self, cfgfile): 296 | super(Darknet, self).__init__() 297 | self.blocks = parse_cfg(cfgfile) 298 | self.net_info, self.module_list = create_modules(self.blocks) 299 | self.header = torch.IntTensor([0,0,0,0]) 300 | self.seen = 0 301 | 302 | self.feature_map = None 303 | 304 | 305 | 306 | def get_blocks(self): 307 | return self.blocks 308 | 309 | def get_module_list(self): 310 | return self.module_list 311 | 312 | def get_feature(self): 313 | return self.feature_map 314 | 315 | 316 | def forward(self, x, device, flag_eval): 317 | detections = [] 318 | modules = self.blocks[1:] 319 | outputs = {} #We cache the outputs for the route layer 320 | 321 | layer_index = 1 322 | write = 0 323 | for i in range(len(modules)): 324 | 325 | module_type = (modules[i]["type"]) 326 | if module_type == "convolutional" or module_type == "upsample" or module_type == "maxpool": 327 | 328 | x = self.module_list[i](x) 329 | outputs[i] = x 330 | # print(x.size(), layer_index, module_type) 331 | layer_index += 1 332 | 333 | 334 | elif module_type == "route": 335 | layers = modules[i]["layers"] 336 | layers = [int(a) for a in layers] 337 | 338 | if (layers[0]) > 0: 339 | layers[0] = layers[0] - i 340 | 341 | if len(layers) == 1: 342 | x = outputs[i + (layers[0])] 343 | 344 | else: 345 | if (layers[1]) > 0: 346 | layers[1] = layers[1] - i 347 | 348 | map1 = outputs[i + layers[0]] 349 | map2 = outputs[i + layers[1]] 350 | 351 | x = torch.cat((map1, map2), 1) 352 | outputs[i] = x 353 | # print(x.size(), layer_index, module_type) 354 | layer_index += 1 355 | 356 | elif module_type == "shortcut": 357 | from_ = int(modules[i]["from"]) 358 | x = outputs[i-1] + outputs[i+from_] 359 | outputs[i] = x 360 | # print(x.size(), layer_index, module_type) 361 | 362 | if i == 61: 363 | self.feature_map = x 364 | 365 | layer_index += 1 366 | 367 | 368 | elif module_type == 'yolo' and flag_eval: 369 | 370 | anchors = self.module_list[i][0].anchors 371 | #Get the input dimensions 372 | inp_dim = int (self.net_info["height"]) 373 | 374 | #Get the number of classes 375 | num_classes = int (modules[i]["classes"]) 376 | 377 | #Output the result 378 | x = x.data 379 | x = predict_transform(x, inp_dim, anchors, num_classes, device) 380 | 381 | if type(x) == int: 382 | continue 383 | 384 | 385 | if not write: 386 | detections = x 387 | write = 1 388 | 389 | else: 390 | detections = torch.cat((detections, x), 1) 391 | 392 | outputs[i] = outputs[i-1] 393 | 394 | 395 | if flag_eval: 396 | try: 397 | return detections 398 | except: 399 | return 0 400 | else: 401 | return outputs[61] 402 | 403 | 404 | def load_weights(self, weightfile): 405 | 406 | #Open the weights file 407 | fp = open(weightfile, "rb") 408 | 409 | #The first 4 values are header information 410 | # 1. Major version number 411 | # 2. Minor Version Number 412 | # 3. Subversion number 413 | # 4. IMages seen 414 | header = np.fromfile(fp, dtype = np.int32, count = 5) 415 | self.header = torch.from_numpy(header) 416 | self.seen = self.header[3] 417 | 418 | #The rest of the values are the weights 419 | # Let's load them up 420 | weights = np.fromfile(fp, dtype = np.float32) 421 | 422 | ptr = 0 423 | for i in range(len(self.module_list)): 424 | module_type = self.blocks[i + 1]["type"] 425 | 426 | if module_type == "convolutional": 427 | model = self.module_list[i] 428 | try: 429 | batch_normalize = int(self.blocks[i+1]["batch_normalize"]) 430 | except: 431 | batch_normalize = 0 432 | 433 | conv = model[0] 434 | 435 | if (batch_normalize): 436 | bn = model[1] 437 | 438 | #Get the number of weights of Batch Norm Layer 439 | num_bn_biases = bn.bias.numel() 440 | 441 | #Load the weights 442 | bn_biases = torch.from_numpy(weights[ptr:ptr + num_bn_biases]) 443 | ptr += num_bn_biases 444 | 445 | bn_weights = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 446 | ptr += num_bn_biases 447 | 448 | bn_running_mean = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 449 | ptr += num_bn_biases 450 | 451 | bn_running_var = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 452 | ptr += num_bn_biases 453 | 454 | #Cast the loaded weights into dims of model weights. 455 | bn_biases = bn_biases.view_as(bn.bias.data) 456 | bn_weights = bn_weights.view_as(bn.weight.data) 457 | bn_running_mean = bn_running_mean.view_as(bn.running_mean) 458 | bn_running_var = bn_running_var.view_as(bn.running_var) 459 | 460 | #Copy the data to model 461 | bn.bias.data.copy_(bn_biases) 462 | bn.weight.data.copy_(bn_weights) 463 | bn.running_mean.copy_(bn_running_mean) 464 | bn.running_var.copy_(bn_running_var) 465 | 466 | else: 467 | #Number of biases 468 | num_biases = conv.bias.numel() 469 | 470 | #Load the weights 471 | conv_biases = torch.from_numpy(weights[ptr: ptr + num_biases]) 472 | ptr = ptr + num_biases 473 | 474 | #reshape the loaded weights according to the dims of the model weights 475 | conv_biases = conv_biases.view_as(conv.bias.data) 476 | 477 | #Finally copy the data 478 | conv.bias.data.copy_(conv_biases) 479 | 480 | 481 | #Let us load the weights for the Convolutional layers 482 | num_weights = conv.weight.numel() 483 | 484 | #Do the same as above for weights 485 | conv_weights = torch.from_numpy(weights[ptr:ptr+num_weights]) 486 | ptr = ptr + num_weights 487 | 488 | conv_weights = conv_weights.view_as(conv.weight.data) 489 | conv.weight.data.copy_(conv_weights) 490 | 491 | def save_weights(self, savedfile, cutoff = 0): 492 | 493 | if cutoff <= 0: 494 | cutoff = len(self.blocks) - 1 495 | 496 | fp = open(savedfile, 'wb') 497 | 498 | # Attach the header at the top of the file 499 | self.header[3] = self.seen 500 | header = self.header 501 | 502 | header = header.numpy() 503 | header.tofile(fp) 504 | 505 | # Now, let us save the weights 506 | for i in range(len(self.module_list)): 507 | module_type = self.blocks[i+1]["type"] 508 | 509 | if (module_type) == "convolutional": 510 | model = self.module_list[i] 511 | try: 512 | batch_normalize = int(self.blocks[i+1]["batch_normalize"]) 513 | except: 514 | batch_normalize = 0 515 | 516 | conv = model[0] 517 | 518 | if (batch_normalize): 519 | bn = model[1] 520 | 521 | #If the parameters are on GPU, convert them back to CPU 522 | #We don't convert the parameter to GPU 523 | #Instead. we copy the parameter and then convert it to CPU 524 | #This is done as weight are need to be saved during training 525 | cpu(bn.bias.data).numpy().tofile(fp) 526 | cpu(bn.weight.data).numpy().tofile(fp) 527 | cpu(bn.running_mean).numpy().tofile(fp) 528 | cpu(bn.running_var).numpy().tofile(fp) 529 | 530 | 531 | else: 532 | cpu(conv.bias.data).numpy().tofile(fp) 533 | 534 | 535 | #Let us save the weights for the Convolutional layers 536 | cpu(conv.weight.data).numpy().tofile(fp) 537 | 538 | 539 | 540 | 541 | 542 | # 543 | #dn = Darknet('cfg/yolov3.cfg') 544 | #dn.load_weights("yolov3.weights") 545 | #inp = get_test_input() 546 | #a, interms = dn(inp) 547 | #dn.eval() 548 | #a_i, interms_i = dn(inp) 549 | -------------------------------------------------------------------------------- /data/clothing_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/data/clothing_vocab.pkl -------------------------------------------------------------------------------- /data/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /data/voc.names: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor 21 | -------------------------------------------------------------------------------- /file_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import pickle 5 | import os 6 | from os import listdir, getcwd 7 | import os.path as osp 8 | import glob 9 | from torchvision import transforms 10 | from model import EncoderClothing, DecoderClothing 11 | from darknet import Darknet 12 | from PIL import Image 13 | from util import * 14 | import cv2 15 | import pickle as pkl 16 | import random 17 | from preprocess import prep_image 18 | 19 | 20 | import sys 21 | if sys.version_info >= (3,0): 22 | from roi_align.roi_align import RoIAlign 23 | else : 24 | from roi_align import RoIAlign 25 | 26 | # "winter scarf", "cane", "bag", "shoes", "hat", "face"] 27 | #attribute categories = #6 28 | colors_a = ["", "white", "black", "gray", "pink", "red", "green", "blue", "brown", "navy", "beige", \ 29 | "yellow", "purple", "orange", "mixed color"] #0-14 30 | pattern_a = ["", "no pattern", "checker", "dotted", "floral", "striped", "custom pattern"] #0-6 31 | gender_a = ["", "man", "woman"] #0-2 32 | season_a = ["", "spring", "summer", "autumn", "winter"] #0-4 33 | upper_t_a = ["", "shirt", "jumper", "jacket", "vest", "parka", "coat", "dress"]#0-7 34 | u_sleeves_a = ["", "short sleeves", "long sleeves", "no sleeves"]#0-3 35 | 36 | lower_t_a = ["", "pants", "skirt"]#0-2 37 | l_sleeves_a = ["", "short", "long"]#0-2 38 | leg_pose_a = ["", "standing", "sitting", "lying"]#0-3 39 | 40 | glasses_a = ["", "glasses"] 41 | 42 | attribute_pool = [colors_a, pattern_a, gender_a, season_a, upper_t_a, u_sleeves_a, 43 | colors_a, pattern_a, gender_a, season_a, lower_t_a, l_sleeves_a, leg_pose_a] 44 | 45 | # Device configuration 46 | #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 47 | device = torch.device('cpu') 48 | def main(args): 49 | # Image preprocessing 50 | transform = transforms.Compose([ 51 | transforms.ToTensor()]) 52 | 53 | # Load vocabulary wrapper 54 | 55 | # Build the models 56 | #CUDA = torch.cuda.is_available() 57 | 58 | num_classes = 80 59 | yolov3 = Darknet(args.cfg_file) 60 | yolov3.load_weights(args.weights_file) 61 | 62 | 63 | print("yolo-v3 network successfully loaded") 64 | 65 | attribute_size = [15, 7, 3, 5, 8, 4, 15, 7, 3, 5, 3, 3, 4] 66 | 67 | encoder = EncoderClothing(args.embed_size, device, args.roi_size, attribute_size) 68 | 69 | # Prepare an image 70 | images = "test" 71 | 72 | 73 | try: 74 | list_dir = os.listdir(images) 75 | # list_dir.sort(key=lambda x: int(x[:-4])) 76 | imlist = [osp.join(osp.realpath('.'), images, img) for img in list_dir if os.path.splitext(img)[1].lower() =='.jpg' or os.path.splitext(img)[1].lower() =='.jpeg' or os.path.splitext(img)[1] =='.png'] 77 | print(imlist) 78 | except NotADirectoryError: 79 | imlist = [] 80 | imlist.append(osp.join(osp.realpath('.'), images)) 81 | print('Not a directory error') 82 | except FileNotFoundError: 83 | print ("No file or directory with the name {}".format(images)) 84 | exit() 85 | 86 | yolov3.to(device) 87 | encoder.to(device) 88 | 89 | yolov3.eval() 90 | encoder.eval() 91 | 92 | encoder.load_state_dict(torch.load(args.encoder_path, map_location=device)) 93 | 94 | for inx, image_filename in enumerate(imlist): 95 | print ('---------------------------') 96 | print(image_filename) 97 | preload_img = cv2.imread(image_filename) 98 | pixel = preload_img.shape[1] * preload_img.shape[0] 99 | 100 | if pixel < 352 * 240: 101 | yolov3.net_info["height"] = 1024 102 | if pixel >= 352 * 240: 103 | yolov3.net_info["height"] = 1024 104 | if pixel >= 480 * 360: 105 | yolov3.net_info["height"] = 1024 106 | if pixel >= 858 * 480: 107 | yolov3.net_info["height"] = 832 108 | if pixel >= 1280 * 720: 109 | yolov3.net_info["height"] = 832 110 | 111 | inp_dim = int(yolov3.net_info["height"]) 112 | 113 | image, orig_img, im_dim = prep_image(image_filename, inp_dim) 114 | 115 | im_dim = torch.FloatTensor(im_dim).repeat(1, 2) 116 | 117 | image_tensor = image.to(device) 118 | im_dim = im_dim.to(device) 119 | 120 | # Generate an caption from the image 121 | detections = yolov3(image_tensor, device, True) # prediction mode for yolo-v3 122 | detections = write_results(detections, args.confidence, num_classes, device, nms=True, nms_conf=args.nms_thresh) 123 | # original image dimension --> im_dim 124 | #view_image(detections) 125 | 126 | 127 | if type(detections) != int: 128 | if detections.shape[0]: 129 | bboxs = detections[:, 1:5].clone() 130 | im_dim = im_dim.repeat(detections.shape[0], 1) 131 | scaling_factor = torch.min(inp_dim/im_dim, 1)[0].view(-1, 1) 132 | 133 | detections[:, [1, 3]] -= (inp_dim - scaling_factor*im_dim[:, 0].view(-1, 1))/2 134 | detections[:, [2, 4]] -= (inp_dim - scaling_factor*im_dim[:, 1].view(-1, 1))/2 135 | 136 | detections[:, 1:5] /= scaling_factor 137 | 138 | small_object_ratio = torch.FloatTensor(detections.shape[0]) 139 | 140 | for i in range(detections.shape[0]): 141 | detections[i, [1, 3]] = torch.clamp(detections[i, [1, 3]], 0.0, im_dim[i, 0]) 142 | detections[i, [2, 4]] = torch.clamp(detections[i, [2, 4]], 0.0, im_dim[i, 1]) 143 | 144 | object_area = (detections[i, 3] - detections[i, 1])*(detections[i, 4] - detections[i, 2]) 145 | orig_img_area = im_dim[i, 0]*im_dim[i, 1] 146 | small_object_ratio[i] = object_area/orig_img_area 147 | 148 | detections = detections[small_object_ratio > 0.02] 149 | im_dim = im_dim[small_object_ratio > 0.02] 150 | 151 | if detections.size(0) > 0: 152 | feature = yolov3.get_feature() 153 | feature = feature.repeat(detections.size(0), 1, 1, 1) 154 | 155 | #orig_img_dim = im_dim[:, 1:] 156 | #orig_img_dim = orig_img_dim.repeat(1, 2) 157 | 158 | scaling_val = 16 159 | 160 | bboxs /= scaling_val 161 | bboxs = bboxs.round() 162 | bboxs_index = torch.arange(bboxs.size(0), dtype=torch.int) 163 | bboxs_index = bboxs_index.to(device) 164 | bboxs = bboxs.to(device) 165 | 166 | roi_align = RoIAlign(args.roi_size, args.roi_size, transform_fpcoor=True).to(device) 167 | roi_features = roi_align(feature, bboxs, bboxs_index) 168 | # print(roi_features) 169 | # print(roi_features.size()) 170 | 171 | #roi_features = roi_features.reshape(roi_features.size(0), -1) 172 | 173 | #roi_align_feature = encoder(roi_features) 174 | 175 | outputs = encoder(roi_features) 176 | #attribute_size = [15, 7, 3, 5, 7, 4, 15, 7, 3, 5, 4, 3, 4] 177 | #losses = [criteria[i](outputs[i], targets[i]) for i in range(len(attribute_size))] 178 | 179 | for i in range(detections.shape[0]): 180 | 181 | sampled_caption = [] 182 | #attr_fc = outputs[] 183 | for j in range(len(outputs)): 184 | #temp = outputs[j][i].data 185 | max_index = torch.max(outputs[j][i].data, 0)[1] 186 | word = attribute_pool[j][max_index] 187 | sampled_caption.append(word) 188 | 189 | c11 = sampled_caption[11] 190 | sampled_caption[11] = sampled_caption[10] 191 | sampled_caption[10] = c11 192 | 193 | sentence = '/'.join(sampled_caption) 194 | 195 | # again sampling for testing 196 | 197 | print (str(i+1) + ': ' + sentence) 198 | write(detections[i], orig_img, sampled_caption,sentence, i+1, coco_classes, colors) 199 | #list(map(lambda x: write(x, orig_img, captions), detections[i].unsqueeze(0))) 200 | 201 | cv2.imshow("frame", orig_img) 202 | key = cv2.waitKey(0) 203 | 204 | if key & 0xFF == ord('q'): 205 | break 206 | 207 | # image = Image.open(args.image) 208 | # plt.imshow(np.asarray(image)) 209 | 210 | if __name__ == '__main__': 211 | parser = argparse.ArgumentParser() 212 | 213 | parser.add_argument('--encoder_path', type=str, default='encoder-12-1170.ckpt', help='path for trained encoder') 214 | 215 | parser.add_argument('--vocab_path1', type=str, default='json/train_up_vocab.pkl', help='path for vocabulary wrapper') 216 | parser.add_argument('--vocab_path2', type=str, default='clothing_vocab_accessory2.pkl', help='path for vocabulary wrapper') 217 | 218 | # Encoder - Yolo-v3 parameters 219 | parser.add_argument('--confidence', type=float, default = 0.5, help = 'Object Confidence to filter predictions') 220 | parser.add_argument('--nms_thresh', type=float , default = 0.4, help = 'NMS Threshhold') 221 | parser.add_argument('--cfg_file', type = str, default = 'cfg/yolov3.cfg', help ='Config file') 222 | parser.add_argument('--weights_file', type = str, default = 'yolov3.weights', help = 'weightsfile') 223 | parser.add_argument('--reso', type=str, default = '832', help = 'Input resolution of the network. Increase to increase accuracy. Decrease to increase speed') 224 | parser.add_argument('--scales', type=str, default = '1,2,3', help = 'Scales to use for detection') 225 | 226 | # Model parameters (should be same as paramters in train.py) 227 | parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors') 228 | parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states') 229 | parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm') 230 | parser.add_argument('--roi_size', type=int, default = 13) 231 | args = parser.parse_args() 232 | 233 | coco_classes = load_classes('data/coco.names') 234 | colors = pkl.load(open("pallete2", "rb")) 235 | 236 | main(args) 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | import numpy as np 6 | from roi_align.roi_align import RoIAlign 7 | #from darknet import Darknet 8 | 9 | import cv2 10 | 11 | def view_image(bboxes): 12 | img = np.full((416, 416, 3), 100, dtype='uint8') 13 | for bbox in bboxes: 14 | c1 = tuple(bbox[0:2].int()*16) 15 | c2 = tuple(bbox[2:4].int()*16) 16 | cv2.rectangle(img, c1, c2, 128, 3) 17 | cv2.imshow("02", img) 18 | cv2.waitKey(0) 19 | cv2.destroyAllWindows() 20 | 21 | class Reshape(nn.Module): 22 | def __init__(self, *args): 23 | super(Reshape, self).__init__() 24 | self.shape = args 25 | 26 | def forward(self, x): 27 | return x.view(self.shape) 28 | 29 | class EncoderClothing(nn.Module): 30 | def __init__(self, embed_size, device, pool_size, attribute_dim): 31 | """Load the pretrained yolo-v3 """ 32 | super(EncoderClothing, self).__init__() 33 | self.device = device 34 | self.linear = nn.Linear(512*pool_size*pool_size, embed_size) 35 | self.relu = nn.ReLU() 36 | #self.module_list = nn.ModuleList([nn.Linear(embed_size, att_size) for att_size in attribute_dim]) 37 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 38 | self.dropout = nn.Dropout(0.5) 39 | self.pool_size = pool_size 40 | 41 | self.module_list = nn.ModuleList([self.conv_bn(512, 256, 1, embed_size, att_size) for att_size in attribute_dim]) 42 | 43 | 44 | 45 | def conv_bn(self, in_planes, out_planes, kernel_size, embed_size, att_size, stride=1, padding=0, bias=False): 46 | #"convolution with batchnorm, relu" 47 | return nn.Sequential( 48 | nn.Conv2d(in_planes, out_planes, 1, stride=stride, 49 | padding=padding, bias=False), 50 | nn.BatchNorm2d(out_planes, eps=1e-3), 51 | nn.ReLU(), 52 | nn.AdaptiveAvgPool2d((1,1)), 53 | nn.BatchNorm2d(out_planes, eps=1e-3), 54 | nn.ReLU(), 55 | nn.Dropout(0.5), 56 | Reshape(-1, embed_size), 57 | nn.Linear(embed_size, att_size) 58 | ) 59 | 60 | def forward(self, x): 61 | """change feature dimension to original image dimension""" 62 | outputs = {} 63 | #features = self.bn(self.linear(features)) 64 | #x = self.relu(self.bn(self.linear(features))) 65 | # x = self.dropout(x) 66 | 67 | for i in range(len(self.module_list)): 68 | output = self.module_list[i](x) 69 | outputs[i] = output 70 | 71 | return outputs 72 | 73 | class EncoderClothing1(nn.Module): 74 | def __init__(self, embed_size, device, pool_size, attribute_dim): 75 | """Load the pretrained yolo-v3 """ 76 | super(EncoderClothing, self).__init__() 77 | self.device = device 78 | self.linear = nn.Linear(512*pool_size*pool_size, embed_size) 79 | self.module_list = nn.ModuleList([nn.Linear(512*pool_size*pool_size, att_size) for att_size in attribute_dim]) 80 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 81 | self.pool_size = pool_size 82 | 83 | def forward(self, features): 84 | """change feature dimension to original image dimension""" 85 | outputs = {} 86 | #features = self.bn(self.linear(features)) 87 | for i in range(len(self.module_list)): 88 | x = self.module_list[i](features) 89 | outputs[i] = x 90 | 91 | return outputs 92 | 93 | 94 | class DecoderClothing(nn.Module): 95 | def __init__(self, embed_size, hidden_size, vocab_size, vocab, num_layers, max_seq_length=30): 96 | """Set the hyper-parameters and build the layers.""" 97 | super(DecoderClothing, self).__init__() 98 | self.embed = nn.Embedding(vocab_size, embed_size) 99 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) 100 | self.linear = nn.Linear(hidden_size, vocab_size) 101 | self.max_seg_length = max_seq_length 102 | self.vocab = vocab 103 | self.clothing_class = { 104 | 'shirt':0, 'jumper':1, 'jacket':2, 'vest':3, 'coat':4, 105 | 'dress':5, 'pants':6, 'skirt':7, 'scarf':8, 'cane':9, 'bag':10, 'shoes':11, 106 | 'hat':12, 'face':13, 'glasses':14 } 107 | 108 | def forward(self, features, captions, lengths): # for training 109 | """Decode image feature vectors and generates captions.""" 110 | embeddings = self.embed(captions) # [B, 10, 256] for captions = [B, 10] 111 | embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) 112 | packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 113 | hiddens, _ = self.lstm(packed) 114 | outputs = self.linear(hiddens[0]) 115 | return outputs 116 | 117 | def sample1(self, features, states=None): # for prediction 118 | """Generate captions for given image features using greedy search.""" 119 | sampled_ids = [] 120 | inputs = features.unsqueeze(1) 121 | for i in range(self.max_seg_length): 122 | hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size) 123 | outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size) 124 | _, predicted = outputs.max(1) # predicted: (batch_size) 125 | sampled_ids.append(predicted) 126 | print(predicted) 127 | inputs = self.embed(predicted) # inputs: (batch_size, embed_size) 128 | inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size) 129 | sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length) 130 | return sampled_ids 131 | 132 | def sample(self, features, states=None): # for predicton 133 | """Generate captions for given image features using greedy search.""" 134 | sampled_ids = [] 135 | prob_ids = [] 136 | k_samples = [] 137 | k_probs = [] 138 | sampling_num = 30 139 | prob_thresh = 0.1 140 | histogram_clothing = np.zeros(15, dtype=int) 141 | inputs = features.unsqueeze(1) 142 | 143 | for i in range(2): 144 | hiddens, states = self.lstm(inputs, states) 145 | outputs = self.linear(hiddens.squeeze(1)) 146 | if i == 0 : 147 | prob_pred, predicted = outputs.max(1) 148 | inputs = self.embed(predicted) 149 | inputs = inputs.unsqueeze(1) 150 | # states_m = states 151 | else : 152 | top_k_prob, top_k = outputs.topk(sampling_num) 153 | #top_k = top_k.squeeze(0) 154 | 155 | for i in range(sampling_num): 156 | inputs = self.embed(top_k[:,i]) 157 | inputs = inputs.unsqueeze(1) 158 | word_prob = top_k_prob[:,i] 159 | if word_prob < prob_thresh: 160 | break 161 | sampled_ids.append(top_k[:,i]) 162 | # print(self.vocab.idx2word[top_k[:,i].cpu().numpy()[0]]) 163 | prob_ids.append(word_prob) 164 | _states = states # re-load 165 | duplicate_tag = False 166 | 167 | for j in range(self.max_seg_length): 168 | _hiddens, _states = self.lstm(inputs, _states) 169 | outputs = self.linear(_hiddens.squeeze(1)) 170 | prob_pred, predicted = outputs.max(1) 171 | 172 | word = self.vocab.idx2word[predicted.cpu().numpy()[0]] 173 | if word == '': 174 | break 175 | 176 | class_index = self.clothing_class.get(word, '') 177 | if class_index is not '': 178 | if histogram_clothing[class_index] > 0: 179 | duplicate_tag = True 180 | break 181 | else: 182 | if word == 'jacket' or word == 'coat' or word == 'jumper': 183 | class_index = self.clothing_class.get('jacket') 184 | histogram_clothing[class_index] += 1 185 | class_index = self.clothing_class.get('coat') 186 | histogram_clothing[class_index] += 1 187 | class_index = self.clothing_class.get('jumper') 188 | histogram_clothing[class_index] += 1 189 | else: 190 | histogram_clothing[class_index] += 1 191 | 192 | sampled_ids.append(predicted) 193 | prob_ids.append(prob_pred) 194 | inputs = self.embed(predicted) 195 | inputs = inputs.unsqueeze(1) 196 | if duplicate_tag : 197 | duplicate_tag = False 198 | sampled_ids = [] 199 | prob_ids = [] 200 | continue 201 | sampled_ids = torch.stack(sampled_ids, 1) 202 | prob_ids = torch.stack(prob_ids, 1) 203 | k_samples.append(sampled_ids) 204 | k_probs.append(prob_ids) 205 | sampled_ids = [] 206 | prob_ids = [] 207 | 208 | return k_samples, k_probs 209 | -------------------------------------------------------------------------------- /nice_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/nice_example.png -------------------------------------------------------------------------------- /pallete: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/pallete -------------------------------------------------------------------------------- /pallete2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/pallete2 -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import cv2 9 | 10 | from util import count_parameters as count 11 | from util import convert2cpu as cpu 12 | from PIL import Image, ImageDraw 13 | 14 | 15 | def letterbox_image(img, inp_dim): 16 | '''resize image with unchanged aspect ratio using padding''' 17 | img_w, img_h = img.shape[1], img.shape[0] 18 | w, h = inp_dim 19 | new_w = int(img_w * min(w/img_w, h/img_h)) 20 | new_h = int(img_h * min(w/img_w, h/img_h)) 21 | resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC) 22 | 23 | canvas = np.full((inp_dim[1], inp_dim[0], 3), 128, dtype='uint8') 24 | 25 | canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w, :] = resized_image 26 | 27 | return canvas 28 | 29 | 30 | 31 | def prep_image(img, inp_dim): 32 | """ 33 | Prepare image for inputting to the neural network. 34 | 35 | Returns a Variable 36 | """ 37 | 38 | orig_im = cv2.imread(img) 39 | dim = orig_im.shape[1], orig_im.shape[0] 40 | img = (letterbox_image(orig_im, (inp_dim, inp_dim))) 41 | img_ = img[:,:,::-1].transpose((2,0,1)).copy() 42 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 43 | return img_, orig_im, dim 44 | 45 | def prep_image2(img, inp_dim): 46 | """ 47 | Prepare image for inputting to the neural network. 48 | 49 | Returns a Variable 50 | """ 51 | 52 | orig_im = img 53 | dim = orig_im.shape[1], orig_im.shape[0] 54 | img = (letterbox_image(orig_im, (inp_dim, inp_dim))) 55 | img_ = img[:,:,::-1].transpose((2,0,1)).copy() 56 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 57 | return img_, orig_im, dim 58 | 59 | def prep_image1(img, inp_dim): 60 | """ 61 | Prepare image for inputting to the neural network. 62 | 63 | Returns a Variable 64 | """ 65 | 66 | orig_im = cv2.imread(img) 67 | dim = [orig_im.shape[1], orig_im.shape[0]] 68 | img = (letterbox_image(orig_im, (inp_dim, inp_dim))) 69 | img_ = img[:,:,::-1].transpose((2,0,1)).copy() 70 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 71 | return img_, orig_im, dim 72 | 73 | def prep_image_pil(img, network_dim): 74 | orig_im = Image.open(img) 75 | img = orig_im.convert('RGB') 76 | dim = img.size 77 | img = img.resize(network_dim) 78 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) 79 | # img = img.view(*network_dim, 3).transpose(0,1).transpose(0,2).contiguous() 80 | img = img.view(1, 3,*network_dim) 81 | img = img.float().div(255.0) 82 | return (img, orig_im, dim) 83 | 84 | def inp_to_image(inp): 85 | inp = inp.cpu().squeeze() 86 | inp = inp*255 87 | try: 88 | inp = inp.data.numpy() 89 | except RuntimeError: 90 | inp = inp.numpy() 91 | inp = inp.transpose(1,2,0) 92 | 93 | inp = inp[:,:,::-1] 94 | return inp 95 | 96 | 97 | -------------------------------------------------------------------------------- /test/819.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/test/819.jpg -------------------------------------------------------------------------------- /test/828.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/test/828.jpg -------------------------------------------------------------------------------- /test/829.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/fashion-ai-analysis/f92aa864070cb798dda901ec6280846e15ba0632/test/829.jpg -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import cv2 10 | import random 11 | 12 | from bbox import bbox_iou, bbox_iou2 13 | import sys 14 | 15 | 16 | def prep_image(img, inp_dim): 17 | """ 18 | Prepare image for inputting to the neural network. 19 | 20 | Returns a Variable 21 | """ 22 | 23 | orig_im = img 24 | dim = orig_im.shape[1], orig_im.shape[0] 25 | img = cv2.resize(orig_im, (inp_dim, inp_dim)) 26 | # img_ = img[:,:,::-1].transpose((2,0,1)).copy() 27 | img_ = img.transpose((2,0,1)).copy() 28 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 29 | return img_, orig_im, dim 30 | 31 | def load_image2(image_path, transform=None): 32 | image = Image.open(image_path) 33 | orig_img = np.array(image) 34 | dim = image.size 35 | image = image.resize([416, 416], Image.LANCZOS) #224 416 36 | 37 | if transform is not None: 38 | image = transform(image).unsqueeze(0) 39 | 40 | return image, orig_img, dim 41 | 42 | def write(x, img, phrases, coco_classes, colors): 43 | c1 = tuple(x[1:3].int()) 44 | c2 = tuple(x[3:5].int()) 45 | cls = int(x[-1]) 46 | if cls == 0: # for human bbox only detection 47 | label = "{0}".format(coco_classes[cls]) 48 | color = random.choice(colors) 49 | cv2.rectangle(img, c1, c2,color, 3) 50 | t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0] 51 | c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4 52 | cv2.rectangle(img, c1, c2, color, -1) 53 | cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1) 54 | for i, phrase in enumerate(phrases): 55 | cv2.putText(img, phrase, (c1[0]-200, c1[1] + t_size[1] + 150 + i*20), cv2.FONT_HERSHEY_PLAIN, 1, color, 1 ) 56 | return img 57 | 58 | def write(x, img, raw_pharses,phrases, order, coco_classes, colors): 59 | c1 = tuple(x[1:3].int()) 60 | c2 = tuple(x[3:5].int()) 61 | cls = int(x[-1]) 62 | attrs_name = ['Tops', 'color', 'pattern', 'gender', 'season', 'type', 'spleeves', 'Bottoms', 'color', 'pattern', 63 | 'gender', 'season', 'spleeves', 'type', "legpose"] 64 | 65 | if cls == 0: # for human bbox only detection 66 | #label = "{0}: {1}".format(coco_classes[cls], order) 67 | label = " " 68 | color = random.choice(colors) 69 | cv2.rectangle(img, c1, c2,color, 3) 70 | t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] 71 | c2 = c1[0] + t_size[0] + 3, c1[1] + len(attrs_name)*t_size[1] + 4 72 | cv2.rectangle(img, c1, c2, color, -1) 73 | 74 | j = 0 75 | for i in range(len(attrs_name)): 76 | if attrs_name[i] == "Bottoms" or attrs_name[i] == "Tops": 77 | text = attrs_name[i] 78 | j = j + 1 79 | else: 80 | 81 | text = " "+attrs_name[i] + ":" + raw_pharses[i - j] 82 | 83 | cv2.putText(img, label + " " + text, (c1[0] -t_size[0], c1[1] + t_size[1] + 5 + i*t_size[1]), cv2.FONT_HERSHEY_PLAIN, 1, 84 | [225, 255, 255], 1) 85 | 86 | 87 | #cv2.putText(img, label+" "+phrases, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1) 88 | 89 | #for i, phrase in enumerate(phrases): 90 | # cv2.putText(img, phrase, (c1[0], c1[1] + t_size[1] + 150 + i*20), cv2.FONT_HERSHEY_PLAIN, 1, color, 1 ) 91 | return img 92 | 93 | def view_image(bboxes): 94 | img = np.full((416, 416, 3), 100, dtype='uint8') 95 | for bbox in bboxes: 96 | c1 = tuple(bbox[1:3].int()) 97 | c2 = tuple(bbox[3:5].int()) 98 | cv2.rectangle(img, c1, c2, 128+20, 3) 99 | cv2.imshow("01", img) 100 | cv2.waitKey(0) 101 | cv2.destroyAllWindows() 102 | 103 | def sampling_sentence(samples, probs, vocab): 104 | 105 | j = 0 106 | captions = [] 107 | for sampled_ids in samples: 108 | sampled_ids = sampled_ids[0].cpu().numpy() 109 | # (1, max_seq_length) -> (max_seq_length) 110 | probs_ids = probs[j][0].detach() 111 | probs_ids = probs_ids.cpu().numpy() 112 | if j == 1 : 113 | break 114 | j+=1 115 | 116 | # Convert word_ids to words 117 | sampled_caption = list() 118 | t_caption = list() 119 | no_flag = False 120 | for word_id in sampled_ids: 121 | word = vocab.idx2word[word_id] 122 | if word == '': 123 | break 124 | sampled_caption.append(word) 125 | 126 | for word in sampled_caption: 127 | if word == 'no': 128 | no_flag = True 129 | continue 130 | else: 131 | if not no_flag: 132 | t_caption.append(word) 133 | no_flag = False 134 | else: 135 | continue 136 | if t_caption.__len__() > 0: 137 | if t_caption[-1] == 'and': 138 | t_caption.pop() 139 | sentence = ' '.join(t_caption) 140 | p = sentence.find('and') 141 | if p > 0: 142 | upper = sentence[:p] 143 | if upper.find('vest') > -1: 144 | upper = upper + ' no sleeves' 145 | captions.append(upper) 146 | captions.append(sentence[p:]) 147 | else: 148 | if sentence.find('vest') > -1: 149 | sentence = sentence + ' no sleeves' 150 | captions.append(sentence) 151 | #sentence = ' '.join(sampled_caption) 152 | 153 | # Print out the image and the generated caption 154 | #print (sentence) 155 | #sys.stdout.write(sentence) 156 | #print (probs_ids) 157 | #captions.append(sentence) 158 | return captions 159 | 160 | def sampling_sentence1(sampled_ids, vocab): 161 | sampled_ids = sampled_ids[0].cpu().numpy() 162 | # (1, max_seq_length) -> (max_seq_length) 163 | 164 | # Convert word_ids to words 165 | sampled_caption = list() 166 | t_caption = list() 167 | no_flag = False 168 | for word_id in sampled_ids[0]: 169 | word = vocab.idx2word[word_id] 170 | if word == '': 171 | break 172 | sampled_caption.append(word) 173 | 174 | for word in sampled_caption: 175 | if word == 'no': 176 | no_flag = True 177 | continue 178 | else: 179 | if not no_flag: 180 | t_caption.append(word) 181 | no_flag = False 182 | else: 183 | continue 184 | if t_caption.__len__() > 0: 185 | if t_caption[-1] == 'and': 186 | t_caption.pop() 187 | sentence = ' '.join(t_caption) 188 | 189 | 190 | # Print out the image and the generated caption 191 | print (sentence) 192 | #print (probs_ids) 193 | 194 | 195 | return sentence 196 | 197 | def count_parameters(model): 198 | return sum(p.numel() for p in model.parameters()) 199 | 200 | def count_learnable_parameters(model): 201 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 202 | 203 | def convert2cpu(matrix): 204 | if matrix.is_cuda: 205 | return torch.FloatTensor(matrix.size()).copy_(matrix) 206 | else: 207 | return matrix 208 | 209 | def predict_transform(prediction, inp_dim, anchors, num_classes, device): 210 | batch_size = prediction.size(0) 211 | stride = inp_dim // prediction.size(2) 212 | grid_size = inp_dim // stride 213 | bbox_attrs = 5 + num_classes 214 | num_anchors = len(anchors) 215 | 216 | anchors = [(a[0]/stride, a[1]/stride) for a in anchors] 217 | 218 | prediction = prediction.view(batch_size, bbox_attrs*num_anchors, grid_size*grid_size) 219 | prediction = prediction.transpose(1,2).contiguous() 220 | prediction = prediction.view(batch_size, grid_size*grid_size*num_anchors, bbox_attrs) 221 | 222 | 223 | #Sigmoid the centre_X, centre_Y. and object confidencce 224 | prediction[:,:,0] = torch.sigmoid(prediction[:,:,0]) 225 | prediction[:,:,1] = torch.sigmoid(prediction[:,:,1]) 226 | prediction[:,:,4] = torch.sigmoid(prediction[:,:,4]) 227 | 228 | 229 | #Add the center offsets 230 | grid_len = np.arange(grid_size) 231 | a,b = np.meshgrid(grid_len, grid_len) 232 | 233 | x_offset = torch.FloatTensor(a).view(-1,1) 234 | y_offset = torch.FloatTensor(b).view(-1,1) 235 | 236 | 237 | x_offset = x_offset.to(device) 238 | y_offset = y_offset.to(device) 239 | 240 | x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0) 241 | 242 | prediction[:,:,:2] += x_y_offset 243 | 244 | #log space transform height and the width 245 | anchors = torch.FloatTensor(anchors) 246 | 247 | 248 | anchors = anchors.to(device) 249 | 250 | anchors = anchors.repeat(grid_size*grid_size, 1).unsqueeze(0) 251 | prediction[:,:,2:4] = torch.exp(prediction[:,:,2:4])*anchors 252 | 253 | #Softmax the class scores 254 | prediction[:,:,5: 5 + num_classes] = torch.sigmoid((prediction[:,:, 5 : 5 + num_classes])) 255 | 256 | prediction[:,:,:4] *= stride 257 | 258 | 259 | return prediction 260 | 261 | def load_classes(namesfile): 262 | fp = open(namesfile, "r") 263 | names = fp.read().split("\n")[:-1] 264 | return names 265 | 266 | def get_im_dim(im): 267 | im = cv2.imread(im) 268 | w,h = im.shape[1], im.shape[0] 269 | return w,h 270 | 271 | def unique(tensor): 272 | tensor_np = tensor.cpu().numpy() 273 | unique_np = np.unique(tensor_np) 274 | unique_tensor = torch.from_numpy(unique_np) 275 | 276 | tensor_res = tensor.new(unique_tensor.shape) 277 | tensor_res.copy_(unique_tensor) 278 | return tensor_res 279 | 280 | def write_results(prediction, confidence, num_classes, device, nms = True, nms_conf = 0.4): 281 | conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2) 282 | prediction = prediction*conf_mask 283 | 284 | 285 | try: 286 | ind_nz = torch.nonzero(prediction[:,:,4]).transpose(0,1).contiguous() 287 | except: 288 | return 0 289 | 290 | 291 | box_a = prediction.new(prediction.shape) 292 | box_a[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) 293 | box_a[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2) 294 | box_a[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) 295 | box_a[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) 296 | prediction[:,:,:4] = box_a[:,:,:4] 297 | 298 | 299 | 300 | batch_size = prediction.size(0) 301 | 302 | output = prediction.new(1, prediction.size(2) + 1) 303 | write = False 304 | 305 | 306 | for ind in range(batch_size): 307 | #select the image from the batch 308 | image_pred = prediction[ind] 309 | 310 | 311 | 312 | #Get the class having maximum score, and the index of that class 313 | #Get rid of num_classes softmax scores 314 | #Add the class index and the class score of class having maximum score 315 | max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) 316 | max_conf = max_conf.float().unsqueeze(1) 317 | max_conf_score = max_conf_score.float().unsqueeze(1) 318 | seq = (image_pred[:,:5], max_conf, max_conf_score) 319 | image_pred = torch.cat(seq, 1) 320 | 321 | 322 | 323 | #Get rid of the zero entries 324 | non_zero_ind = (torch.nonzero(image_pred[:,4])) 325 | 326 | 327 | image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7) 328 | 329 | #Get the various classes detected in the image 330 | try: 331 | img_classes = unique(image_pred_[:,-1]) 332 | except: 333 | continue 334 | #WE will do NMS classwise 335 | for cls in img_classes: 336 | if cls == 0: 337 | #get the detections with one particular class 338 | cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) 339 | class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() 340 | 341 | 342 | image_pred_class = image_pred_[class_mask_ind].view(-1,7) 343 | 344 | 345 | 346 | #sort the detections such that the entry with the maximum objectness 347 | #confidence is at the top 348 | conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] 349 | image_pred_class = image_pred_class[conf_sort_index] 350 | idx = image_pred_class.size(0) 351 | 352 | #if nms has to be done 353 | if nms: 354 | #For each detection 355 | for i in range(idx): 356 | #Get the IOUs of all boxes that come after the one we are looking at 357 | #in the loop 358 | try: 359 | ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:], device) 360 | except ValueError: 361 | break 362 | 363 | except IndexError: 364 | break 365 | 366 | #Zero out all the detections that have IoU > treshhold 367 | iou_mask = (ious < nms_conf).float().unsqueeze(1) 368 | image_pred_class[i+1:] *= iou_mask 369 | 370 | #Remove the non-zero entries 371 | non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() 372 | image_pred_class = image_pred_class[non_zero_ind].view(-1,7) 373 | 374 | 375 | 376 | #Concatenate the batch_id of the image to the detection 377 | #this helps us identify which image does the detection correspond to 378 | #We use a linear straucture to hold ALL the detections from the batch 379 | #the batch_dim is flattened 380 | #batch is identified by extra batch column 381 | 382 | 383 | batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) 384 | seq = batch_ind, image_pred_class 385 | if not write: 386 | output = torch.cat(seq,1) 387 | write = True 388 | else: 389 | out = torch.cat(seq,1) 390 | output = torch.cat((output,out)) 391 | 392 | return output 393 | 394 | #!/usr/bin/env python3 395 | # -*- coding: utf-8 -*- 396 | """ 397 | Created on Sat Mar 24 00:12:16 2018 398 | 399 | @author: ayooshmac 400 | """ 401 | 402 | def predict_transform_half(prediction, inp_dim, anchors, num_classes, CUDA = True): 403 | batch_size = prediction.size(0) 404 | stride = inp_dim // prediction.size(2) 405 | 406 | bbox_attrs = 5 + num_classes 407 | num_anchors = len(anchors) 408 | grid_size = inp_dim // stride 409 | 410 | 411 | prediction = prediction.view(batch_size, bbox_attrs*num_anchors, grid_size*grid_size) 412 | prediction = prediction.transpose(1,2).contiguous() 413 | prediction = prediction.view(batch_size, grid_size*grid_size*num_anchors, bbox_attrs) 414 | 415 | 416 | #Sigmoid the centre_X, centre_Y. and object confidencce 417 | prediction[:,:,0] = torch.sigmoid(prediction[:,:,0]) 418 | prediction[:,:,1] = torch.sigmoid(prediction[:,:,1]) 419 | prediction[:,:,4] = torch.sigmoid(prediction[:,:,4]) 420 | 421 | 422 | #Add the center offsets 423 | grid_len = np.arange(grid_size) 424 | a,b = np.meshgrid(grid_len, grid_len) 425 | 426 | x_offset = torch.FloatTensor(a).view(-1,1) 427 | y_offset = torch.FloatTensor(b).view(-1,1) 428 | 429 | if CUDA: 430 | x_offset = x_offset.cuda().half() 431 | y_offset = y_offset.cuda().half() 432 | 433 | x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0) 434 | 435 | prediction[:,:,:2] += x_y_offset 436 | 437 | #log space transform height and the width 438 | anchors = torch.HalfTensor(anchors) 439 | 440 | if CUDA: 441 | anchors = anchors.cuda() 442 | 443 | anchors = anchors.repeat(grid_size*grid_size, 1).unsqueeze(0) 444 | prediction[:,:,2:4] = torch.exp(prediction[:,:,2:4])*anchors 445 | 446 | #Softmax the class scores 447 | prediction[:,:,5: 5 + num_classes] = nn.Softmax(-1)(Variable(prediction[:,:, 5 : 5 + num_classes])).data 448 | 449 | prediction[:,:,:4] *= stride 450 | 451 | 452 | return prediction 453 | 454 | 455 | def write_results_half(prediction, confidence, num_classes, nms = True, nms_conf = 0.4): 456 | conf_mask = (prediction[:,:,4] > confidence).half().unsqueeze(2) 457 | prediction = prediction*conf_mask 458 | 459 | try: 460 | ind_nz = torch.nonzero(prediction[:,:,4]).transpose(0,1).contiguous() 461 | except: 462 | return 0 463 | 464 | 465 | 466 | box_a = prediction.new(prediction.shape) 467 | box_a[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) 468 | box_a[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2) 469 | box_a[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) 470 | box_a[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) 471 | prediction[:,:,:4] = box_a[:,:,:4] 472 | 473 | 474 | 475 | batch_size = prediction.size(0) 476 | 477 | output = prediction.new(1, prediction.size(2) + 1) 478 | write = False 479 | 480 | for ind in range(batch_size): 481 | #select the image from the batch 482 | image_pred = prediction[ind] 483 | 484 | 485 | #Get the class having maximum score, and the index of that class 486 | #Get rid of num_classes softmax scores 487 | #Add the class index and the class score of class having maximum score 488 | max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) 489 | max_conf = max_conf.half().unsqueeze(1) 490 | max_conf_score = max_conf_score.half().unsqueeze(1) 491 | seq = (image_pred[:,:5], max_conf, max_conf_score) 492 | image_pred = torch.cat(seq, 1) 493 | 494 | 495 | #Get rid of the zero entries 496 | non_zero_ind = (torch.nonzero(image_pred[:,4])) 497 | try: 498 | image_pred_ = image_pred[non_zero_ind.squeeze(),:] 499 | except: 500 | continue 501 | 502 | #Get the various classes detected in the image 503 | img_classes = unique(image_pred_[:,-1].long()).half() 504 | 505 | 506 | 507 | 508 | #WE will do NMS classwise 509 | for cls in img_classes: 510 | #get the detections with one particular class 511 | cls_mask = image_pred_*(image_pred_[:,-1] == cls).half().unsqueeze(1) 512 | class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() 513 | 514 | 515 | image_pred_class = image_pred_[class_mask_ind] 516 | 517 | 518 | #sort the detections such that the entry with the maximum objectness 519 | #confidence is at the top 520 | conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] 521 | image_pred_class = image_pred_class[conf_sort_index] 522 | idx = image_pred_class.size(0) 523 | 524 | #if nms has to be done 525 | if nms: 526 | #For each detection 527 | for i in range(idx): 528 | #Get the IOUs of all boxes that come after the one we are looking at 529 | #in the loop 530 | try: 531 | ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:], device) 532 | except ValueError: 533 | break 534 | 535 | except IndexError: 536 | break 537 | 538 | #Zero out all the detections that have IoU > treshhold 539 | iou_mask = (ious < nms_conf).half().unsqueeze(1) 540 | image_pred_class[i+1:] *= iou_mask 541 | 542 | #Remove the non-zero entries 543 | non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() 544 | image_pred_class = image_pred_class[non_zero_ind] 545 | 546 | 547 | 548 | #Concatenate the batch_id of the image to the detection 549 | #this helps us identify which image does the detection correspond to 550 | #We use a linear straucture to hold ALL the detections from the batch 551 | #the batch_dim is flattened 552 | #batch is identified by extra batch column 553 | batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) 554 | seq = batch_ind, image_pred_class 555 | 556 | if not write: 557 | output = torch.cat(seq,1) 558 | write = True 559 | else: 560 | out = torch.cat(seq,1) 561 | output = torch.cat((output,out)) 562 | 563 | return output 564 | --------------------------------------------------------------------------------