├── README.md ├── data ├── testIdxs.txt └── trainIdxs.txt ├── fcrn.py ├── loader.py ├── model └── readme.txt ├── result ├── gt_depth_epoch_100.png ├── input_rgb_epoch_100.png └── pred_depth_epoch_100.png ├── test.py ├── train.py ├── utils.py └── weights.py /README.md: -------------------------------------------------------------------------------- 1 | # fcrn_pytorch 2 | Deeper Depth Prediction with Fully Convolutional Residual Networks(2016 IEEE 3D Vision)的pytorch实现 3 | 4 | 论文:https://arxiv.org/pdf/1606.00373.pdf 5 | 6 | 主要参考:官方源码https://github.com/iro-cp/FCRN-DepthPrediction 7 | 前人实现https://github.com/XPFly1989/FCRN 8 | >fcrn_pytorch: 文件结构 9 | >>data:待处理的数据 10 | 11 | >>>testIdxs.txt trainIdxs.txt nyu_depth_v2_labeled 12 | 13 | >>model:保存模型 14 | 15 | >>>NYU_ResNet-UpProj.npy model_300.pth 16 | 17 | >>result:模型的效果 18 | 19 | >>frcn.py:网络 20 | 21 | >>loader.py:数据预处理 22 | 23 | >>test.py:测试模型 24 | 25 | >>train.py:可以继续训练模型 26 | 27 | >>weights.py:加载官方给出的tensorflow参数 28 | 29 | >>utils.py:功能函数 30 | 31 | (1)下载NYU Depth Dataset V2 Labelled Dataset : http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat. 放在data文件夹中。备份地址:https://pan.baidu.com/s/1rIUbsEUjkZJheEZ5wTb5aA 密码: bfi4 32 | 33 | (2)下载官方tensorflow的训练模型:http://campar.in.tum.de/files/rupprecht/depthpred/NYU_ResNet-UpProj.npy. 放在model文件夹中,也可以下载我训练的模型继续训练。链接:https://pan.baidu.com/s/1KCJ8ssTHmr1JPkoConC33w 提取码:2d8m 34 | 35 | -------------------------------------------------------------------------------- /data/testIdxs.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 9 4 | 14 5 | 15 6 | 16 7 | 17 8 | 18 9 | 21 10 | 28 11 | 29 12 | 30 13 | 31 14 | 32 15 | 33 16 | 34 17 | 35 18 | 36 19 | 37 20 | 38 21 | 39 22 | 40 23 | 41 24 | 42 25 | 43 26 | 46 27 | 47 28 | 56 29 | 57 30 | 59 31 | 60 32 | 61 33 | 62 34 | 63 35 | 76 36 | 77 37 | 78 38 | 79 39 | 84 40 | 85 41 | 86 42 | 87 43 | 88 44 | 89 45 | 90 46 | 91 47 | 117 48 | 118 49 | 119 50 | 125 51 | 126 52 | 127 53 | 128 54 | 129 55 | 131 56 | 132 57 | 133 58 | 134 59 | 137 60 | 153 61 | 154 62 | 155 63 | 167 64 | 168 65 | 169 66 | 171 67 | 172 68 | 173 69 | 174 70 | 175 71 | 176 72 | 180 73 | 181 74 | 182 75 | 183 76 | 184 77 | 185 78 | 186 79 | 187 80 | 188 81 | 189 82 | 190 83 | 191 84 | 192 85 | 193 86 | 194 87 | 195 88 | 196 89 | 197 90 | 198 91 | 199 92 | 200 93 | 201 94 | 202 95 | 207 96 | 208 97 | 209 98 | 210 99 | 211 100 | 212 101 | 220 102 | 221 103 | 222 104 | 250 105 | 264 106 | 271 107 | 272 108 | 273 109 | 279 110 | 280 111 | 281 112 | 282 113 | 283 114 | 284 115 | 285 116 | 296 117 | 297 118 | 298 119 | 299 120 | 300 121 | 301 122 | 302 123 | 310 124 | 311 125 | 312 126 | 315 127 | 316 128 | 317 129 | 325 130 | 326 131 | 327 132 | 328 133 | 329 134 | 330 135 | 331 136 | 332 137 | 333 138 | 334 139 | 335 140 | 351 141 | 352 142 | 355 143 | 356 144 | 357 145 | 358 146 | 359 147 | 360 148 | 361 149 | 362 150 | 363 151 | 364 152 | 384 153 | 385 154 | 386 155 | 387 156 | 388 157 | 389 158 | 390 159 | 395 160 | 396 161 | 397 162 | 411 163 | 412 164 | 413 165 | 414 166 | 430 167 | 431 168 | 432 169 | 433 170 | 434 171 | 435 172 | 441 173 | 442 174 | 443 175 | 444 176 | 445 177 | 446 178 | 447 179 | 448 180 | 462 181 | 463 182 | 464 183 | 465 184 | 466 185 | 469 186 | 470 187 | 471 188 | 472 189 | 473 190 | 474 191 | 475 192 | 476 193 | 477 194 | 508 195 | 509 196 | 510 197 | 511 198 | 512 199 | 513 200 | 515 201 | 516 202 | 517 203 | 518 204 | 519 205 | 520 206 | 521 207 | 522 208 | 523 209 | 524 210 | 525 211 | 526 212 | 531 213 | 532 214 | 533 215 | 537 216 | 538 217 | 539 218 | 549 219 | 550 220 | 551 221 | 555 222 | 556 223 | 557 224 | 558 225 | 559 226 | 560 227 | 561 228 | 562 229 | 563 230 | 564 231 | 565 232 | 566 233 | 567 234 | 568 235 | 569 236 | 570 237 | 571 238 | 579 239 | 580 240 | 581 241 | 582 242 | 583 243 | 591 244 | 592 245 | 593 246 | 594 247 | 603 248 | 604 249 | 605 250 | 606 251 | 607 252 | 612 253 | 613 254 | 617 255 | 618 256 | 619 257 | 620 258 | 621 259 | 633 260 | 634 261 | 635 262 | 636 263 | 637 264 | 638 265 | 644 266 | 645 267 | 650 268 | 651 269 | 656 270 | 657 271 | 658 272 | 663 273 | 664 274 | 668 275 | 669 276 | 670 277 | 671 278 | 672 279 | 673 280 | 676 281 | 677 282 | 678 283 | 679 284 | 680 285 | 681 286 | 686 287 | 687 288 | 688 289 | 689 290 | 690 291 | 693 292 | 694 293 | 697 294 | 698 295 | 699 296 | 706 297 | 707 298 | 708 299 | 709 300 | 710 301 | 711 302 | 712 303 | 713 304 | 717 305 | 718 306 | 724 307 | 725 308 | 726 309 | 727 310 | 728 311 | 731 312 | 732 313 | 733 314 | 734 315 | 743 316 | 744 317 | 759 318 | 760 319 | 761 320 | 762 321 | 763 322 | 764 323 | 765 324 | 766 325 | 767 326 | 768 327 | 769 328 | 770 329 | 771 330 | 772 331 | 773 332 | 774 333 | 775 334 | 776 335 | 777 336 | 778 337 | 779 338 | 780 339 | 781 340 | 782 341 | 783 342 | 784 343 | 785 344 | 786 345 | 787 346 | 800 347 | 801 348 | 802 349 | 803 350 | 804 351 | 810 352 | 811 353 | 812 354 | 813 355 | 814 356 | 821 357 | 822 358 | 823 359 | 833 360 | 834 361 | 835 362 | 836 363 | 837 364 | 838 365 | 839 366 | 840 367 | 841 368 | 842 369 | 843 370 | 844 371 | 845 372 | 846 373 | 850 374 | 851 375 | 852 376 | 857 377 | 858 378 | 859 379 | 860 380 | 861 381 | 862 382 | 869 383 | 870 384 | 871 385 | 906 386 | 907 387 | 908 388 | 917 389 | 918 390 | 919 391 | 926 392 | 927 393 | 928 394 | 932 395 | 933 396 | 934 397 | 935 398 | 945 399 | 946 400 | 947 401 | 959 402 | 960 403 | 961 404 | 962 405 | 965 406 | 966 407 | 967 408 | 970 409 | 971 410 | 972 411 | 973 412 | 974 413 | 975 414 | 976 415 | 977 416 | 991 417 | 992 418 | 993 419 | 994 420 | 995 421 | 1001 422 | 1002 423 | 1003 424 | 1004 425 | 1010 426 | 1011 427 | 1012 428 | 1021 429 | 1022 430 | 1023 431 | 1032 432 | 1033 433 | 1034 434 | 1038 435 | 1039 436 | 1048 437 | 1049 438 | 1052 439 | 1053 440 | 1057 441 | 1058 442 | 1075 443 | 1076 444 | 1077 445 | 1078 446 | 1079 447 | 1080 448 | 1081 449 | 1082 450 | 1083 451 | 1084 452 | 1088 453 | 1089 454 | 1090 455 | 1091 456 | 1092 457 | 1093 458 | 1094 459 | 1095 460 | 1096 461 | 1098 462 | 1099 463 | 1100 464 | 1101 465 | 1102 466 | 1103 467 | 1104 468 | 1106 469 | 1107 470 | 1108 471 | 1109 472 | 1117 473 | 1118 474 | 1119 475 | 1123 476 | 1124 477 | 1125 478 | 1126 479 | 1127 480 | 1128 481 | 1129 482 | 1130 483 | 1131 484 | 1135 485 | 1136 486 | 1144 487 | 1145 488 | 1146 489 | 1147 490 | 1148 491 | 1149 492 | 1150 493 | 1151 494 | 1152 495 | 1153 496 | 1154 497 | 1155 498 | 1156 499 | 1157 500 | 1158 501 | 1162 502 | 1163 503 | 1164 504 | 1165 505 | 1166 506 | 1167 507 | 1170 508 | 1171 509 | 1174 510 | 1175 511 | 1176 512 | 1179 513 | 1180 514 | 1181 515 | 1182 516 | 1183 517 | 1184 518 | 1192 519 | 1193 520 | 1194 521 | 1195 522 | 1196 523 | 1201 524 | 1202 525 | 1203 526 | 1204 527 | 1205 528 | 1206 529 | 1207 530 | 1208 531 | 1209 532 | 1210 533 | 1211 534 | 1212 535 | 1216 536 | 1217 537 | 1218 538 | 1219 539 | 1220 540 | 1226 541 | 1227 542 | 1228 543 | 1229 544 | 1230 545 | 1233 546 | 1234 547 | 1235 548 | 1247 549 | 1248 550 | 1249 551 | 1250 552 | 1254 553 | 1255 554 | 1256 555 | 1257 556 | 1258 557 | 1259 558 | 1260 559 | 1261 560 | 1262 561 | 1263 562 | 1264 563 | 1265 564 | 1275 565 | 1276 566 | 1277 567 | 1278 568 | 1279 569 | 1280 570 | 1285 571 | 1286 572 | 1287 573 | 1288 574 | 1289 575 | 1290 576 | 1291 577 | 1292 578 | 1293 579 | 1294 580 | 1295 581 | 1297 582 | 1298 583 | 1299 584 | 1302 585 | 1303 586 | 1304 587 | 1305 588 | 1306 589 | 1307 590 | 1308 591 | 1314 592 | 1315 593 | 1329 594 | 1330 595 | 1331 596 | 1332 597 | 1335 598 | 1336 599 | 1337 600 | 1338 601 | 1339 602 | 1340 603 | 1347 604 | 1348 605 | 1349 606 | 1353 607 | 1354 608 | 1355 609 | 1356 610 | 1364 611 | 1365 612 | 1368 613 | 1369 614 | 1384 615 | 1385 616 | 1386 617 | 1387 618 | 1388 619 | 1389 620 | 1390 621 | 1391 622 | 1394 623 | 1395 624 | 1396 625 | 1397 626 | 1398 627 | 1399 628 | 1400 629 | 1401 630 | 1407 631 | 1408 632 | 1409 633 | 1410 634 | 1411 635 | 1412 636 | 1413 637 | 1414 638 | 1421 639 | 1422 640 | 1423 641 | 1424 642 | 1430 643 | 1431 644 | 1432 645 | 1433 646 | 1441 647 | 1442 648 | 1443 649 | 1444 650 | 1445 651 | 1446 652 | 1447 653 | 1448 654 | 1449 -------------------------------------------------------------------------------- /data/trainIdxs.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 4 3 | 5 4 | 6 5 | 7 6 | 8 7 | 10 8 | 11 9 | 12 10 | 13 11 | 19 12 | 20 13 | 22 14 | 23 15 | 24 16 | 25 17 | 26 18 | 27 19 | 44 20 | 45 21 | 48 22 | 49 23 | 50 24 | 51 25 | 52 26 | 53 27 | 54 28 | 55 29 | 58 30 | 64 31 | 65 32 | 66 33 | 67 34 | 68 35 | 69 36 | 70 37 | 71 38 | 72 39 | 73 40 | 74 41 | 75 42 | 80 43 | 81 44 | 82 45 | 83 46 | 92 47 | 93 48 | 94 49 | 95 50 | 96 51 | 97 52 | 98 53 | 99 54 | 100 55 | 101 56 | 102 57 | 103 58 | 104 59 | 105 60 | 106 61 | 107 62 | 108 63 | 109 64 | 110 65 | 111 66 | 112 67 | 113 68 | 114 69 | 115 70 | 116 71 | 120 72 | 121 73 | 122 74 | 123 75 | 124 76 | 130 77 | 135 78 | 136 79 | 138 80 | 139 81 | 140 82 | 141 83 | 142 84 | 143 85 | 144 86 | 145 87 | 146 88 | 147 89 | 148 90 | 149 91 | 150 92 | 151 93 | 152 94 | 156 95 | 157 96 | 158 97 | 159 98 | 160 99 | 161 100 | 162 101 | 163 102 | 164 103 | 165 104 | 166 105 | 170 106 | 177 107 | 178 108 | 179 109 | 203 110 | 204 111 | 205 112 | 206 113 | 213 114 | 214 115 | 215 116 | 216 117 | 217 118 | 218 119 | 219 120 | 223 121 | 224 122 | 225 123 | 226 124 | 227 125 | 228 126 | 229 127 | 230 128 | 231 129 | 232 130 | 233 131 | 234 132 | 235 133 | 236 134 | 237 135 | 238 136 | 239 137 | 240 138 | 241 139 | 242 140 | 243 141 | 244 142 | 245 143 | 246 144 | 247 145 | 248 146 | 249 147 | 251 148 | 252 149 | 253 150 | 254 151 | 255 152 | 256 153 | 257 154 | 258 155 | 259 156 | 260 157 | 261 158 | 262 159 | 263 160 | 265 161 | 266 162 | 267 163 | 268 164 | 269 165 | 270 166 | 274 167 | 275 168 | 276 169 | 277 170 | 278 171 | 286 172 | 287 173 | 288 174 | 289 175 | 290 176 | 291 177 | 292 178 | 293 179 | 294 180 | 295 181 | 303 182 | 304 183 | 305 184 | 306 185 | 307 186 | 308 187 | 309 188 | 313 189 | 314 190 | 318 191 | 319 192 | 320 193 | 321 194 | 322 195 | 323 196 | 324 197 | 336 198 | 337 199 | 338 200 | 339 201 | 340 202 | 341 203 | 342 204 | 343 205 | 344 206 | 345 207 | 346 208 | 347 209 | 348 210 | 349 211 | 350 212 | 353 213 | 354 214 | 365 215 | 366 216 | 367 217 | 368 218 | 369 219 | 370 220 | 371 221 | 372 222 | 373 223 | 374 224 | 375 225 | 376 226 | 377 227 | 378 228 | 379 229 | 380 230 | 381 231 | 382 232 | 383 233 | 391 234 | 392 235 | 393 236 | 394 237 | 398 238 | 399 239 | 400 240 | 401 241 | 402 242 | 403 243 | 404 244 | 405 245 | 406 246 | 407 247 | 408 248 | 409 249 | 410 250 | 415 251 | 416 252 | 417 253 | 418 254 | 419 255 | 420 256 | 421 257 | 422 258 | 423 259 | 424 260 | 425 261 | 426 262 | 427 263 | 428 264 | 429 265 | 436 266 | 437 267 | 438 268 | 439 269 | 440 270 | 449 271 | 450 272 | 451 273 | 452 274 | 453 275 | 454 276 | 455 277 | 456 278 | 457 279 | 458 280 | 459 281 | 460 282 | 461 283 | 467 284 | 468 285 | 478 286 | 479 287 | 480 288 | 481 289 | 482 290 | 483 291 | 484 292 | 485 293 | 486 294 | 487 295 | 488 296 | 489 297 | 490 298 | 491 299 | 492 300 | 493 301 | 494 302 | 495 303 | 496 304 | 497 305 | 498 306 | 499 307 | 500 308 | 501 309 | 502 310 | 503 311 | 504 312 | 505 313 | 506 314 | 507 315 | 514 316 | 527 317 | 528 318 | 529 319 | 530 320 | 534 321 | 535 322 | 536 323 | 540 324 | 541 325 | 542 326 | 543 327 | 544 328 | 545 329 | 546 330 | 547 331 | 548 332 | 552 333 | 553 334 | 554 335 | 572 336 | 573 337 | 574 338 | 575 339 | 576 340 | 577 341 | 578 342 | 584 343 | 585 344 | 586 345 | 587 346 | 588 347 | 589 348 | 590 349 | 595 350 | 596 351 | 597 352 | 598 353 | 599 354 | 600 355 | 601 356 | 602 357 | 608 358 | 609 359 | 610 360 | 611 361 | 614 362 | 615 363 | 616 364 | 622 365 | 623 366 | 624 367 | 625 368 | 626 369 | 627 370 | 628 371 | 629 372 | 630 373 | 631 374 | 632 375 | 639 376 | 640 377 | 641 378 | 642 379 | 643 380 | 646 381 | 647 382 | 648 383 | 649 384 | 652 385 | 653 386 | 654 387 | 655 388 | 659 389 | 660 390 | 661 391 | 662 392 | 665 393 | 666 394 | 667 395 | 674 396 | 675 397 | 682 398 | 683 399 | 684 400 | 685 401 | 691 402 | 692 403 | 695 404 | 696 405 | 700 406 | 701 407 | 702 408 | 703 409 | 704 410 | 705 411 | 714 412 | 715 413 | 716 414 | 719 415 | 720 416 | 721 417 | 722 418 | 723 419 | 729 420 | 730 421 | 735 422 | 736 423 | 737 424 | 738 425 | 739 426 | 740 427 | 741 428 | 742 429 | 745 430 | 746 431 | 747 432 | 748 433 | 749 434 | 750 435 | 751 436 | 752 437 | 753 438 | 754 439 | 755 440 | 756 441 | 757 442 | 758 443 | 788 444 | 789 445 | 790 446 | 791 447 | 792 448 | 793 449 | 794 450 | 795 451 | 796 452 | 797 453 | 798 454 | 799 455 | 805 456 | 806 457 | 807 458 | 808 459 | 809 460 | 815 461 | 816 462 | 817 463 | 818 464 | 819 465 | 820 466 | 824 467 | 825 468 | 826 469 | 827 470 | 828 471 | 829 472 | 830 473 | 831 474 | 832 475 | 847 476 | 848 477 | 849 478 | 853 479 | 854 480 | 855 481 | 856 482 | 863 483 | 864 484 | 865 485 | 866 486 | 867 487 | 868 488 | 872 489 | 873 490 | 874 491 | 875 492 | 876 493 | 877 494 | 878 495 | 879 496 | 880 497 | 881 498 | 882 499 | 883 500 | 884 501 | 885 502 | 886 503 | 887 504 | 888 505 | 889 506 | 890 507 | 891 508 | 892 509 | 893 510 | 894 511 | 895 512 | 896 513 | 897 514 | 898 515 | 899 516 | 900 517 | 901 518 | 902 519 | 903 520 | 904 521 | 905 522 | 909 523 | 910 524 | 911 525 | 912 526 | 913 527 | 914 528 | 915 529 | 916 530 | 920 531 | 921 532 | 922 533 | 923 534 | 924 535 | 925 536 | 929 537 | 930 538 | 931 539 | 936 540 | 937 541 | 938 542 | 939 543 | 940 544 | 941 545 | 942 546 | 943 547 | 944 548 | 948 549 | 949 550 | 950 551 | 951 552 | 952 553 | 953 554 | 954 555 | 955 556 | 956 557 | 957 558 | 958 559 | 963 560 | 964 561 | 968 562 | 969 563 | 978 564 | 979 565 | 980 566 | 981 567 | 982 568 | 983 569 | 984 570 | 985 571 | 986 572 | 987 573 | 988 574 | 989 575 | 990 576 | 996 577 | 997 578 | 998 579 | 999 580 | 1000 581 | 1005 582 | 1006 583 | 1007 584 | 1008 585 | 1009 586 | 1013 587 | 1014 588 | 1015 589 | 1016 590 | 1017 591 | 1018 592 | 1019 593 | 1020 594 | 1024 595 | 1025 596 | 1026 597 | 1027 598 | 1028 599 | 1029 600 | 1030 601 | 1031 602 | 1035 603 | 1036 604 | 1037 605 | 1040 606 | 1041 607 | 1042 608 | 1043 609 | 1044 610 | 1045 611 | 1046 612 | 1047 613 | 1050 614 | 1051 615 | 1054 616 | 1055 617 | 1056 618 | 1059 619 | 1060 620 | 1061 621 | 1062 622 | 1063 623 | 1064 624 | 1065 625 | 1066 626 | 1067 627 | 1068 628 | 1069 629 | 1070 630 | 1071 631 | 1072 632 | 1073 633 | 1074 634 | 1085 635 | 1086 636 | 1087 637 | 1097 638 | 1105 639 | 1110 640 | 1111 641 | 1112 642 | 1113 643 | 1114 644 | 1115 645 | 1116 646 | 1120 647 | 1121 648 | 1122 649 | 1132 650 | 1133 651 | 1134 652 | 1137 653 | 1138 654 | 1139 655 | 1140 656 | 1141 657 | 1142 658 | 1143 659 | 1159 660 | 1160 661 | 1161 662 | 1168 663 | 1169 664 | 1172 665 | 1173 666 | 1177 667 | 1178 668 | 1185 669 | 1186 670 | 1187 671 | 1188 672 | 1189 673 | 1190 674 | 1191 675 | 1197 676 | 1198 677 | 1199 678 | 1200 679 | 1213 680 | 1214 681 | 1215 682 | 1221 683 | 1222 684 | 1223 685 | 1224 686 | 1225 687 | 1231 688 | 1232 689 | 1236 690 | 1237 691 | 1238 692 | 1239 693 | 1240 694 | 1241 695 | 1242 696 | 1243 697 | 1244 698 | 1245 699 | 1246 700 | 1251 701 | 1252 702 | 1253 703 | 1266 704 | 1267 705 | 1268 706 | 1269 707 | 1270 708 | 1271 709 | 1272 710 | 1273 711 | 1274 712 | 1281 713 | 1282 714 | 1283 715 | 1284 716 | 1296 717 | 1300 718 | 1301 719 | 1309 720 | 1310 721 | 1311 722 | 1312 723 | 1313 724 | 1316 725 | 1317 726 | 1318 727 | 1319 728 | 1320 729 | 1321 730 | 1322 731 | 1323 732 | 1324 733 | 1325 734 | 1326 735 | 1327 736 | 1328 737 | 1333 738 | 1334 739 | 1341 740 | 1342 741 | 1343 742 | 1344 743 | 1345 744 | 1346 745 | 1350 746 | 1351 747 | 1352 748 | 1357 749 | 1358 750 | 1359 751 | 1360 752 | 1361 753 | 1362 754 | 1363 755 | 1366 756 | 1367 757 | 1370 758 | 1371 759 | 1372 760 | 1373 761 | 1374 762 | 1375 763 | 1376 764 | 1377 765 | 1378 766 | 1379 767 | 1380 768 | 1381 769 | 1382 770 | 1383 771 | 1392 772 | 1393 773 | 1402 774 | 1403 775 | 1404 776 | 1405 777 | 1406 778 | 1415 779 | 1416 780 | 1417 781 | 1418 782 | 1419 783 | 1420 784 | 1425 785 | 1426 786 | 1427 787 | 1428 788 | 1429 789 | 1434 790 | 1435 791 | 1436 792 | 1437 793 | 1438 794 | 1439 795 | 1440 -------------------------------------------------------------------------------- /fcrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional 4 | import math 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | expansion = 4 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None): 11 | super(Bottleneck, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, 16 | bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = nn.BatchNorm2d(planes * 4) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv3(out) 36 | out = self.bn3(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class UpProject(nn.Module): 48 | 49 | def __init__(self, in_channels, out_channels, batch_size): 50 | super(UpProject, self).__init__() 51 | self.batch_size = batch_size 52 | 53 | self.conv1_1 = nn.Conv2d(in_channels, out_channels, 3) 54 | self.conv1_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) 55 | self.conv1_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) 56 | self.conv1_4 = nn.Conv2d(in_channels, out_channels, 2) 57 | 58 | self.conv2_1 = nn.Conv2d(in_channels, out_channels, 3) 59 | self.conv2_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) 60 | self.conv2_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) 61 | self.conv2_4 = nn.Conv2d(in_channels, out_channels, 2) 62 | 63 | self.bn1_1 = nn.BatchNorm2d(out_channels) 64 | self.bn1_2 = nn.BatchNorm2d(out_channels) 65 | 66 | self.relu = nn.ReLU(inplace=True) 67 | 68 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 69 | 70 | self.bn2 = nn.BatchNorm2d(out_channels) 71 | 72 | def forward(self, x): 73 | # b, 10, 8, 1024 74 | out1_1 = self.conv1_1(nn.functional.pad(x, (1, 1, 1, 1))) 75 | out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding 76 | #out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github 77 | out1_3 = self.conv1_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding 78 | #out1_3 = self.conv1_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github 79 | out1_4 = self.conv1_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding 80 | #out1_4 = self.conv1_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github 81 | 82 | out2_1 = self.conv2_1(nn.functional.pad(x, (1, 1, 1, 1))) 83 | out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding 84 | #out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 1, 0)))#author's interleaving pading in github 85 | out2_3 = self.conv2_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding 86 | #out2_3 = self.conv2_3(nn.functional.pad(x, (1, 0, 1, 1)))#author's interleaving pading in github 87 | out2_4 = self.conv2_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding 88 | #out2_4 = self.conv2_4(nn.functional.pad(x, (1, 0, 1, 0)))#author's interleaving pading in github 89 | 90 | height = out1_1.size()[2] 91 | width = out1_1.size()[3] 92 | 93 | out1_1_2 = torch.stack((out1_1, out1_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( 94 | self.batch_size, -1, height, width * 2) 95 | out1_3_4 = torch.stack((out1_3, out1_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( 96 | self.batch_size, -1, height, width * 2) 97 | 98 | out1_1234 = torch.stack((out1_1_2, out1_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view( 99 | self.batch_size, -1, height * 2, width * 2) 100 | 101 | out2_1_2 = torch.stack((out2_1, out2_2), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( 102 | self.batch_size, -1, height, width * 2) 103 | out2_3_4 = torch.stack((out2_3, out2_4), dim=-3).permute(0, 1, 3, 4, 2).contiguous().view( 104 | self.batch_size, -1, height, width * 2) 105 | 106 | out2_1234 = torch.stack((out2_1_2, out2_3_4), dim=-3).permute(0, 1, 3, 2, 4).contiguous().view( 107 | self.batch_size, -1, height * 2, width * 2) 108 | 109 | out1 = self.bn1_1(out1_1234) 110 | out1 = self.relu(out1) 111 | out1 = self.conv3(out1) 112 | out1 = self.bn2(out1) 113 | 114 | out2 = self.bn1_2(out2_1234) 115 | 116 | out = out1 + out2 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class FCRN(nn.Module): 123 | 124 | def __init__(self, batch_size): 125 | super(FCRN, self).__init__() 126 | self.inplanes = 64 127 | self.batch_size = batch_size 128 | # b, 304, 228, 3 129 | # ResNet with out avrgpool & fc 130 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# b, 152 114, 64 131 | self.bn1 = nn.BatchNorm2d(64) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# b, 76, 57, 64 134 | self.layer1 = self._make_layer(Bottleneck, 64, 3) #b, 76, 57, 256 135 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)# b, 38, 29, 512 136 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)# b, 19, 15, 1024 137 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)# b, 10, 8, 2048 138 | 139 | # Up-Conv layers 140 | self.conv2 = nn.Conv2d(2048, 1024, kernel_size=1, bias=False)# b, 10, 8, 1024 141 | self.bn2 = nn.BatchNorm2d(1024) 142 | 143 | self.up1 = self._make_upproj_layer(UpProject, 1024, 512, self.batch_size) 144 | self.up2 = self._make_upproj_layer(UpProject, 512, 256, self.batch_size) 145 | self.up3 = self._make_upproj_layer(UpProject, 256, 128, self.batch_size) 146 | self.up4 = self._make_upproj_layer(UpProject, 128, 64, self.batch_size) 147 | 148 | self.drop = nn.Dropout2d() 149 | 150 | self.conv3 = nn.Conv2d(64, 1, 3, padding=1) 151 | 152 | self.upsample = nn.Upsample((228, 304), mode='bilinear') 153 | 154 | # initialize 155 | initialize = False 156 | if initialize: 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 160 | m.weight.data.normal_(0, math.sqrt(2. / n)) 161 | elif isinstance(m, nn.BatchNorm2d): 162 | m.weight.data.fill_(1) 163 | m.bias.data.zero_() 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1): 166 | downsample = None 167 | if stride != 1 or self.inplanes != planes * block.expansion: 168 | downsample = nn.Sequential( 169 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, 170 | stride=stride, bias=False), 171 | nn.BatchNorm2d(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample)) 176 | self.inplanes = planes * block.expansion 177 | for i in range(1, blocks): 178 | layers.append(block(self.inplanes, planes)) 179 | 180 | return nn.Sequential(*layers) 181 | 182 | def _make_upproj_layer(self, block, in_channels, out_channels, batch_size): 183 | return block(in_channels, out_channels, batch_size) 184 | 185 | def forward(self, x): 186 | x = self.conv1(x) 187 | x = self.bn1(x) 188 | x = self.relu(x) 189 | x = self.maxpool(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | x = self.layer4(x) 195 | 196 | x = self.conv2(x) 197 | x = self.bn2(x) 198 | 199 | x = self.up1(x) 200 | x = self.up2(x) 201 | x = self.up3(x) 202 | x = self.up4(x) 203 | 204 | x = self.drop(x) 205 | 206 | x = self.conv3(x) 207 | x = self.relu(x) 208 | 209 | x = self.upsample(x) 210 | 211 | return x 212 | 213 | from torchsummary import summary 214 | # 测试网络模型 215 | if __name__ == '__main__': 216 | batch_size = 1 217 | net = FCRN(batch_size).cuda() 218 | x = torch.zeros(batch_size, 3,304,228).cuda() 219 | print(net(x).size()) 220 | summary(net, (3, 304, 228)) 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | from PIL import Image 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | from utils import load_split 9 | 10 | 11 | class NyuDepthLoader(data.Dataset): 12 | def __init__(self, data_path, lists): 13 | self.data_path = data_path 14 | self.lists = lists 15 | 16 | self.nyu = h5py.File(self.data_path) 17 | 18 | self.imgs = self.nyu['images'] 19 | self.dpts = self.nyu['depths'] 20 | 21 | def __getitem__(self, index): 22 | img_idx = self.lists[index] 23 | img = self.imgs[img_idx].transpose(2, 1, 0) # HWC 24 | dpt = self.dpts[img_idx].transpose(1, 0) 25 | img = Image.fromarray(img) 26 | dpt = Image.fromarray(dpt) 27 | input_transform = transforms.Compose([transforms.Resize(228), 28 | transforms.ToTensor()]) 29 | 30 | target_depth_transform = transforms.Compose([transforms.Resize(228), 31 | transforms.ToTensor()]) 32 | 33 | img = input_transform(img) 34 | dpt = target_depth_transform(dpt) 35 | return img, dpt 36 | 37 | def __len__(self): 38 | return len(self.lists) 39 | 40 | # 测试数据加载 41 | import matplotlib 42 | matplotlib.use('Agg') 43 | import matplotlib.pyplot as plt 44 | def test_loader(): 45 | batch_size = 16 46 | data_path = './data/nyu_depth_v2_labeled.mat' 47 | # 1.Load data 48 | train_lists, val_lists, test_lists = load_split() 49 | 50 | print("Loading data...") 51 | train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists), 52 | batch_size=batch_size, shuffle=True, drop_last=True) 53 | for input, depth in train_loader: 54 | print(input.size()) 55 | break 56 | #input_rgb_image = input[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 57 | input_rgb_image = input[0].data.permute(1, 2, 0) 58 | input_gt_depth_image = depth[0][0].data.cpu().numpy().astype(np.float32) 59 | 60 | input_gt_depth_image /= np.max(input_gt_depth_image) 61 | plt.imshow(input_rgb_image) 62 | plt.show() 63 | plt.imshow(input_gt_depth_image, cmap="viridis") 64 | plt.show() 65 | # plot.imsave('input_rgb_epoch_0.png', input_rgb_image) 66 | # plot.imsave('gt_depth_epoch_0.png', input_gt_depth_image, cmap="viridis") 67 | 68 | if __name__ == '__main__': 69 | test_loader() 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /model/readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhohn0/fcrn_pytorch/5cb562272d9939d9f9886b0fdb8768d84b998e90/model/readme.txt -------------------------------------------------------------------------------- /result/gt_depth_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhohn0/fcrn_pytorch/5cb562272d9939d9f9886b0fdb8768d84b998e90/result/gt_depth_epoch_100.png -------------------------------------------------------------------------------- /result/input_rgb_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhohn0/fcrn_pytorch/5cb562272d9939d9f9886b0fdb8768d84b998e90/result/input_rgb_epoch_100.png -------------------------------------------------------------------------------- /result/pred_depth_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhohn0/fcrn_pytorch/5cb562272d9939d9f9886b0fdb8768d84b998e90/result/pred_depth_epoch_100.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from fcrn import FCRN 5 | from train import load_split 6 | from torch.autograd import Variable 7 | from loader import NyuDepthLoader 8 | import matplotlib.pyplot as plot 9 | 10 | data_path = './data/nyu_depth_v2_labeled.mat' 11 | dtype = torch.cuda.FloatTensor 12 | 13 | batch_size = 1 14 | resume_from_file = True 15 | Threshold_1_25 = 0 16 | Threshold_1_25_2 = 0 17 | Threshold_1_25_3 = 0 18 | RMSE_linear = 0.0 19 | RMSE_log = 0.0 20 | RMSE_log_scale_invariant = 0.0 21 | ARD = 0.0 22 | SRD = 0.0 23 | 24 | model = FCRN(batch_size) 25 | model = model.cuda() 26 | loss_fn = torch.nn.MSELoss().cuda() 27 | 28 | resume_file = './model/model_100.pth' 29 | 30 | if resume_from_file: 31 | if os.path.isfile(resume_file): 32 | print("=> loading checkpoint '{}'".format(resume_file)) 33 | checkpoint = torch.load(resume_file) 34 | start_epoch = checkpoint['epoch'] 35 | model.load_state_dict(checkpoint['state_dict']) 36 | print("=> loaded checkpoint '{}' (epoch {})" 37 | .format(resume_file, checkpoint['epoch'])) 38 | else: 39 | print("=> no checkpoint found at '{}'".format(resume_file)) 40 | 41 | _, _, test_lists = load_split() 42 | num_samples = len(test_lists) 43 | 44 | test_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, test_lists), 45 | batch_size=batch_size, shuffle=False, drop_last=False) 46 | model.eval() 47 | idx = 0 48 | with torch.no_grad(): 49 | for input, gt in test_loader: 50 | input_var = Variable(input.type(dtype)) 51 | gt_var = Variable(gt.type(dtype)) 52 | 53 | output = model(input_var) 54 | 55 | #input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 56 | input_rgb_image = input[0].data.permute(1, 2, 0) 57 | input_gt_depth_image = gt_var[0].data.squeeze().cpu().numpy().astype(np.float32) 58 | pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32) 59 | 60 | input_gt_depth_image /= np.max(input_gt_depth_image) 61 | pred_depth_image /= np.max(pred_depth_image) 62 | 63 | idx = idx + 1 64 | if idx + 1 == len(test_loader): 65 | print('predict complete.') 66 | plot.imsave('Test_input_rgb_{:05d}.png'.format(idx), input_rgb_image) 67 | plot.imsave('Test_gt_depth_{:05d}.png'.format(idx), input_gt_depth_image, cmap="viridis") 68 | plot.imsave('Test_pred_depth_{:05d}.png'.format(idx), pred_depth_image, cmap="viridis") 69 | 70 | 71 | n = np.sum(input_gt_depth_image > 1e-3) #计算值大于1e-3的个数 72 | 73 | idxs = (input_gt_depth_image <= 1e-3) # 返回与原始数据同维的布尔值 74 | pred_depth_image[idxs] = 1 # 将小于1e-3赋值成1 75 | input_gt_depth_image[idxs] = 1 76 | 77 | pred_d_gt = pred_depth_image / input_gt_depth_image 78 | pred_d_gt[idxs] = 100 79 | gt_d_pred = input_gt_depth_image / pred_depth_image 80 | gt_d_pred[idxs] = 100 81 | 82 | Threshold_1_25 += np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25) / n #np.maximum返回相对较大的值 83 | Threshold_1_25_2 += np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25 * 1.25) / n 84 | Threshold_1_25_3 += np.sum(np.maximum(pred_d_gt, gt_d_pred) < 1.25 * 1.25 * 1.25) / n 85 | 86 | log_pred = np.log(pred_depth_image) 87 | log_gt = np.log(input_gt_depth_image) 88 | 89 | d_i = log_gt - log_pred 90 | 91 | RMSE_linear += np.sqrt(np.sum((pred_depth_image - input_gt_depth_image) ** 2) / n) 92 | RMSE_log += np.sqrt(np.sum((log_pred - log_gt) ** 2) / n) 93 | RMSE_log_scale_invariant += np.sum(d_i ** 2) / n + (np.sum(d_i) ** 2) / (n ** 2) 94 | ARD += np.sum(np.abs((pred_depth_image - input_gt_depth_image)) / input_gt_depth_image) / n 95 | SRD += np.sum(((pred_depth_image - input_gt_depth_image) ** 2) / input_gt_depth_image) / n 96 | 97 | Threshold_1_25 /= num_samples 98 | Threshold_1_25_2 /= num_samples 99 | Threshold_1_25_3 /= num_samples 100 | RMSE_linear /= num_samples 101 | RMSE_log /= num_samples 102 | RMSE_log_scale_invariant /= num_samples 103 | ARD /= num_samples 104 | SRD /= num_samples 105 | 106 | print('Threshold_1_25: {}'.format(Threshold_1_25)) 107 | print('Threshold_1_25_2: {}'.format(Threshold_1_25_2)) 108 | print('Threshold_1_25_3: {}'.format(Threshold_1_25_3)) 109 | print('RMSE_linear: {}'.format(RMSE_linear)) 110 | print('RMSE_log: {}'.format(RMSE_log)) 111 | print('RMSE_log_scale_invariant: {}'.format(RMSE_log_scale_invariant)) 112 | print('ARD: {}'.format(ARD)) 113 | print('SRD: {}'.format(SRD)) 114 | 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | from loader import * 3 | import os 4 | from fcrn import FCRN 5 | from torch.autograd import Variable 6 | from weights import load_weights 7 | from utils import load_split, loss_mse, loss_huber 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plot 11 | 12 | dtype = torch.cuda.FloatTensor 13 | weights_file = "./model/NYU_ResNet-UpProj.npy" 14 | 15 | 16 | def main(): 17 | batch_size = 16 18 | data_path = './data/nyu_depth_v2_labeled.mat' 19 | learning_rate = 1.0e-4 20 | monentum = 0.9 21 | weight_decay = 0.0005 22 | num_epochs = 100 23 | 24 | 25 | # 1.Load data 26 | train_lists, val_lists, test_lists = load_split() 27 | print("Loading data...") 28 | train_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, train_lists), 29 | batch_size=batch_size, shuffle=False, drop_last=True) 30 | val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists), 31 | batch_size=batch_size, shuffle=True, drop_last=True) 32 | test_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, test_lists), 33 | batch_size=batch_size, shuffle=True, drop_last=True) 34 | print(train_loader) 35 | # 2.Load model 36 | print("Loading model...") 37 | model = FCRN(batch_size) 38 | model.load_state_dict(load_weights(model, weights_file, dtype)) #加载官方参数,从tensorflow转过来 39 | #加载训练模型 40 | resume_from_file = False 41 | resume_file = './model/model_300.pth' 42 | if resume_from_file: 43 | if os.path.isfile(resume_file): 44 | checkpoint = torch.load(resume_file) 45 | start_epoch = checkpoint['epoch'] 46 | model.load_state_dict(checkpoint['state_dict']) 47 | print("loaded checkpoint '{}' (epoch {})" 48 | .format(resume_file, checkpoint['epoch'])) 49 | else: 50 | print("can not find!") 51 | model = model.cuda() 52 | 53 | # 3.Loss 54 | # 官方MSE 55 | # loss_fn = torch.nn.MSELoss() 56 | # 自定义MSE 57 | # loss_fn = loss_mse() 58 | # 论文的loss,the reverse Huber 59 | loss_fn = loss_huber() 60 | print("loss_fn set...") 61 | 62 | # 4.Optim 63 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 64 | print("optimizer set...") 65 | 66 | # 5.Train 67 | best_val_err = 1.0e-4 68 | start_epoch = 0 69 | 70 | for epoch in range(num_epochs): 71 | print('Starting train epoch %d / %d' % (start_epoch + epoch + 1, num_epochs + start_epoch)) 72 | model.train() 73 | running_loss = 0 74 | count = 0 75 | epoch_loss = 0 76 | for input, depth in train_loader: 77 | 78 | input_var = Variable(input.type(dtype)) 79 | depth_var = Variable(depth.type(dtype)) 80 | 81 | output = model(input_var) 82 | loss = loss_fn(output, depth_var) 83 | print('loss: %f' % loss.data.cpu().item()) 84 | count += 1 85 | running_loss += loss.data.cpu().numpy() 86 | 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | epoch_loss = running_loss / count 92 | print('epoch loss:', epoch_loss) 93 | 94 | # validate 95 | model.eval() 96 | num_correct, num_samples = 0, 0 97 | loss_local = 0 98 | with torch.no_grad(): 99 | for input, depth in val_loader: 100 | input_var = Variable(input.type(dtype)) 101 | depth_var = Variable(depth.type(dtype)) 102 | 103 | output = model(input_var) 104 | if num_epochs == epoch + 1: 105 | # 关于保存的测试图片可以参考 loader 的写法 106 | # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 107 | input_rgb_image = input[0].data.permute(1, 2, 0) 108 | input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32) 109 | pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32) 110 | 111 | input_gt_depth_image /= np.max(input_gt_depth_image) 112 | pred_depth_image /= np.max(pred_depth_image) 113 | 114 | plot.imsave('./result/input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image) 115 | plot.imsave('./result/gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, cmap="viridis") 116 | plot.imsave('./result/pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, cmap="viridis") 117 | 118 | loss_local += loss_fn(output, depth_var) 119 | 120 | num_samples += 1 121 | 122 | err = float(loss_local) / num_samples 123 | print('val_error: %f' % err) 124 | 125 | if err < best_val_err or epoch == num_epochs - 1: 126 | best_val_err = err 127 | torch.save({ 128 | 'epoch': start_epoch + epoch + 1, 129 | 'state_dict': model.state_dict(), 130 | 'optimizer': optimizer.state_dict(), 131 | }, './model/model_' + str(start_epoch + epoch + 1) + '.pth') 132 | 133 | if epoch % 10 == 0: 134 | learning_rate = learning_rate * 0.8 135 | 136 | 137 | if __name__ == '__main__': 138 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | # 自定义损失函数 7 | class loss_huber(nn.Module): 8 | def __init__(self): 9 | super(loss_huber,self).__init__() 10 | 11 | def forward(self, pred, truth): 12 | c = pred.shape[1] #通道 13 | h = pred.shape[2] #高 14 | w = pred.shape[3] #宽 15 | pred = pred.view(-1, c * h * w) 16 | truth = truth.view(-1, c * h * w) 17 | # 根据当前batch所有像素计算阈值 18 | t = 0.2 * torch.max(torch.abs(pred - truth)) 19 | # 计算L1范数 20 | l1 = torch.mean(torch.mean(torch.abs(pred - truth), 1), 0) 21 | # 计算论文中的L2 22 | l2 = torch.mean(torch.mean(((pred - truth)**2 + t**2) / t / 2, 1), 0) 23 | 24 | if l1 > t: 25 | return l2 26 | else: 27 | return l1 28 | 29 | class loss_mse(nn.Module): 30 | def __init__(self): 31 | super(loss_mse, self).__init__() 32 | def forward(self, pred, truth): 33 | c = pred.shape[1] 34 | h = pred.shape[2] 35 | w = pred.shape[3] 36 | pred = pred.view(-1, c * h * w) 37 | truth = truth.view(-1, c * h * w) 38 | return torch.mean(torch.mean((pred - truth), 1)**2, 0) 39 | 40 | if __name__ == '__main__': 41 | loss = loss_huber() 42 | x = torch.zeros(2, 1, 2, 2) 43 | y = torch.ones(2, 1, 2, 2) 44 | c = x.shape[1] 45 | h = x.shape[2] 46 | w = x.shape[3] 47 | r = loss(x, y) 48 | print(r) 49 | 50 | 51 | # 加载数据集的index 52 | def load_split(): 53 | current_directoty = os.getcwd() 54 | train_lists_path = current_directoty + '/data/trainIdxs.txt' 55 | test_lists_path = current_directoty + '/data/testIdxs.txt' 56 | 57 | train_f = open(train_lists_path) 58 | test_f = open(test_lists_path) 59 | 60 | train_lists = [] 61 | test_lists = [] 62 | 63 | train_lists_line = train_f.readline() 64 | while train_lists_line: 65 | train_lists.append(int(train_lists_line) - 1) 66 | train_lists_line = train_f.readline() 67 | train_f.close() 68 | 69 | test_lists_line = test_f.readline() 70 | while test_lists_line: 71 | test_lists.append(int(test_lists_line) - 1) 72 | test_lists_line = test_f.readline() 73 | test_f.close() 74 | 75 | val_start_idx = int(len(train_lists) * 0.8) 76 | 77 | val_lists = train_lists[val_start_idx:-1] 78 | train_lists = train_lists[0:val_start_idx] 79 | 80 | return train_lists, val_lists, test_lists 81 | 82 | # 测试网络 83 | def validate(model, val_loader, loss_fn, dtype): 84 | # validate 85 | model.eval() 86 | num_correct, num_samples = 0, 0 87 | loss_local = 0 88 | with torch.no_grad(): 89 | for input, depth in val_loader: 90 | input_var = Variable(input.type(dtype)) 91 | depth_var = Variable(depth.type(dtype)) 92 | 93 | output = model(input_var) 94 | if num_epochs == epoch + 1: 95 | # 关于保存的测试图片可以参考 loader 的写法 96 | # input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 97 | input_rgb_image = input[0].data.permute(1, 2, 0) 98 | input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32) 99 | pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32) 100 | 101 | input_gt_depth_image /= np.max(input_gt_depth_image) 102 | pred_depth_image /= np.max(pred_depth_image) 103 | 104 | plot.imsave('./result/input_rgb_epoch_{}.png'.format(start_epoch + epoch + 1), input_rgb_image) 105 | plot.imsave('./result/gt_depth_epoch_{}.png'.format(start_epoch + epoch + 1), input_gt_depth_image, 106 | cmap="viridis") 107 | plot.imsave('./result/pred_depth_epoch_{}.png'.format(start_epoch + epoch + 1), pred_depth_image, 108 | cmap="viridis") 109 | 110 | loss_local += loss_fn(output, depth_var) 111 | 112 | num_samples += 1 113 | 114 | err = float(loss_local) / num_samples 115 | print('val_error: %f' % err) -------------------------------------------------------------------------------- /weights.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def load_weights(model, weights_file, dtype): 6 | 7 | model_params = model.state_dict() 8 | data_dict = np.load(weights_file, encoding='latin1').item() 9 | 10 | if True: 11 | model_params['conv1.weight'] = torch.from_numpy(data_dict['conv1']['weights']).type(dtype).permute(3,2,0,1) 12 | #model_params['conv1.bias'] = torch.from_numpy(data_dict['conv1']['biases']).type(dtype) 13 | model_params['bn1.weight'] = torch.from_numpy(data_dict['bn_conv1']['scale']).type(dtype) 14 | model_params['bn1.bias'] = torch.from_numpy(data_dict['bn_conv1']['offset']).type(dtype) 15 | 16 | model_params['layer1.0.downsample.0.weight'] = torch.from_numpy(data_dict['res2a_branch1']['weights']).type(dtype).permute(3,2,0,1) 17 | model_params['layer1.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn2a_branch1']['scale']).type(dtype) 18 | model_params['layer1.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn2a_branch1']['offset']).type(dtype) 19 | 20 | model_params['layer1.0.conv1.weight'] = torch.from_numpy(data_dict['res2a_branch2a']['weights']).type(dtype).permute(3,2,0,1) 21 | model_params['layer1.0.bn1.weight'] = torch.from_numpy(data_dict['bn2a_branch2a']['scale']).type(dtype) 22 | model_params['layer1.0.bn1.bias'] = torch.from_numpy(data_dict['bn2a_branch2a']['offset']).type(dtype) 23 | 24 | model_params['layer1.0.conv2.weight'] = torch.from_numpy(data_dict['res2a_branch2b']['weights']).type(dtype).permute(3,2,0,1) 25 | model_params['layer1.0.bn2.weight'] = torch.from_numpy(data_dict['bn2a_branch2b']['scale']).type(dtype) 26 | model_params['layer1.0.bn2.bias'] = torch.from_numpy(data_dict['bn2a_branch2b']['offset']).type(dtype) 27 | 28 | model_params['layer1.0.conv3.weight'] = torch.from_numpy(data_dict['res2a_branch2c']['weights']).type(dtype).permute(3,2,0,1) 29 | model_params['layer1.0.bn3.weight'] = torch.from_numpy(data_dict['bn2a_branch2c']['scale']).type(dtype) 30 | model_params['layer1.0.bn3.bias'] = torch.from_numpy(data_dict['bn2a_branch2c']['offset']).type(dtype) 31 | 32 | model_params['layer1.1.conv1.weight'] = torch.from_numpy(data_dict['res2b_branch2a']['weights']).type(dtype).permute(3,2,0,1) 33 | model_params['layer1.1.bn1.weight'] = torch.from_numpy(data_dict['bn2b_branch2a']['scale']).type(dtype) 34 | model_params['layer1.1.bn1.bias'] = torch.from_numpy(data_dict['bn2b_branch2a']['offset']).type(dtype) 35 | 36 | model_params['layer1.1.conv2.weight'] = torch.from_numpy(data_dict['res2b_branch2b']['weights']).type(dtype).permute(3,2,0,1) 37 | model_params['layer1.1.bn2.weight'] = torch.from_numpy(data_dict['bn2b_branch2b']['scale']).type(dtype) 38 | model_params['layer1.1.bn2.bias'] = torch.from_numpy(data_dict['bn2b_branch2b']['offset']).type(dtype) 39 | 40 | model_params['layer1.1.conv3.weight'] = torch.from_numpy(data_dict['res2b_branch2c']['weights']).type(dtype).permute(3,2,0,1) 41 | model_params['layer1.1.bn3.weight'] = torch.from_numpy(data_dict['bn2b_branch2c']['scale']).type(dtype) 42 | model_params['layer1.1.bn3.bias'] = torch.from_numpy(data_dict['bn2b_branch2c']['offset']).type(dtype) 43 | 44 | model_params['layer1.2.conv1.weight'] = torch.from_numpy(data_dict['res2c_branch2a']['weights']).type(dtype).permute(3,2,0,1) 45 | model_params['layer1.2.bn1.weight'] = torch.from_numpy(data_dict['bn2c_branch2a']['scale']).type(dtype) 46 | model_params['layer1.2.bn1.bias'] = torch.from_numpy(data_dict['bn2c_branch2a']['offset']).type(dtype) 47 | 48 | model_params['layer1.2.conv2.weight'] = torch.from_numpy(data_dict['res2c_branch2b']['weights']).type(dtype).permute(3,2,0,1) 49 | model_params['layer1.2.bn2.weight'] = torch.from_numpy(data_dict['bn2c_branch2b']['scale']).type(dtype) 50 | model_params['layer1.2.bn2.bias'] = torch.from_numpy(data_dict['bn2c_branch2b']['offset']).type(dtype) 51 | 52 | model_params['layer1.2.conv3.weight'] = torch.from_numpy(data_dict['res2c_branch2c']['weights']).type(dtype).permute(3,2,0,1) 53 | model_params['layer1.2.bn3.weight'] = torch.from_numpy(data_dict['bn2c_branch2c']['scale']).type(dtype) 54 | model_params['layer1.2.bn3.bias'] = torch.from_numpy(data_dict['bn2c_branch2c']['offset']).type(dtype) 55 | 56 | model_params['layer2.0.downsample.0.weight'] = torch.from_numpy(data_dict['res3a_branch1']['weights']).type(dtype).permute(3,2,0,1) 57 | model_params['layer2.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn3a_branch1']['scale']).type(dtype) 58 | model_params['layer2.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn3a_branch1']['offset']).type(dtype) 59 | 60 | model_params['layer2.0.conv1.weight'] = torch.from_numpy(data_dict['res3a_branch2a']['weights']).type(dtype).permute(3,2,0,1) 61 | model_params['layer2.0.bn1.weight'] = torch.from_numpy(data_dict['bn3a_branch2a']['scale']).type(dtype) 62 | model_params['layer2.0.bn1.bias'] = torch.from_numpy(data_dict['bn3a_branch2a']['offset']).type(dtype) 63 | 64 | model_params['layer2.0.conv2.weight'] = torch.from_numpy(data_dict['res3a_branch2b']['weights']).type(dtype).permute(3,2,0,1) 65 | model_params['layer2.0.bn2.weight'] = torch.from_numpy(data_dict['bn3a_branch2b']['scale']).type(dtype) 66 | model_params['layer2.0.bn2.bias'] = torch.from_numpy(data_dict['bn3a_branch2b']['offset']).type(dtype) 67 | 68 | model_params['layer2.0.conv3.weight'] = torch.from_numpy(data_dict['res3a_branch2c']['weights']).type(dtype).permute(3,2,0,1) 69 | model_params['layer2.0.bn3.weight'] = torch.from_numpy(data_dict['bn3a_branch2c']['scale']).type(dtype) 70 | model_params['layer2.0.bn3.bias'] = torch.from_numpy(data_dict['bn3a_branch2c']['offset']).type(dtype) 71 | 72 | model_params['layer2.1.conv1.weight'] = torch.from_numpy(data_dict['res3b_branch2a']['weights']).type(dtype).permute(3,2,0,1) 73 | model_params['layer2.1.bn1.weight'] = torch.from_numpy(data_dict['bn3b_branch2a']['scale']).type(dtype) 74 | model_params['layer2.1.bn1.bias'] = torch.from_numpy(data_dict['bn3b_branch2a']['offset']).type(dtype) 75 | 76 | model_params['layer2.1.conv2.weight'] = torch.from_numpy(data_dict['res3b_branch2b']['weights']).type(dtype).permute(3,2,0,1) 77 | model_params['layer2.1.bn2.weight'] = torch.from_numpy(data_dict['bn3b_branch2b']['scale']).type(dtype) 78 | model_params['layer2.1.bn2.bias'] = torch.from_numpy(data_dict['bn3b_branch2b']['offset']).type(dtype) 79 | 80 | model_params['layer2.1.conv3.weight'] = torch.from_numpy(data_dict['res3b_branch2c']['weights']).type(dtype).permute(3,2,0,1) 81 | model_params['layer2.1.bn3.weight'] = torch.from_numpy(data_dict['bn3b_branch2c']['scale']).type(dtype) 82 | model_params['layer2.1.bn3.bias'] = torch.from_numpy(data_dict['bn3b_branch2c']['offset']).type(dtype) 83 | 84 | model_params['layer2.2.conv1.weight'] = torch.from_numpy(data_dict['res3c_branch2a']['weights']).type(dtype).permute(3,2,0,1) 85 | model_params['layer2.2.bn1.weight'] = torch.from_numpy(data_dict['bn3c_branch2a']['scale']).type(dtype) 86 | model_params['layer2.2.bn1.bias'] = torch.from_numpy(data_dict['bn3c_branch2a']['offset']).type(dtype) 87 | 88 | model_params['layer2.2.conv2.weight'] = torch.from_numpy(data_dict['res3c_branch2b']['weights']).type(dtype).permute(3,2,0,1) 89 | model_params['layer2.2.bn2.weight'] = torch.from_numpy(data_dict['bn3c_branch2b']['scale']).type(dtype) 90 | model_params['layer2.2.bn2.bias'] = torch.from_numpy(data_dict['bn3c_branch2b']['offset']).type(dtype) 91 | 92 | model_params['layer2.2.conv3.weight'] = torch.from_numpy(data_dict['res3c_branch2c']['weights']).type(dtype).permute(3,2,0,1) 93 | model_params['layer2.2.bn3.weight'] = torch.from_numpy(data_dict['bn3c_branch2c']['scale']).type(dtype) 94 | model_params['layer2.2.bn3.bias'] = torch.from_numpy(data_dict['bn3c_branch2c']['offset']).type(dtype) 95 | 96 | model_params['layer2.3.conv1.weight'] = torch.from_numpy(data_dict['res3d_branch2a']['weights']).type(dtype).permute(3,2,0,1) 97 | model_params['layer2.3.bn1.weight'] = torch.from_numpy(data_dict['bn3d_branch2a']['scale']).type(dtype) 98 | model_params['layer2.3.bn1.bias'] = torch.from_numpy(data_dict['bn3d_branch2a']['offset']).type(dtype) 99 | 100 | model_params['layer2.3.conv2.weight'] = torch.from_numpy(data_dict['res3d_branch2b']['weights']).type(dtype).permute(3,2,0,1) 101 | model_params['layer2.3.bn2.weight'] = torch.from_numpy(data_dict['bn3d_branch2b']['scale']).type(dtype) 102 | model_params['layer2.3.bn2.bias'] = torch.from_numpy(data_dict['bn3d_branch2b']['offset']).type(dtype) 103 | 104 | model_params['layer2.3.conv3.weight'] = torch.from_numpy(data_dict['res3d_branch2c']['weights']).type(dtype).permute(3,2,0,1) 105 | model_params['layer2.3.bn3.weight'] = torch.from_numpy(data_dict['bn3d_branch2c']['scale']).type(dtype) 106 | model_params['layer2.3.bn3.bias'] = torch.from_numpy(data_dict['bn3d_branch2c']['offset']).type(dtype) 107 | 108 | model_params['layer3.0.downsample.0.weight'] = torch.from_numpy(data_dict['res4a_branch1']['weights']).type(dtype).permute(3,2,0,1) 109 | model_params['layer3.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn4a_branch1']['scale']).type(dtype) 110 | model_params['layer3.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn4a_branch1']['offset']).type(dtype) 111 | 112 | model_params['layer3.0.conv1.weight'] = torch.from_numpy(data_dict['res4a_branch2a']['weights']).type(dtype).permute(3,2,0,1) 113 | model_params['layer3.0.bn1.weight'] = torch.from_numpy(data_dict['bn4a_branch2a']['scale']).type(dtype) 114 | model_params['layer3.0.bn1.bias'] = torch.from_numpy(data_dict['bn4a_branch2a']['offset']).type(dtype) 115 | 116 | model_params['layer3.0.conv2.weight'] = torch.from_numpy(data_dict['res4a_branch2b']['weights']).type(dtype).permute(3,2,0,1) 117 | model_params['layer3.0.bn2.weight'] = torch.from_numpy(data_dict['bn4a_branch2b']['scale']).type(dtype) 118 | model_params['layer3.0.bn2.bias'] = torch.from_numpy(data_dict['bn4a_branch2b']['offset']).type(dtype) 119 | 120 | model_params['layer3.0.conv3.weight'] = torch.from_numpy(data_dict['res4a_branch2c']['weights']).type(dtype).permute(3,2,0,1) 121 | model_params['layer3.0.bn3.weight'] = torch.from_numpy(data_dict['bn4a_branch2c']['scale']).type(dtype) 122 | model_params['layer3.0.bn3.bias'] = torch.from_numpy(data_dict['bn4a_branch2c']['offset']).type(dtype) 123 | 124 | model_params['layer3.1.conv1.weight'] = torch.from_numpy(data_dict['res4b_branch2a']['weights']).type(dtype).permute(3,2,0,1) 125 | model_params['layer3.1.bn1.weight'] = torch.from_numpy(data_dict['bn4b_branch2a']['scale']).type(dtype) 126 | model_params['layer3.1.bn1.bias'] = torch.from_numpy(data_dict['bn4b_branch2a']['offset']).type(dtype) 127 | 128 | model_params['layer3.1.conv2.weight'] = torch.from_numpy(data_dict['res4b_branch2b']['weights']).type(dtype).permute(3,2,0,1) 129 | model_params['layer3.1.bn2.weight'] = torch.from_numpy(data_dict['bn4b_branch2b']['scale']).type(dtype) 130 | model_params['layer3.1.bn2.bias'] = torch.from_numpy(data_dict['bn4b_branch2b']['offset']).type(dtype) 131 | 132 | model_params['layer3.1.conv3.weight'] = torch.from_numpy(data_dict['res4b_branch2c']['weights']).type(dtype).permute(3,2,0,1) 133 | model_params['layer3.1.bn3.weight'] = torch.from_numpy(data_dict['bn4b_branch2c']['scale']).type(dtype) 134 | model_params['layer3.1.bn3.bias'] = torch.from_numpy(data_dict['bn4b_branch2c']['offset']).type(dtype) 135 | 136 | model_params['layer3.2.conv1.weight'] = torch.from_numpy(data_dict['res4c_branch2a']['weights']).type(dtype).permute(3,2,0,1) 137 | model_params['layer3.2.bn1.weight'] = torch.from_numpy(data_dict['bn4c_branch2a']['scale']).type(dtype) 138 | model_params['layer3.2.bn1.bias'] = torch.from_numpy(data_dict['bn4c_branch2a']['offset']).type(dtype) 139 | 140 | model_params['layer3.2.conv2.weight'] = torch.from_numpy(data_dict['res4c_branch2b']['weights']).type(dtype).permute(3,2,0,1) 141 | model_params['layer3.2.bn2.weight'] = torch.from_numpy(data_dict['bn4c_branch2b']['scale']).type(dtype) 142 | model_params['layer3.2.bn2.bias'] = torch.from_numpy(data_dict['bn4c_branch2b']['offset']).type(dtype) 143 | 144 | model_params['layer3.2.conv3.weight'] = torch.from_numpy(data_dict['res4c_branch2c']['weights']).type(dtype).permute(3,2,0,1) 145 | model_params['layer3.2.bn3.weight'] = torch.from_numpy(data_dict['bn4c_branch2c']['scale']).type(dtype) 146 | model_params['layer3.2.bn3.bias'] = torch.from_numpy(data_dict['bn4c_branch2c']['offset']).type(dtype) 147 | 148 | model_params['layer3.3.conv1.weight'] = torch.from_numpy(data_dict['res4d_branch2a']['weights']).type(dtype).permute(3,2,0,1) 149 | model_params['layer3.3.bn1.weight'] = torch.from_numpy(data_dict['bn4d_branch2a']['scale']).type(dtype) 150 | model_params['layer3.3.bn1.bias'] = torch.from_numpy(data_dict['bn4d_branch2a']['offset']).type(dtype) 151 | 152 | model_params['layer3.3.conv2.weight'] = torch.from_numpy(data_dict['res4d_branch2b']['weights']).type(dtype).permute(3,2,0,1) 153 | model_params['layer3.3.bn2.weight'] = torch.from_numpy(data_dict['bn4d_branch2b']['scale']).type(dtype) 154 | model_params['layer3.3.bn2.bias'] = torch.from_numpy(data_dict['bn4d_branch2b']['offset']).type(dtype) 155 | 156 | model_params['layer3.3.conv3.weight'] = torch.from_numpy(data_dict['res4d_branch2c']['weights']).type(dtype).permute(3,2,0,1) 157 | model_params['layer3.3.bn3.weight'] = torch.from_numpy(data_dict['bn4d_branch2c']['scale']).type(dtype) 158 | model_params['layer3.3.bn3.bias'] = torch.from_numpy(data_dict['bn4d_branch2c']['offset']).type(dtype) 159 | 160 | model_params['layer3.4.conv1.weight'] = torch.from_numpy(data_dict['res4e_branch2a']['weights']).type(dtype).permute(3,2,0,1) 161 | model_params['layer3.4.bn1.weight'] = torch.from_numpy(data_dict['bn4e_branch2a']['scale']).type(dtype) 162 | model_params['layer3.4.bn1.bias'] = torch.from_numpy(data_dict['bn4e_branch2a']['offset']).type(dtype) 163 | 164 | model_params['layer3.4.conv2.weight'] = torch.from_numpy(data_dict['res4e_branch2b']['weights']).type(dtype).permute(3,2,0,1) 165 | model_params['layer3.4.bn2.weight'] = torch.from_numpy(data_dict['bn4e_branch2b']['scale']).type(dtype) 166 | model_params['layer3.4.bn2.bias'] = torch.from_numpy(data_dict['bn4e_branch2b']['offset']).type(dtype) 167 | 168 | model_params['layer3.4.conv3.weight'] = torch.from_numpy(data_dict['res4e_branch2c']['weights']).type(dtype).permute(3,2,0,1) 169 | model_params['layer3.4.bn3.weight'] = torch.from_numpy(data_dict['bn4e_branch2c']['scale']).type(dtype) 170 | model_params['layer3.4.bn3.bias'] = torch.from_numpy(data_dict['bn4e_branch2c']['offset']).type(dtype) 171 | 172 | model_params['layer3.5.conv1.weight'] = torch.from_numpy(data_dict['res4f_branch2a']['weights']).type(dtype).permute(3,2,0,1) 173 | model_params['layer3.5.bn1.weight'] = torch.from_numpy(data_dict['bn4f_branch2a']['scale']).type(dtype) 174 | model_params['layer3.5.bn1.bias'] = torch.from_numpy(data_dict['bn4f_branch2a']['offset']).type(dtype) 175 | 176 | model_params['layer3.5.conv2.weight'] = torch.from_numpy(data_dict['res4f_branch2b']['weights']).type(dtype).permute(3,2,0,1) 177 | model_params['layer3.5.bn2.weight'] = torch.from_numpy(data_dict['bn4f_branch2b']['scale']).type(dtype) 178 | model_params['layer3.5.bn2.bias'] = torch.from_numpy(data_dict['bn4f_branch2b']['offset']).type(dtype) 179 | 180 | model_params['layer3.5.conv3.weight'] = torch.from_numpy(data_dict['res4f_branch2c']['weights']).type(dtype).permute(3,2,0,1) 181 | model_params['layer3.5.bn3.weight'] = torch.from_numpy(data_dict['bn4f_branch2c']['scale']).type(dtype) 182 | model_params['layer3.5.bn3.bias'] = torch.from_numpy(data_dict['bn4f_branch2c']['offset']).type(dtype) 183 | 184 | model_params['layer4.0.downsample.0.weight'] = torch.from_numpy(data_dict['res5a_branch1']['weights']).type(dtype).permute(3,2,0,1) 185 | model_params['layer4.0.downsample.1.weight'] = torch.from_numpy(data_dict['bn5a_branch1']['scale']).type(dtype) 186 | model_params['layer4.0.downsample.1.bias'] = torch.from_numpy(data_dict['bn5a_branch1']['offset']).type(dtype) 187 | 188 | model_params['layer4.0.conv1.weight'] = torch.from_numpy(data_dict['res5a_branch2a']['weights']).type(dtype).permute(3,2,0,1) 189 | model_params['layer4.0.bn1.weight'] = torch.from_numpy(data_dict['bn5a_branch2a']['scale']).type(dtype) 190 | model_params['layer4.0.bn1.bias'] = torch.from_numpy(data_dict['bn5a_branch2a']['offset']).type(dtype) 191 | 192 | model_params['layer4.0.conv2.weight'] = torch.from_numpy(data_dict['res5a_branch2b']['weights']).type(dtype).permute(3,2,0,1) 193 | model_params['layer4.0.bn2.weight'] = torch.from_numpy(data_dict['bn5a_branch2b']['scale']).type(dtype) 194 | model_params['layer4.0.bn2.bias'] = torch.from_numpy(data_dict['bn5a_branch2b']['offset']).type(dtype) 195 | 196 | model_params['layer4.0.conv3.weight'] = torch.from_numpy(data_dict['res5a_branch2c']['weights']).type(dtype).permute(3,2,0,1) 197 | model_params['layer4.0.bn3.weight'] = torch.from_numpy(data_dict['bn5a_branch2c']['scale']).type(dtype) 198 | model_params['layer4.0.bn3.bias'] = torch.from_numpy(data_dict['bn5a_branch2c']['offset']).type(dtype) 199 | 200 | model_params['layer4.1.conv1.weight'] = torch.from_numpy(data_dict['res5b_branch2a']['weights']).type(dtype).permute(3,2,0,1) 201 | model_params['layer4.1.bn1.weight'] = torch.from_numpy(data_dict['bn5b_branch2a']['scale']).type(dtype) 202 | model_params['layer4.1.bn1.bias'] = torch.from_numpy(data_dict['bn5b_branch2a']['offset']).type(dtype) 203 | 204 | model_params['layer4.1.conv2.weight'] = torch.from_numpy(data_dict['res5b_branch2b']['weights']).type(dtype).permute(3,2,0,1) 205 | model_params['layer4.1.bn2.weight'] = torch.from_numpy(data_dict['bn5b_branch2b']['scale']).type(dtype) 206 | model_params['layer4.1.bn2.bias'] = torch.from_numpy(data_dict['bn5b_branch2b']['offset']).type(dtype) 207 | 208 | model_params['layer4.1.conv3.weight'] = torch.from_numpy(data_dict['res5b_branch2c']['weights']).type(dtype).permute(3,2,0,1) 209 | model_params['layer4.1.bn3.weight'] = torch.from_numpy(data_dict['bn5b_branch2c']['scale']).type(dtype) 210 | model_params['layer4.1.bn3.bias'] = torch.from_numpy(data_dict['bn5b_branch2c']['offset']).type(dtype) 211 | 212 | model_params['layer4.2.conv1.weight'] = torch.from_numpy(data_dict['res5c_branch2a']['weights']).type(dtype).permute(3,2,0,1) 213 | model_params['layer4.2.bn1.weight'] = torch.from_numpy(data_dict['bn5c_branch2a']['scale']).type(dtype) 214 | model_params['layer4.2.bn1.bias'] = torch.from_numpy(data_dict['bn5c_branch2a']['offset']).type(dtype) 215 | 216 | model_params['layer4.2.conv2.weight'] = torch.from_numpy(data_dict['res5c_branch2b']['weights']).type(dtype).permute(3,2,0,1) 217 | model_params['layer4.2.bn2.weight'] = torch.from_numpy(data_dict['bn5c_branch2b']['scale']).type(dtype) 218 | model_params['layer4.2.bn2.bias'] = torch.from_numpy(data_dict['bn5c_branch2b']['offset']).type(dtype) 219 | 220 | model_params['layer4.2.conv3.weight'] = torch.from_numpy(data_dict['res5c_branch2c']['weights']).type(dtype).permute(3,2,0,1) 221 | model_params['layer4.2.bn3.weight'] = torch.from_numpy(data_dict['bn5c_branch2c']['scale']).type(dtype) 222 | model_params['layer4.2.bn3.bias'] = torch.from_numpy(data_dict['bn5c_branch2c']['offset']).type(dtype) 223 | 224 | model_params['conv2.weight'] = torch.from_numpy(data_dict['layer1']['weights']).type(dtype).permute(3,2,0,1) 225 | #model_params['conv2.bias'] = torch.from_numpy(data_dict['layer1']['biases']).type(dtype) 226 | model_params['bn2.weight'] = torch.from_numpy(data_dict['layer1_BN']['scale']).type(dtype) 227 | model_params['bn2.bias'] = torch.from_numpy(data_dict['layer1_BN']['offset']).type(dtype) 228 | 229 | # set True to enable weight import, or set False to initialize by yourself 230 | if True: 231 | 232 | model_params['up1.conv1_1.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1) 233 | model_params['up1.conv1_1.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvA']['biases']).type(dtype) 234 | 235 | model_params['up1.conv1_2.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1) 236 | model_params['up1.conv1_2.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvB']['biases']).type(dtype) 237 | 238 | model_params['up1.conv1_3.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1) 239 | model_params['up1.conv1_3.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvC']['biases']).type(dtype) 240 | 241 | model_params['up1.conv1_4.weight'] = torch.from_numpy(data_dict['layer2x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1) 242 | model_params['up1.conv1_4.bias'] = torch.from_numpy(data_dict['layer2x_br1_ConvD']['biases']).type(dtype) 243 | 244 | model_params['up1.bn1_1.weight'] = torch.from_numpy(data_dict['layer2x_br1_BN']['scale']).type(dtype) 245 | model_params['up1.bn1_1.bias'] = torch.from_numpy(data_dict['layer2x_br1_BN']['offset']).type(dtype) 246 | 247 | model_params['up1.conv2_1.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1) 248 | model_params['up1.conv2_1.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvA']['biases']).type(dtype) 249 | 250 | model_params['up1.conv2_2.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1) 251 | model_params['up1.conv2_2.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvB']['biases']).type(dtype) 252 | 253 | model_params['up1.conv2_3.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1) 254 | model_params['up1.conv2_3.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvC']['biases']).type(dtype) 255 | 256 | model_params['up1.conv2_4.weight'] = torch.from_numpy(data_dict['layer2x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1) 257 | model_params['up1.conv2_4.bias'] = torch.from_numpy(data_dict['layer2x_br2_ConvD']['biases']).type(dtype) 258 | 259 | model_params['up1.bn1_2.weight'] = torch.from_numpy(data_dict['layer2x_br2_BN']['scale']).type(dtype) 260 | model_params['up1.bn1_2.bias'] = torch.from_numpy(data_dict['layer2x_br2_BN']['offset']).type(dtype) 261 | 262 | model_params['up1.conv3.weight'] = torch.from_numpy(data_dict['layer2x_Conv']['weights']).type(dtype).permute(3,2,0,1) 263 | model_params['up1.conv3.bias'] = torch.from_numpy(data_dict['layer2x_Conv']['biases']).type(dtype) 264 | 265 | model_params['up1.bn2.weight'] = torch.from_numpy(data_dict['layer2x_BN']['scale']).type(dtype) 266 | model_params['up1.bn2.bias'] = torch.from_numpy(data_dict['layer2x_BN']['offset']).type(dtype) 267 | 268 | model_params['up2.conv1_1.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1) 269 | model_params['up2.conv1_1.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvA']['biases']).type(dtype) 270 | 271 | model_params['up2.conv1_2.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1) 272 | model_params['up2.conv1_2.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvB']['biases']).type(dtype) 273 | 274 | model_params['up2.conv1_3.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1) 275 | model_params['up2.conv1_3.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvC']['biases']).type(dtype) 276 | 277 | model_params['up2.conv1_4.weight'] = torch.from_numpy(data_dict['layer4x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1) 278 | model_params['up2.conv1_4.bias'] = torch.from_numpy(data_dict['layer4x_br1_ConvD']['biases']).type(dtype) 279 | 280 | model_params['up2.bn1_1.weight'] = torch.from_numpy(data_dict['layer4x_br1_BN']['scale']).type(dtype) 281 | model_params['up2.bn1_1.bias'] = torch.from_numpy(data_dict['layer4x_br1_BN']['offset']).type(dtype) 282 | 283 | model_params['up2.conv2_1.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1) 284 | model_params['up2.conv2_1.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvA']['biases']).type(dtype) 285 | 286 | model_params['up2.conv2_2.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1) 287 | model_params['up2.conv2_2.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvB']['biases']).type(dtype) 288 | 289 | model_params['up2.conv2_3.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1) 290 | model_params['up2.conv2_3.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvC']['biases']).type(dtype) 291 | 292 | model_params['up2.conv2_4.weight'] = torch.from_numpy(data_dict['layer4x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1) 293 | model_params['up2.conv2_4.bias'] = torch.from_numpy(data_dict['layer4x_br2_ConvD']['biases']).type(dtype) 294 | 295 | model_params['up2.bn1_2.weight'] = torch.from_numpy(data_dict['layer4x_br2_BN']['scale']).type(dtype) 296 | model_params['up2.bn1_2.bias'] = torch.from_numpy(data_dict['layer4x_br2_BN']['offset']).type(dtype) 297 | 298 | model_params['up2.conv3.weight'] = torch.from_numpy(data_dict['layer4x_Conv']['weights']).type(dtype).permute(3,2,0,1) 299 | model_params['up2.conv3.bias'] = torch.from_numpy(data_dict['layer4x_Conv']['biases']).type(dtype) 300 | 301 | model_params['up2.bn2.weight'] = torch.from_numpy(data_dict['layer4x_BN']['scale']).type(dtype) 302 | model_params['up2.bn2.bias'] = torch.from_numpy(data_dict['layer4x_BN']['offset']).type(dtype) 303 | 304 | model_params['up3.conv1_1.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1) 305 | model_params['up3.conv1_1.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvA']['biases']).type(dtype) 306 | 307 | model_params['up3.conv1_2.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1) 308 | model_params['up3.conv1_2.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvB']['biases']).type(dtype) 309 | 310 | model_params['up3.conv1_3.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1) 311 | model_params['up3.conv1_3.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvC']['biases']).type(dtype) 312 | 313 | model_params['up3.conv1_4.weight'] = torch.from_numpy(data_dict['layer8x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1) 314 | model_params['up3.conv1_4.bias'] = torch.from_numpy(data_dict['layer8x_br1_ConvD']['biases']).type(dtype) 315 | 316 | model_params['up3.bn1_1.weight'] = torch.from_numpy(data_dict['layer8x_br1_BN']['scale']).type(dtype) 317 | model_params['up3.bn1_1.bias'] = torch.from_numpy(data_dict['layer8x_br1_BN']['offset']).type(dtype) 318 | 319 | model_params['up3.conv2_1.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1) 320 | model_params['up3.conv2_1.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvA']['biases']).type(dtype) 321 | 322 | model_params['up3.conv2_2.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1) 323 | model_params['up3.conv2_2.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvB']['biases']).type(dtype) 324 | 325 | model_params['up3.conv2_3.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1) 326 | model_params['up3.conv2_3.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvC']['biases']).type(dtype) 327 | 328 | model_params['up3.conv2_4.weight'] = torch.from_numpy(data_dict['layer8x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1) 329 | model_params['up3.conv2_4.bias'] = torch.from_numpy(data_dict['layer8x_br2_ConvD']['biases']).type(dtype) 330 | 331 | model_params['up3.bn1_2.weight'] = torch.from_numpy(data_dict['layer8x_br2_BN']['scale']).type(dtype) 332 | model_params['up3.bn1_2.bias'] = torch.from_numpy(data_dict['layer8x_br2_BN']['offset']).type(dtype) 333 | 334 | model_params['up3.conv3.weight'] = torch.from_numpy(data_dict['layer8x_Conv']['weights']).type(dtype).permute(3,2,0,1) 335 | model_params['up3.conv3.bias'] = torch.from_numpy(data_dict['layer8x_Conv']['biases']).type(dtype) 336 | 337 | model_params['up3.bn2.weight'] = torch.from_numpy(data_dict['layer8x_BN']['scale']).type(dtype) 338 | model_params['up3.bn2.bias'] = torch.from_numpy(data_dict['layer8x_BN']['offset']).type(dtype) 339 | 340 | model_params['up4.conv1_1.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvA']['weights']).type(dtype).permute(3,2,0,1) 341 | model_params['up4.conv1_1.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvA']['biases']).type(dtype) 342 | 343 | model_params['up4.conv1_2.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvB']['weights']).type(dtype).permute(3,2,0,1) 344 | model_params['up4.conv1_2.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvB']['biases']).type(dtype) 345 | 346 | model_params['up4.conv1_3.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvC']['weights']).type(dtype).permute(3,2,0,1) 347 | model_params['up4.conv1_3.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvC']['biases']).type(dtype) 348 | 349 | model_params['up4.conv1_4.weight'] = torch.from_numpy(data_dict['layer16x_br1_ConvD']['weights']).type(dtype).permute(3,2,0,1) 350 | model_params['up4.conv1_4.bias'] = torch.from_numpy(data_dict['layer16x_br1_ConvD']['biases']).type(dtype) 351 | 352 | model_params['up4.bn1_1.weight'] = torch.from_numpy(data_dict['layer16x_br1_BN']['scale']).type(dtype) 353 | model_params['up4.bn1_1.bias'] = torch.from_numpy(data_dict['layer16x_br1_BN']['offset']).type(dtype) 354 | 355 | model_params['up4.conv2_1.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvA']['weights']).type(dtype).permute(3,2,0,1) 356 | model_params['up4.conv2_1.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvA']['biases']).type(dtype) 357 | 358 | model_params['up4.conv2_2.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvB']['weights']).type(dtype).permute(3,2,0,1) 359 | model_params['up4.conv2_2.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvB']['biases']).type(dtype) 360 | 361 | model_params['up4.conv2_3.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvC']['weights']).type(dtype).permute(3,2,0,1) 362 | model_params['up4.conv2_3.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvC']['biases']).type(dtype) 363 | 364 | model_params['up4.conv2_4.weight'] = torch.from_numpy(data_dict['layer16x_br2_ConvD']['weights']).type(dtype).permute(3,2,0,1) 365 | model_params['up4.conv2_4.bias'] = torch.from_numpy(data_dict['layer16x_br2_ConvD']['biases']).type(dtype) 366 | 367 | model_params['up4.bn1_2.weight'] = torch.from_numpy(data_dict['layer16x_br2_BN']['scale']).type(dtype) 368 | model_params['up4.bn1_2.bias'] = torch.from_numpy(data_dict['layer16x_br2_BN']['offset']).type(dtype) 369 | 370 | model_params['up4.conv3.weight'] = torch.from_numpy(data_dict['layer16x_Conv']['weights']).type(dtype).permute(3,2,0,1) 371 | model_params['up4.conv3.bias'] = torch.from_numpy(data_dict['layer16x_Conv']['biases']).type(dtype) 372 | 373 | model_params['up4.bn2.weight'] = torch.from_numpy(data_dict['layer16x_BN']['scale']).type(dtype) 374 | model_params['up4.bn2.bias'] = torch.from_numpy(data_dict['layer16x_BN']['offset']).type(dtype) 375 | 376 | model_params['conv3.weight'] = torch.from_numpy(data_dict['ConvPred']['weights']).type(dtype).permute(3,2,0,1) 377 | model_params['conv3.bias'] = torch.from_numpy(data_dict['ConvPred']['biases']).type(dtype) 378 | 379 | print('weights loaded...') 380 | return model_params 381 | --------------------------------------------------------------------------------