├── .fig-framework.png ├── .gitignore ├── LICENSE ├── README.md ├── data ├── Seeds │ └── .gitkeep ├── VOC2012 │ ├── train_aug.txt │ └── val.txt └── pretrained │ └── .gitkeep ├── lib ├── __init__.py ├── loader │ ├── __init__.py │ ├── voc_seg_group_loader.py │ └── voc_seg_loader.py ├── models │ ├── __init__.py │ ├── cian_models.py │ ├── cian_modules.py │ ├── layers.py │ ├── layers_custom │ │ ├── __init__.py │ │ ├── constant.py │ │ └── layers_custom_v1.py │ ├── model │ │ ├── __init__.py │ │ ├── efficientnet.py │ │ ├── inception_bn.py │ │ ├── layers_custom │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── constant.py │ │ │ ├── drop_connect.py │ │ │ └── layers_custom_v1.py │ │ ├── resnet.py │ │ ├── resnet_v1.py │ │ ├── vgg.py │ │ └── wide_resnet.py │ ├── multi_scale.py │ └── resnet.py └── utils │ ├── __init__.py │ ├── dataset_tools.py │ ├── image_tools.py │ └── mxnet_tools.py ├── run_cian.sh └── scripts ├── __init__.py ├── eval_segment.py ├── generate_retrain.py └── train_infer_segment.py /.fig-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/.fig-framework.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | __pycache__ 7 | 8 | # Params 9 | *.params 10 | *.states 11 | *.npy 12 | *.pkl 13 | 14 | # caffe 15 | *.caffemodel 16 | 17 | # Vim 18 | *.swp 19 | 20 | snapshot 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Junsong Fan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CIAN: Cross-Image Affinity Net for Weakly Supervised Semantic Segmentation 2 | 3 | This is the code of: 4 | 5 | CIAN: Cross-Image Affinity Net for Weakly Supervised Semantic Segmentation, Junsong Fan, Zhaoxiang Zhang, Tieniu Tan, Chunfeng Song, Jun Xiao, AAAI2020 [[paper]](https://arxiv.org/abs/1811.10842). 6 | 7 | ## Introuduction 8 | 9 | ![fig-framework](.fig-framework.png) 10 | 11 | Framework of the approach. Our approach learns the cross image relationships to help weakly-supervised semantic segmentation. The CIAN Module takes features as input from two different images, extract and change information across them to augment the original features. 12 | 13 | 14 | 15 | ## Prerequisite 16 | 17 | - Python 3.7, MXNet 1.3.1, Numpy, OpenCV, [pydensecrf](https://github.com/lucasb-eyer/pydensecrf). 18 | - NVIDIA GPUs 19 | 20 | 21 | 22 | ## Usage 23 | 24 | #### Prepare the dataset and pretrain parameters: 25 | 26 | - Download the VOC 2012 dataset, and config the path accordingly in the `run_cian.sh`. 27 | 28 | - Download the seeds [here](https://drive.google.com/open?id=1w2WIEtQe2F1tgxlINpk5mi-BlL_gue8z), untar and put the folder `CIAN_SEEDS` into `CIAN/data/Seeds/` . We use the VGG16 based [CAM](http://cnnlocalization.csail.mit.edu/Zhou_Learning_Deep_Features_CVPR_2016_paper.pdf) to generate the foreground and the saliency model [DRFI](https://github.com/playerkk/drfi_cpp) to generate the background. You can also generate the seeds by yourself. 29 | 30 | - Download the ImageNet pretrained parameters and put them into the folder `CIAN/data/pretrained`. We adopt the pretrained parameters provided by the official MXNet [model-zoo](https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/common/modelzoo.py): 31 | - ResNet101 [http://data.mxnet.io/models/imagenet/resnet/101-layers/resnet-101-0000.params] 32 | - ResNet50 [http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50-0000.params] 33 | 34 | #### Run: 35 | 36 | ```bash 37 | ./run_cian.sh 38 | ``` 39 | 40 | This script will automatically run the training, testing (on val set), and retraining pipeline. Checkpoints and predictions will be saved in folder `CIAN/snapshot/CIAN`. 41 | 42 | 43 | 44 | ## Citation 45 | 46 | If you find the code is useful, please consider citing: 47 | 48 | ``` 49 | @inproceedings{fan2020cian, 50 | title={CIAN: Cross-Image Affinity Net for Weakly Supervised Semantic Segmentation}, 51 | author={Fan, Junsong and Zhang, Zhaoxiang and Tan, Tieniu and Song, Chunfeng and Xiao, Jun}, 52 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 53 | year={2020} 54 | } 55 | ``` 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /data/Seeds/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/data/Seeds/.gitkeep -------------------------------------------------------------------------------- /data/VOC2012/val.txt: -------------------------------------------------------------------------------- 1 | 2007_000033 2 | 2007_000042 3 | 2007_000061 4 | 2007_000123 5 | 2007_000129 6 | 2007_000175 7 | 2007_000187 8 | 2007_000323 9 | 2007_000332 10 | 2007_000346 11 | 2007_000452 12 | 2007_000464 13 | 2007_000491 14 | 2007_000529 15 | 2007_000559 16 | 2007_000572 17 | 2007_000629 18 | 2007_000636 19 | 2007_000661 20 | 2007_000663 21 | 2007_000676 22 | 2007_000727 23 | 2007_000762 24 | 2007_000783 25 | 2007_000799 26 | 2007_000804 27 | 2007_000830 28 | 2007_000837 29 | 2007_000847 30 | 2007_000862 31 | 2007_000925 32 | 2007_000999 33 | 2007_001154 34 | 2007_001175 35 | 2007_001239 36 | 2007_001284 37 | 2007_001288 38 | 2007_001289 39 | 2007_001299 40 | 2007_001311 41 | 2007_001321 42 | 2007_001377 43 | 2007_001408 44 | 2007_001423 45 | 2007_001430 46 | 2007_001457 47 | 2007_001458 48 | 2007_001526 49 | 2007_001568 50 | 2007_001585 51 | 2007_001586 52 | 2007_001587 53 | 2007_001594 54 | 2007_001630 55 | 2007_001677 56 | 2007_001678 57 | 2007_001717 58 | 2007_001733 59 | 2007_001761 60 | 2007_001763 61 | 2007_001774 62 | 2007_001884 63 | 2007_001955 64 | 2007_002046 65 | 2007_002094 66 | 2007_002119 67 | 2007_002132 68 | 2007_002260 69 | 2007_002266 70 | 2007_002268 71 | 2007_002284 72 | 2007_002376 73 | 2007_002378 74 | 2007_002387 75 | 2007_002400 76 | 2007_002412 77 | 2007_002426 78 | 2007_002427 79 | 2007_002445 80 | 2007_002470 81 | 2007_002539 82 | 2007_002565 83 | 2007_002597 84 | 2007_002618 85 | 2007_002619 86 | 2007_002624 87 | 2007_002643 88 | 2007_002648 89 | 2007_002719 90 | 2007_002728 91 | 2007_002823 92 | 2007_002824 93 | 2007_002852 94 | 2007_002903 95 | 2007_003011 96 | 2007_003020 97 | 2007_003022 98 | 2007_003051 99 | 2007_003088 100 | 2007_003101 101 | 2007_003106 102 | 2007_003110 103 | 2007_003131 104 | 2007_003134 105 | 2007_003137 106 | 2007_003143 107 | 2007_003169 108 | 2007_003188 109 | 2007_003194 110 | 2007_003195 111 | 2007_003201 112 | 2007_003349 113 | 2007_003367 114 | 2007_003373 115 | 2007_003499 116 | 2007_003503 117 | 2007_003506 118 | 2007_003530 119 | 2007_003571 120 | 2007_003587 121 | 2007_003611 122 | 2007_003621 123 | 2007_003682 124 | 2007_003711 125 | 2007_003714 126 | 2007_003742 127 | 2007_003786 128 | 2007_003841 129 | 2007_003848 130 | 2007_003861 131 | 2007_003872 132 | 2007_003917 133 | 2007_003957 134 | 2007_003991 135 | 2007_004033 136 | 2007_004052 137 | 2007_004112 138 | 2007_004121 139 | 2007_004143 140 | 2007_004189 141 | 2007_004190 142 | 2007_004193 143 | 2007_004241 144 | 2007_004275 145 | 2007_004281 146 | 2007_004380 147 | 2007_004392 148 | 2007_004405 149 | 2007_004468 150 | 2007_004483 151 | 2007_004510 152 | 2007_004538 153 | 2007_004558 154 | 2007_004644 155 | 2007_004649 156 | 2007_004712 157 | 2007_004722 158 | 2007_004856 159 | 2007_004866 160 | 2007_004902 161 | 2007_004969 162 | 2007_005058 163 | 2007_005074 164 | 2007_005107 165 | 2007_005114 166 | 2007_005149 167 | 2007_005173 168 | 2007_005281 169 | 2007_005294 170 | 2007_005296 171 | 2007_005304 172 | 2007_005331 173 | 2007_005354 174 | 2007_005358 175 | 2007_005428 176 | 2007_005460 177 | 2007_005469 178 | 2007_005509 179 | 2007_005547 180 | 2007_005600 181 | 2007_005608 182 | 2007_005626 183 | 2007_005689 184 | 2007_005696 185 | 2007_005705 186 | 2007_005759 187 | 2007_005803 188 | 2007_005813 189 | 2007_005828 190 | 2007_005844 191 | 2007_005845 192 | 2007_005857 193 | 2007_005911 194 | 2007_005915 195 | 2007_005978 196 | 2007_006028 197 | 2007_006035 198 | 2007_006046 199 | 2007_006076 200 | 2007_006086 201 | 2007_006117 202 | 2007_006171 203 | 2007_006241 204 | 2007_006260 205 | 2007_006277 206 | 2007_006348 207 | 2007_006364 208 | 2007_006373 209 | 2007_006444 210 | 2007_006449 211 | 2007_006549 212 | 2007_006553 213 | 2007_006560 214 | 2007_006647 215 | 2007_006678 216 | 2007_006680 217 | 2007_006698 218 | 2007_006761 219 | 2007_006802 220 | 2007_006837 221 | 2007_006841 222 | 2007_006864 223 | 2007_006866 224 | 2007_006946 225 | 2007_007007 226 | 2007_007084 227 | 2007_007109 228 | 2007_007130 229 | 2007_007165 230 | 2007_007168 231 | 2007_007195 232 | 2007_007196 233 | 2007_007203 234 | 2007_007211 235 | 2007_007235 236 | 2007_007341 237 | 2007_007414 238 | 2007_007417 239 | 2007_007470 240 | 2007_007477 241 | 2007_007493 242 | 2007_007498 243 | 2007_007524 244 | 2007_007534 245 | 2007_007624 246 | 2007_007651 247 | 2007_007688 248 | 2007_007748 249 | 2007_007795 250 | 2007_007810 251 | 2007_007815 252 | 2007_007818 253 | 2007_007836 254 | 2007_007849 255 | 2007_007881 256 | 2007_007996 257 | 2007_008051 258 | 2007_008084 259 | 2007_008106 260 | 2007_008110 261 | 2007_008204 262 | 2007_008222 263 | 2007_008256 264 | 2007_008260 265 | 2007_008339 266 | 2007_008374 267 | 2007_008415 268 | 2007_008430 269 | 2007_008543 270 | 2007_008547 271 | 2007_008596 272 | 2007_008645 273 | 2007_008670 274 | 2007_008708 275 | 2007_008722 276 | 2007_008747 277 | 2007_008802 278 | 2007_008815 279 | 2007_008897 280 | 2007_008944 281 | 2007_008964 282 | 2007_008973 283 | 2007_008980 284 | 2007_009015 285 | 2007_009068 286 | 2007_009084 287 | 2007_009088 288 | 2007_009096 289 | 2007_009221 290 | 2007_009245 291 | 2007_009251 292 | 2007_009252 293 | 2007_009258 294 | 2007_009320 295 | 2007_009323 296 | 2007_009331 297 | 2007_009346 298 | 2007_009392 299 | 2007_009413 300 | 2007_009419 301 | 2007_009446 302 | 2007_009458 303 | 2007_009521 304 | 2007_009562 305 | 2007_009592 306 | 2007_009654 307 | 2007_009655 308 | 2007_009684 309 | 2007_009687 310 | 2007_009691 311 | 2007_009706 312 | 2007_009750 313 | 2007_009756 314 | 2007_009764 315 | 2007_009794 316 | 2007_009817 317 | 2007_009841 318 | 2007_009897 319 | 2007_009911 320 | 2007_009923 321 | 2007_009938 322 | 2008_000009 323 | 2008_000016 324 | 2008_000073 325 | 2008_000075 326 | 2008_000080 327 | 2008_000107 328 | 2008_000120 329 | 2008_000123 330 | 2008_000149 331 | 2008_000182 332 | 2008_000213 333 | 2008_000215 334 | 2008_000223 335 | 2008_000233 336 | 2008_000234 337 | 2008_000239 338 | 2008_000254 339 | 2008_000270 340 | 2008_000271 341 | 2008_000345 342 | 2008_000359 343 | 2008_000391 344 | 2008_000401 345 | 2008_000464 346 | 2008_000469 347 | 2008_000474 348 | 2008_000501 349 | 2008_000510 350 | 2008_000533 351 | 2008_000573 352 | 2008_000589 353 | 2008_000602 354 | 2008_000630 355 | 2008_000657 356 | 2008_000661 357 | 2008_000662 358 | 2008_000666 359 | 2008_000673 360 | 2008_000700 361 | 2008_000725 362 | 2008_000731 363 | 2008_000763 364 | 2008_000765 365 | 2008_000782 366 | 2008_000795 367 | 2008_000811 368 | 2008_000848 369 | 2008_000853 370 | 2008_000863 371 | 2008_000911 372 | 2008_000919 373 | 2008_000943 374 | 2008_000992 375 | 2008_001013 376 | 2008_001028 377 | 2008_001040 378 | 2008_001070 379 | 2008_001074 380 | 2008_001076 381 | 2008_001078 382 | 2008_001135 383 | 2008_001150 384 | 2008_001170 385 | 2008_001231 386 | 2008_001249 387 | 2008_001260 388 | 2008_001283 389 | 2008_001308 390 | 2008_001379 391 | 2008_001404 392 | 2008_001433 393 | 2008_001439 394 | 2008_001478 395 | 2008_001491 396 | 2008_001504 397 | 2008_001513 398 | 2008_001514 399 | 2008_001531 400 | 2008_001546 401 | 2008_001547 402 | 2008_001580 403 | 2008_001629 404 | 2008_001640 405 | 2008_001682 406 | 2008_001688 407 | 2008_001715 408 | 2008_001821 409 | 2008_001874 410 | 2008_001885 411 | 2008_001895 412 | 2008_001966 413 | 2008_001971 414 | 2008_001992 415 | 2008_002043 416 | 2008_002152 417 | 2008_002205 418 | 2008_002212 419 | 2008_002239 420 | 2008_002240 421 | 2008_002241 422 | 2008_002269 423 | 2008_002273 424 | 2008_002358 425 | 2008_002379 426 | 2008_002383 427 | 2008_002429 428 | 2008_002464 429 | 2008_002467 430 | 2008_002492 431 | 2008_002495 432 | 2008_002504 433 | 2008_002521 434 | 2008_002536 435 | 2008_002588 436 | 2008_002623 437 | 2008_002680 438 | 2008_002681 439 | 2008_002775 440 | 2008_002778 441 | 2008_002835 442 | 2008_002859 443 | 2008_002864 444 | 2008_002900 445 | 2008_002904 446 | 2008_002929 447 | 2008_002936 448 | 2008_002942 449 | 2008_002958 450 | 2008_003003 451 | 2008_003026 452 | 2008_003034 453 | 2008_003076 454 | 2008_003105 455 | 2008_003108 456 | 2008_003110 457 | 2008_003135 458 | 2008_003141 459 | 2008_003155 460 | 2008_003210 461 | 2008_003238 462 | 2008_003270 463 | 2008_003330 464 | 2008_003333 465 | 2008_003369 466 | 2008_003379 467 | 2008_003451 468 | 2008_003461 469 | 2008_003477 470 | 2008_003492 471 | 2008_003499 472 | 2008_003511 473 | 2008_003546 474 | 2008_003576 475 | 2008_003577 476 | 2008_003676 477 | 2008_003709 478 | 2008_003733 479 | 2008_003777 480 | 2008_003782 481 | 2008_003821 482 | 2008_003846 483 | 2008_003856 484 | 2008_003858 485 | 2008_003874 486 | 2008_003876 487 | 2008_003885 488 | 2008_003886 489 | 2008_003926 490 | 2008_003976 491 | 2008_004069 492 | 2008_004101 493 | 2008_004140 494 | 2008_004172 495 | 2008_004175 496 | 2008_004212 497 | 2008_004279 498 | 2008_004339 499 | 2008_004345 500 | 2008_004363 501 | 2008_004367 502 | 2008_004396 503 | 2008_004399 504 | 2008_004453 505 | 2008_004477 506 | 2008_004552 507 | 2008_004562 508 | 2008_004575 509 | 2008_004610 510 | 2008_004612 511 | 2008_004621 512 | 2008_004624 513 | 2008_004654 514 | 2008_004659 515 | 2008_004687 516 | 2008_004701 517 | 2008_004704 518 | 2008_004705 519 | 2008_004754 520 | 2008_004758 521 | 2008_004854 522 | 2008_004910 523 | 2008_004995 524 | 2008_005049 525 | 2008_005089 526 | 2008_005097 527 | 2008_005105 528 | 2008_005145 529 | 2008_005197 530 | 2008_005217 531 | 2008_005242 532 | 2008_005245 533 | 2008_005254 534 | 2008_005262 535 | 2008_005338 536 | 2008_005398 537 | 2008_005399 538 | 2008_005422 539 | 2008_005439 540 | 2008_005445 541 | 2008_005525 542 | 2008_005544 543 | 2008_005628 544 | 2008_005633 545 | 2008_005637 546 | 2008_005642 547 | 2008_005676 548 | 2008_005680 549 | 2008_005691 550 | 2008_005727 551 | 2008_005738 552 | 2008_005812 553 | 2008_005904 554 | 2008_005915 555 | 2008_006008 556 | 2008_006036 557 | 2008_006055 558 | 2008_006063 559 | 2008_006108 560 | 2008_006130 561 | 2008_006143 562 | 2008_006159 563 | 2008_006216 564 | 2008_006219 565 | 2008_006229 566 | 2008_006254 567 | 2008_006275 568 | 2008_006325 569 | 2008_006327 570 | 2008_006341 571 | 2008_006408 572 | 2008_006480 573 | 2008_006523 574 | 2008_006526 575 | 2008_006528 576 | 2008_006553 577 | 2008_006554 578 | 2008_006703 579 | 2008_006722 580 | 2008_006752 581 | 2008_006784 582 | 2008_006835 583 | 2008_006874 584 | 2008_006981 585 | 2008_006986 586 | 2008_007025 587 | 2008_007031 588 | 2008_007048 589 | 2008_007120 590 | 2008_007123 591 | 2008_007143 592 | 2008_007194 593 | 2008_007219 594 | 2008_007273 595 | 2008_007350 596 | 2008_007378 597 | 2008_007392 598 | 2008_007402 599 | 2008_007497 600 | 2008_007498 601 | 2008_007507 602 | 2008_007513 603 | 2008_007527 604 | 2008_007548 605 | 2008_007596 606 | 2008_007677 607 | 2008_007737 608 | 2008_007797 609 | 2008_007804 610 | 2008_007811 611 | 2008_007814 612 | 2008_007828 613 | 2008_007836 614 | 2008_007945 615 | 2008_007994 616 | 2008_008051 617 | 2008_008103 618 | 2008_008127 619 | 2008_008221 620 | 2008_008252 621 | 2008_008268 622 | 2008_008296 623 | 2008_008301 624 | 2008_008335 625 | 2008_008362 626 | 2008_008392 627 | 2008_008393 628 | 2008_008421 629 | 2008_008434 630 | 2008_008469 631 | 2008_008629 632 | 2008_008682 633 | 2008_008711 634 | 2008_008746 635 | 2009_000012 636 | 2009_000013 637 | 2009_000022 638 | 2009_000032 639 | 2009_000037 640 | 2009_000039 641 | 2009_000074 642 | 2009_000080 643 | 2009_000087 644 | 2009_000096 645 | 2009_000121 646 | 2009_000136 647 | 2009_000149 648 | 2009_000156 649 | 2009_000201 650 | 2009_000205 651 | 2009_000219 652 | 2009_000242 653 | 2009_000309 654 | 2009_000318 655 | 2009_000335 656 | 2009_000351 657 | 2009_000354 658 | 2009_000387 659 | 2009_000391 660 | 2009_000412 661 | 2009_000418 662 | 2009_000421 663 | 2009_000426 664 | 2009_000440 665 | 2009_000446 666 | 2009_000455 667 | 2009_000457 668 | 2009_000469 669 | 2009_000487 670 | 2009_000488 671 | 2009_000523 672 | 2009_000573 673 | 2009_000619 674 | 2009_000628 675 | 2009_000641 676 | 2009_000664 677 | 2009_000675 678 | 2009_000704 679 | 2009_000705 680 | 2009_000712 681 | 2009_000716 682 | 2009_000723 683 | 2009_000727 684 | 2009_000730 685 | 2009_000731 686 | 2009_000732 687 | 2009_000771 688 | 2009_000825 689 | 2009_000828 690 | 2009_000839 691 | 2009_000840 692 | 2009_000845 693 | 2009_000879 694 | 2009_000892 695 | 2009_000919 696 | 2009_000924 697 | 2009_000931 698 | 2009_000935 699 | 2009_000964 700 | 2009_000989 701 | 2009_000991 702 | 2009_000998 703 | 2009_001008 704 | 2009_001082 705 | 2009_001108 706 | 2009_001160 707 | 2009_001215 708 | 2009_001240 709 | 2009_001255 710 | 2009_001278 711 | 2009_001299 712 | 2009_001300 713 | 2009_001314 714 | 2009_001332 715 | 2009_001333 716 | 2009_001363 717 | 2009_001391 718 | 2009_001411 719 | 2009_001433 720 | 2009_001505 721 | 2009_001535 722 | 2009_001536 723 | 2009_001565 724 | 2009_001607 725 | 2009_001644 726 | 2009_001663 727 | 2009_001683 728 | 2009_001684 729 | 2009_001687 730 | 2009_001718 731 | 2009_001731 732 | 2009_001765 733 | 2009_001768 734 | 2009_001775 735 | 2009_001804 736 | 2009_001816 737 | 2009_001818 738 | 2009_001850 739 | 2009_001851 740 | 2009_001854 741 | 2009_001941 742 | 2009_001991 743 | 2009_002012 744 | 2009_002035 745 | 2009_002042 746 | 2009_002082 747 | 2009_002094 748 | 2009_002097 749 | 2009_002122 750 | 2009_002150 751 | 2009_002155 752 | 2009_002164 753 | 2009_002165 754 | 2009_002171 755 | 2009_002185 756 | 2009_002202 757 | 2009_002221 758 | 2009_002238 759 | 2009_002239 760 | 2009_002265 761 | 2009_002268 762 | 2009_002291 763 | 2009_002295 764 | 2009_002317 765 | 2009_002320 766 | 2009_002346 767 | 2009_002366 768 | 2009_002372 769 | 2009_002382 770 | 2009_002390 771 | 2009_002415 772 | 2009_002445 773 | 2009_002487 774 | 2009_002521 775 | 2009_002527 776 | 2009_002535 777 | 2009_002539 778 | 2009_002549 779 | 2009_002562 780 | 2009_002568 781 | 2009_002571 782 | 2009_002573 783 | 2009_002584 784 | 2009_002591 785 | 2009_002594 786 | 2009_002604 787 | 2009_002618 788 | 2009_002635 789 | 2009_002638 790 | 2009_002649 791 | 2009_002651 792 | 2009_002727 793 | 2009_002732 794 | 2009_002749 795 | 2009_002753 796 | 2009_002771 797 | 2009_002808 798 | 2009_002856 799 | 2009_002887 800 | 2009_002888 801 | 2009_002928 802 | 2009_002936 803 | 2009_002975 804 | 2009_002982 805 | 2009_002990 806 | 2009_003003 807 | 2009_003005 808 | 2009_003043 809 | 2009_003059 810 | 2009_003063 811 | 2009_003065 812 | 2009_003071 813 | 2009_003080 814 | 2009_003105 815 | 2009_003123 816 | 2009_003193 817 | 2009_003196 818 | 2009_003217 819 | 2009_003224 820 | 2009_003241 821 | 2009_003269 822 | 2009_003273 823 | 2009_003299 824 | 2009_003304 825 | 2009_003311 826 | 2009_003323 827 | 2009_003343 828 | 2009_003378 829 | 2009_003387 830 | 2009_003406 831 | 2009_003433 832 | 2009_003450 833 | 2009_003466 834 | 2009_003481 835 | 2009_003494 836 | 2009_003498 837 | 2009_003504 838 | 2009_003507 839 | 2009_003517 840 | 2009_003523 841 | 2009_003542 842 | 2009_003549 843 | 2009_003551 844 | 2009_003564 845 | 2009_003569 846 | 2009_003576 847 | 2009_003589 848 | 2009_003607 849 | 2009_003640 850 | 2009_003666 851 | 2009_003696 852 | 2009_003703 853 | 2009_003707 854 | 2009_003756 855 | 2009_003771 856 | 2009_003773 857 | 2009_003804 858 | 2009_003806 859 | 2009_003810 860 | 2009_003849 861 | 2009_003857 862 | 2009_003858 863 | 2009_003895 864 | 2009_003903 865 | 2009_003904 866 | 2009_003928 867 | 2009_003938 868 | 2009_003971 869 | 2009_003991 870 | 2009_004021 871 | 2009_004033 872 | 2009_004043 873 | 2009_004070 874 | 2009_004072 875 | 2009_004084 876 | 2009_004099 877 | 2009_004125 878 | 2009_004140 879 | 2009_004217 880 | 2009_004221 881 | 2009_004247 882 | 2009_004248 883 | 2009_004255 884 | 2009_004298 885 | 2009_004324 886 | 2009_004455 887 | 2009_004494 888 | 2009_004497 889 | 2009_004504 890 | 2009_004507 891 | 2009_004509 892 | 2009_004540 893 | 2009_004568 894 | 2009_004579 895 | 2009_004581 896 | 2009_004590 897 | 2009_004592 898 | 2009_004594 899 | 2009_004635 900 | 2009_004653 901 | 2009_004687 902 | 2009_004721 903 | 2009_004730 904 | 2009_004732 905 | 2009_004738 906 | 2009_004748 907 | 2009_004789 908 | 2009_004799 909 | 2009_004801 910 | 2009_004848 911 | 2009_004859 912 | 2009_004867 913 | 2009_004882 914 | 2009_004886 915 | 2009_004895 916 | 2009_004942 917 | 2009_004969 918 | 2009_004987 919 | 2009_004993 920 | 2009_004994 921 | 2009_005038 922 | 2009_005078 923 | 2009_005087 924 | 2009_005089 925 | 2009_005137 926 | 2009_005148 927 | 2009_005156 928 | 2009_005158 929 | 2009_005189 930 | 2009_005190 931 | 2009_005217 932 | 2009_005219 933 | 2009_005220 934 | 2009_005231 935 | 2009_005260 936 | 2009_005262 937 | 2009_005302 938 | 2010_000003 939 | 2010_000038 940 | 2010_000065 941 | 2010_000083 942 | 2010_000084 943 | 2010_000087 944 | 2010_000110 945 | 2010_000159 946 | 2010_000160 947 | 2010_000163 948 | 2010_000174 949 | 2010_000216 950 | 2010_000238 951 | 2010_000241 952 | 2010_000256 953 | 2010_000272 954 | 2010_000284 955 | 2010_000309 956 | 2010_000318 957 | 2010_000330 958 | 2010_000335 959 | 2010_000342 960 | 2010_000372 961 | 2010_000422 962 | 2010_000426 963 | 2010_000427 964 | 2010_000502 965 | 2010_000530 966 | 2010_000552 967 | 2010_000559 968 | 2010_000572 969 | 2010_000573 970 | 2010_000622 971 | 2010_000628 972 | 2010_000639 973 | 2010_000666 974 | 2010_000679 975 | 2010_000682 976 | 2010_000683 977 | 2010_000724 978 | 2010_000738 979 | 2010_000764 980 | 2010_000788 981 | 2010_000814 982 | 2010_000836 983 | 2010_000874 984 | 2010_000904 985 | 2010_000906 986 | 2010_000907 987 | 2010_000918 988 | 2010_000929 989 | 2010_000941 990 | 2010_000952 991 | 2010_000961 992 | 2010_001000 993 | 2010_001010 994 | 2010_001011 995 | 2010_001016 996 | 2010_001017 997 | 2010_001024 998 | 2010_001036 999 | 2010_001061 1000 | 2010_001069 1001 | 2010_001070 1002 | 2010_001079 1003 | 2010_001104 1004 | 2010_001124 1005 | 2010_001149 1006 | 2010_001151 1007 | 2010_001174 1008 | 2010_001206 1009 | 2010_001246 1010 | 2010_001251 1011 | 2010_001256 1012 | 2010_001264 1013 | 2010_001292 1014 | 2010_001313 1015 | 2010_001327 1016 | 2010_001331 1017 | 2010_001351 1018 | 2010_001367 1019 | 2010_001376 1020 | 2010_001403 1021 | 2010_001448 1022 | 2010_001451 1023 | 2010_001522 1024 | 2010_001534 1025 | 2010_001553 1026 | 2010_001557 1027 | 2010_001563 1028 | 2010_001577 1029 | 2010_001579 1030 | 2010_001646 1031 | 2010_001656 1032 | 2010_001692 1033 | 2010_001699 1034 | 2010_001734 1035 | 2010_001752 1036 | 2010_001767 1037 | 2010_001768 1038 | 2010_001773 1039 | 2010_001820 1040 | 2010_001830 1041 | 2010_001851 1042 | 2010_001908 1043 | 2010_001913 1044 | 2010_001951 1045 | 2010_001956 1046 | 2010_001962 1047 | 2010_001966 1048 | 2010_001995 1049 | 2010_002017 1050 | 2010_002025 1051 | 2010_002030 1052 | 2010_002106 1053 | 2010_002137 1054 | 2010_002142 1055 | 2010_002146 1056 | 2010_002147 1057 | 2010_002150 1058 | 2010_002161 1059 | 2010_002200 1060 | 2010_002228 1061 | 2010_002232 1062 | 2010_002251 1063 | 2010_002271 1064 | 2010_002305 1065 | 2010_002310 1066 | 2010_002336 1067 | 2010_002348 1068 | 2010_002361 1069 | 2010_002390 1070 | 2010_002396 1071 | 2010_002422 1072 | 2010_002450 1073 | 2010_002480 1074 | 2010_002512 1075 | 2010_002531 1076 | 2010_002536 1077 | 2010_002538 1078 | 2010_002546 1079 | 2010_002623 1080 | 2010_002682 1081 | 2010_002691 1082 | 2010_002693 1083 | 2010_002701 1084 | 2010_002763 1085 | 2010_002792 1086 | 2010_002868 1087 | 2010_002900 1088 | 2010_002902 1089 | 2010_002921 1090 | 2010_002929 1091 | 2010_002939 1092 | 2010_002988 1093 | 2010_003014 1094 | 2010_003060 1095 | 2010_003123 1096 | 2010_003127 1097 | 2010_003132 1098 | 2010_003168 1099 | 2010_003183 1100 | 2010_003187 1101 | 2010_003207 1102 | 2010_003231 1103 | 2010_003239 1104 | 2010_003275 1105 | 2010_003276 1106 | 2010_003293 1107 | 2010_003302 1108 | 2010_003325 1109 | 2010_003362 1110 | 2010_003365 1111 | 2010_003381 1112 | 2010_003402 1113 | 2010_003409 1114 | 2010_003418 1115 | 2010_003446 1116 | 2010_003453 1117 | 2010_003468 1118 | 2010_003473 1119 | 2010_003495 1120 | 2010_003506 1121 | 2010_003514 1122 | 2010_003531 1123 | 2010_003532 1124 | 2010_003541 1125 | 2010_003547 1126 | 2010_003597 1127 | 2010_003675 1128 | 2010_003708 1129 | 2010_003716 1130 | 2010_003746 1131 | 2010_003758 1132 | 2010_003764 1133 | 2010_003768 1134 | 2010_003771 1135 | 2010_003772 1136 | 2010_003781 1137 | 2010_003813 1138 | 2010_003820 1139 | 2010_003854 1140 | 2010_003912 1141 | 2010_003915 1142 | 2010_003947 1143 | 2010_003956 1144 | 2010_003971 1145 | 2010_004041 1146 | 2010_004042 1147 | 2010_004056 1148 | 2010_004063 1149 | 2010_004104 1150 | 2010_004120 1151 | 2010_004149 1152 | 2010_004165 1153 | 2010_004208 1154 | 2010_004219 1155 | 2010_004226 1156 | 2010_004314 1157 | 2010_004320 1158 | 2010_004322 1159 | 2010_004337 1160 | 2010_004348 1161 | 2010_004355 1162 | 2010_004369 1163 | 2010_004382 1164 | 2010_004419 1165 | 2010_004432 1166 | 2010_004472 1167 | 2010_004479 1168 | 2010_004519 1169 | 2010_004520 1170 | 2010_004529 1171 | 2010_004543 1172 | 2010_004550 1173 | 2010_004551 1174 | 2010_004556 1175 | 2010_004559 1176 | 2010_004628 1177 | 2010_004635 1178 | 2010_004662 1179 | 2010_004697 1180 | 2010_004757 1181 | 2010_004763 1182 | 2010_004772 1183 | 2010_004783 1184 | 2010_004789 1185 | 2010_004795 1186 | 2010_004815 1187 | 2010_004825 1188 | 2010_004828 1189 | 2010_004856 1190 | 2010_004857 1191 | 2010_004861 1192 | 2010_004941 1193 | 2010_004946 1194 | 2010_004951 1195 | 2010_004980 1196 | 2010_004994 1197 | 2010_005013 1198 | 2010_005021 1199 | 2010_005046 1200 | 2010_005063 1201 | 2010_005108 1202 | 2010_005118 1203 | 2010_005159 1204 | 2010_005160 1205 | 2010_005166 1206 | 2010_005174 1207 | 2010_005180 1208 | 2010_005187 1209 | 2010_005206 1210 | 2010_005245 1211 | 2010_005252 1212 | 2010_005284 1213 | 2010_005305 1214 | 2010_005344 1215 | 2010_005353 1216 | 2010_005366 1217 | 2010_005401 1218 | 2010_005421 1219 | 2010_005428 1220 | 2010_005432 1221 | 2010_005433 1222 | 2010_005496 1223 | 2010_005501 1224 | 2010_005508 1225 | 2010_005531 1226 | 2010_005534 1227 | 2010_005575 1228 | 2010_005582 1229 | 2010_005606 1230 | 2010_005626 1231 | 2010_005644 1232 | 2010_005664 1233 | 2010_005705 1234 | 2010_005706 1235 | 2010_005709 1236 | 2010_005718 1237 | 2010_005719 1238 | 2010_005727 1239 | 2010_005762 1240 | 2010_005788 1241 | 2010_005860 1242 | 2010_005871 1243 | 2010_005877 1244 | 2010_005888 1245 | 2010_005899 1246 | 2010_005922 1247 | 2010_005991 1248 | 2010_005992 1249 | 2010_006026 1250 | 2010_006034 1251 | 2010_006054 1252 | 2010_006070 1253 | 2011_000045 1254 | 2011_000051 1255 | 2011_000054 1256 | 2011_000066 1257 | 2011_000070 1258 | 2011_000112 1259 | 2011_000173 1260 | 2011_000178 1261 | 2011_000185 1262 | 2011_000226 1263 | 2011_000234 1264 | 2011_000238 1265 | 2011_000239 1266 | 2011_000248 1267 | 2011_000283 1268 | 2011_000291 1269 | 2011_000310 1270 | 2011_000312 1271 | 2011_000338 1272 | 2011_000396 1273 | 2011_000412 1274 | 2011_000419 1275 | 2011_000435 1276 | 2011_000436 1277 | 2011_000438 1278 | 2011_000455 1279 | 2011_000456 1280 | 2011_000479 1281 | 2011_000481 1282 | 2011_000482 1283 | 2011_000503 1284 | 2011_000512 1285 | 2011_000521 1286 | 2011_000526 1287 | 2011_000536 1288 | 2011_000548 1289 | 2011_000566 1290 | 2011_000585 1291 | 2011_000598 1292 | 2011_000607 1293 | 2011_000618 1294 | 2011_000638 1295 | 2011_000658 1296 | 2011_000661 1297 | 2011_000669 1298 | 2011_000747 1299 | 2011_000780 1300 | 2011_000789 1301 | 2011_000807 1302 | 2011_000809 1303 | 2011_000813 1304 | 2011_000830 1305 | 2011_000843 1306 | 2011_000874 1307 | 2011_000888 1308 | 2011_000900 1309 | 2011_000912 1310 | 2011_000953 1311 | 2011_000969 1312 | 2011_001005 1313 | 2011_001014 1314 | 2011_001020 1315 | 2011_001047 1316 | 2011_001060 1317 | 2011_001064 1318 | 2011_001069 1319 | 2011_001071 1320 | 2011_001082 1321 | 2011_001110 1322 | 2011_001114 1323 | 2011_001159 1324 | 2011_001161 1325 | 2011_001190 1326 | 2011_001232 1327 | 2011_001263 1328 | 2011_001276 1329 | 2011_001281 1330 | 2011_001287 1331 | 2011_001292 1332 | 2011_001313 1333 | 2011_001341 1334 | 2011_001346 1335 | 2011_001350 1336 | 2011_001407 1337 | 2011_001416 1338 | 2011_001421 1339 | 2011_001434 1340 | 2011_001447 1341 | 2011_001489 1342 | 2011_001529 1343 | 2011_001530 1344 | 2011_001534 1345 | 2011_001546 1346 | 2011_001567 1347 | 2011_001589 1348 | 2011_001597 1349 | 2011_001601 1350 | 2011_001607 1351 | 2011_001613 1352 | 2011_001614 1353 | 2011_001619 1354 | 2011_001624 1355 | 2011_001642 1356 | 2011_001665 1357 | 2011_001669 1358 | 2011_001674 1359 | 2011_001708 1360 | 2011_001713 1361 | 2011_001714 1362 | 2011_001722 1363 | 2011_001726 1364 | 2011_001745 1365 | 2011_001748 1366 | 2011_001775 1367 | 2011_001782 1368 | 2011_001793 1369 | 2011_001794 1370 | 2011_001812 1371 | 2011_001862 1372 | 2011_001863 1373 | 2011_001868 1374 | 2011_001880 1375 | 2011_001910 1376 | 2011_001984 1377 | 2011_001988 1378 | 2011_002002 1379 | 2011_002040 1380 | 2011_002041 1381 | 2011_002064 1382 | 2011_002075 1383 | 2011_002098 1384 | 2011_002110 1385 | 2011_002121 1386 | 2011_002124 1387 | 2011_002150 1388 | 2011_002156 1389 | 2011_002178 1390 | 2011_002200 1391 | 2011_002223 1392 | 2011_002244 1393 | 2011_002247 1394 | 2011_002279 1395 | 2011_002295 1396 | 2011_002298 1397 | 2011_002308 1398 | 2011_002317 1399 | 2011_002322 1400 | 2011_002327 1401 | 2011_002343 1402 | 2011_002358 1403 | 2011_002371 1404 | 2011_002379 1405 | 2011_002391 1406 | 2011_002498 1407 | 2011_002509 1408 | 2011_002515 1409 | 2011_002532 1410 | 2011_002535 1411 | 2011_002548 1412 | 2011_002575 1413 | 2011_002578 1414 | 2011_002589 1415 | 2011_002592 1416 | 2011_002623 1417 | 2011_002641 1418 | 2011_002644 1419 | 2011_002662 1420 | 2011_002675 1421 | 2011_002685 1422 | 2011_002713 1423 | 2011_002730 1424 | 2011_002754 1425 | 2011_002812 1426 | 2011_002863 1427 | 2011_002879 1428 | 2011_002885 1429 | 2011_002929 1430 | 2011_002951 1431 | 2011_002975 1432 | 2011_002993 1433 | 2011_002997 1434 | 2011_003003 1435 | 2011_003011 1436 | 2011_003019 1437 | 2011_003030 1438 | 2011_003055 1439 | 2011_003085 1440 | 2011_003103 1441 | 2011_003114 1442 | 2011_003145 1443 | 2011_003146 1444 | 2011_003182 1445 | 2011_003197 1446 | 2011_003205 1447 | 2011_003240 1448 | 2011_003256 1449 | 2011_003271 -------------------------------------------------------------------------------- /data/pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/data/pretrained/.gitkeep -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/lib/__init__.py -------------------------------------------------------------------------------- /lib/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .voc_seg_group_loader import VOCSegGroupLoader 2 | from .voc_seg_loader import VOCSegLoader 3 | -------------------------------------------------------------------------------- /lib/loader/voc_seg_group_loader.py: -------------------------------------------------------------------------------- 1 | from ..utils import * 2 | from .voc_seg_loader import load_batch_semantic 3 | 4 | 5 | class VOCSegGroupLoader(mx.io.DataIter): 6 | def __init__(self, image_root, label_root, annotation_root, data_list, 7 | batch_size, group_size, num_block, target_size, 8 | pad=False, shuffle=False, rand_scale=False, rand_mirror=False, rand_crop=False, downsample=None): 9 | 10 | assert group_size >= 2, "'group_size': # common-class images, typical value is 2 for pairs" 11 | assert num_block >= 1, "'num_block': should equal # GPU" 12 | assert batch_size % (group_size * num_block) == 0 13 | 14 | with open(data_list, 'r') as f: 15 | data_names = [x.strip() for x in f.readlines()] 16 | 17 | if pad and (len(data_names) % batch_size > 0): 18 | pad_num = batch_size - (len(data_names) % batch_size) 19 | data_names = data_names + data_names[:pad_num] 20 | 21 | self.image_src_list = [os.path.join(image_root, x+'.jpg') for x in data_names] 22 | self.label_src_list = [os.path.join(label_root, x+'.png') for x in data_names] \ 23 | if label_root is not None else [None] * len(data_names) 24 | self.ann_list = [VOC.get_annotation(os.path.join(annotation_root, x+'.xml')) for x in data_names] 25 | 26 | self.batch_size = batch_size 27 | self.group_size = group_size 28 | self.num_block = num_block 29 | self.meta_length = self.batch_size // (self.num_block * self.group_size) 30 | 31 | self.target_size = target_size 32 | self.shuffle = shuffle 33 | self.rand_scale = rand_scale 34 | self.rand_mirror = rand_mirror 35 | self.rand_crop = rand_crop 36 | self.downsample = downsample 37 | 38 | scale_pool = [0.5, 0.75, 1, 1.25, 1.5] 39 | self.scale_sampler = lambda : np.random.choice(scale_pool) 40 | 41 | self.index = list(range(len(data_names))) 42 | self.num_batch = len(data_names) // self.batch_size 43 | self.reset() 44 | 45 | def reset(self): 46 | self.index_pointer = 0 47 | self.cache = [] 48 | if self.shuffle: 49 | np.random.shuffle(self.index) 50 | 51 | def pop(self): 52 | if len(self.cache) > 0: 53 | index = self.cache.pop() 54 | elif self.index_pointer < len(self.index): 55 | index = self.index[self.index_pointer] 56 | self.index_pointer += 1 57 | else: 58 | raise StopIteration 59 | return index 60 | 61 | def is_ok(self, a, b): 62 | lbl_a = self.ann_list[a] 63 | lbl_b = self.ann_list[b] 64 | return len(set(lbl_a) - set(lbl_b)) < len(lbl_a) 65 | 66 | def next(self): 67 | indices = [] 68 | while len(indices) < self.batch_size // self.group_size: 69 | cache = [] 70 | partners = [self.pop()] 71 | while len(partners) < self.group_size: 72 | this = self.pop() 73 | while not all([self.is_ok(prev, this) for prev in partners]): 74 | cache.append(this) 75 | this = self.pop() 76 | partners.append(this) 77 | indices.append(partners) 78 | self.cache = cache[::-1] + self.cache 79 | 80 | indices = sum( [sum(zip(*indices[i : i+self.meta_length]), tuple()) \ 81 | for i in range(0, len(indices), self.meta_length)], tuple() ) 82 | 83 | image_src_list = [self.image_src_list[i] for i in indices] 84 | label_src_list = [self.label_src_list[i] for i in indices] 85 | self.cache_image_src_list = image_src_list 86 | 87 | batch = load_batch_semantic(image_src_list, label_src_list, self.target_size, self.scale_sampler, 88 | self.rand_scale, self.rand_mirror, self.rand_crop, self.downsample) 89 | return batch 90 | 91 | -------------------------------------------------------------------------------- /lib/loader/voc_seg_loader.py: -------------------------------------------------------------------------------- 1 | from ..utils import * 2 | 3 | class VOCSegLoader(mx.io.DataIter): 4 | def __init__(self, image_root, label_root, data_list, batch_size, target_size, 5 | pad=False, shuffle=False, rand_scale=False, rand_mirror=False, rand_crop=False, data_slice=None, downsample=None): 6 | 7 | with open(data_list, 'r') as f: 8 | data_names = [x.strip() for x in f.readlines()] 9 | 10 | if data_slice is not None: 11 | data_names = data_names[data_slice] 12 | 13 | if pad and (len(data_names) % batch_size > 0): 14 | pad_num = batch_size - (len(data_names) % batch_size) 15 | data_names = data_names + data_names[:pad_num] 16 | 17 | self.image_src_list = [os.path.join(image_root, x+'.jpg') for x in data_names] 18 | self.label_src_list = [os.path.join(label_root, x+'.png') for x in data_names] \ 19 | if label_root is not None else [None] * len(data_names) 20 | 21 | self.batch_size = batch_size 22 | self.target_size = target_size 23 | 24 | self.shuffle = shuffle 25 | self.rand_scale = rand_scale 26 | self.rand_mirror = rand_mirror 27 | self.rand_crop = rand_crop 28 | self.downsample = downsample 29 | 30 | scale_pool = [0.5, 0.75, 1, 1.25, 1.5] 31 | self.scale_sampler = lambda : np.random.choice(scale_pool) 32 | 33 | self.index = list(range(len(data_names))) 34 | self.num_batch = len(data_names) // self.batch_size 35 | self.reset() 36 | 37 | def reset(self): 38 | self.current = 0 39 | if self.shuffle: 40 | np.random.shuffle(self.index) 41 | 42 | def next(self): 43 | if self.current >= self.num_batch: 44 | raise StopIteration 45 | 46 | index = self.index[self.current*self.batch_size : (self.current+1)*self.batch_size] 47 | self.current += 1 48 | 49 | image_src_list = [self.image_src_list[i] for i in index] 50 | label_src_list = [self.label_src_list[i] for i in index] 51 | self.cache_image_src_list = image_src_list 52 | 53 | batch = load_batch_semantic(image_src_list, label_src_list, self.target_size, self.scale_sampler, self.rand_scale, self.rand_mirror, self.rand_crop, self.downsample) 54 | return batch 55 | 56 | 57 | def load_batch_semantic(image_src_list, label_src_list, target_size, scale_sampler, rand_scale, rand_mirror, rand_crop, downsample): 58 | img_batch, seg_batch = [], [] 59 | for image_src, label_src in zip(image_src_list, label_src_list): 60 | img, seg = load_semantic(image_src, label_src, target_size, scale_sampler() if rand_scale else 1, 61 | rand_mirror and (np.random.rand() > 0.5), rand_crop, downsample) 62 | img_batch.append(img) 63 | seg_batch.append(seg) 64 | 65 | img_batch = mx.nd.array(np.array(img_batch)).transpose((0, 3, 1, 2)) 66 | img_batch = img_batch[:, ::-1, :, :] 67 | seg_batch = mx.nd.array(np.array(seg_batch)) 68 | 69 | batch = mx.io.DataBatch(data=[img_batch], label=[seg_batch]) 70 | return batch 71 | 72 | def load_semantic(image_src, label_src, target_size, scale, mirror, rand_crop, downsample=None): 73 | img = cv2.imread(image_src) 74 | h, w = img.shape[:2] 75 | seg = cv2.imread(label_src, 0) if label_src is not None else np.full((h, w), 255, np.uint8) 76 | 77 | if mirror: 78 | img = img[:, ::-1] 79 | seg = seg[:, ::-1] 80 | 81 | if scale != 1: 82 | h = int(h * scale + .5) 83 | w = int(w * scale + .5) 84 | img = cv2.resize(img, (w, h)) 85 | seg = cv2.resize(seg, (w, h), interpolation=cv2.INTER_NEAREST) 86 | 87 | pad_h = max(target_size - h, 0) 88 | pad_w = max(target_size - w, 0) 89 | if pad_h > 0 or pad_w > 0: 90 | img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT) 91 | seg = cv2.copyMakeBorder(seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT) 92 | h, w = img.shape[:2] 93 | 94 | if rand_crop: 95 | h_bgn = np.random.randint(0, h - target_size + 1) 96 | w_bgn = np.random.randint(0, w - target_size + 1) 97 | else: 98 | h_bgn = (h - target_size) // 2 99 | w_bgn = (w - target_size) // 2 100 | 101 | img = img[h_bgn:h_bgn+target_size, w_bgn:w_bgn+target_size] 102 | seg = seg[h_bgn:h_bgn+target_size, w_bgn:w_bgn+target_size] 103 | 104 | if downsample: 105 | d_size = (target_size - 1) // downsample + 1 106 | seg = cv2.resize(seg, (d_size, d_size)) 107 | 108 | return img, seg 109 | 110 | 111 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers_custom import * 2 | from .cian_models import MultiScale 3 | from .cian_models import resnet101_largefov_SA, resnet101_largefov_CA 4 | from .cian_models import resnet50_largefov_SA, resnet50_largefov_CA 5 | 6 | -------------------------------------------------------------------------------- /lib/models/cian_models.py: -------------------------------------------------------------------------------- 1 | from .resnet import _Resnet 2 | from .cian_modules import build_self_affinity, build_cross_affinity 3 | from .layers import * 4 | from .multi_scale import * 5 | 6 | 7 | def resnet101_largefov_SA(x, num_cls, is_downsample=True, 8 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 9 | use_global_stats_backbone=False, use_global_stats_affinity=False, 10 | lr_mult=10, reuse=None, **kwargs): 11 | 12 | x_raw = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, 13 | use_global_stats=use_global_stats_backbone, 14 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), lr_mult=1, reuse=reuse) 15 | 16 | x_res = build_self_affinity(x_raw, 1024, 2048, is_downsample, 17 | in_embed_type, out_embed_type, sim_type, 18 | use_global_stats_affinity, lr_mult, reuse) 19 | 20 | x = x_raw + x_res 21 | x = Conv(x, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 22 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=reuse) 23 | return x 24 | 25 | def resnet101_largefov_CA(x, num_cls, is_downsample=True, 26 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 27 | group_size=2, merge_type='max', merge_self=True, 28 | use_global_stats_backbone=False, use_global_stats_affinity=False, 29 | lr_mult=10, reuse=None): 30 | 31 | x_raw = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, 32 | use_global_stats=use_global_stats_backbone, 33 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), lr_mult=1, reuse=reuse) 34 | 35 | x_res_self, x_res_cross = build_cross_affinity(x_raw, 1024, 2048, is_downsample, 36 | in_embed_type, out_embed_type, sim_type, 37 | group_size, merge_type, merge_self, 38 | use_global_stats_affinity, lr_mult, reuse) 39 | 40 | x_self = x_raw + x_res_self 41 | x_self = Conv(x_self, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 42 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=reuse) 43 | x_cross = x_raw + x_res_cross 44 | x_cross = Conv(x_cross, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 45 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=x_self) 46 | return x_self, x_cross 47 | 48 | def resnet50_largefov_SA(x, num_cls, is_downsample=True, 49 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 50 | use_global_stats_backbone=False, use_global_stats_affinity=False, 51 | lr_mult=10, reuse=None, **kwargs): 52 | 53 | x_raw = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, 54 | use_global_stats=use_global_stats_backbone, 55 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), lr_mult=1, reuse=reuse) 56 | 57 | x_res = build_self_affinity(x_raw, 1024, 2048, is_downsample, 58 | in_embed_type, out_embed_type, sim_type, 59 | use_global_stats_affinity, lr_mult, reuse) 60 | 61 | x = x_raw + x_res 62 | x = Conv(x, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 63 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=reuse) 64 | return x 65 | 66 | def resnet50_largefov_CA(x, num_cls, is_downsample=True, 67 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 68 | group_size=2, merge_type='max', merge_self=True, 69 | use_global_stats_backbone=False, use_global_stats_affinity=False, 70 | lr_mult=10, reuse=None): 71 | 72 | x_raw = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, 73 | use_global_stats=use_global_stats_backbone, 74 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), lr_mult=1, reuse=reuse) 75 | 76 | x_res_self, x_res_cross = build_cross_affinity(x_raw, 1024, 2048, is_downsample, 77 | in_embed_type, out_embed_type, sim_type, 78 | group_size, merge_type, merge_self, 79 | use_global_stats_affinity, lr_mult, reuse) 80 | 81 | x_self = x_raw + x_res_self 82 | x_self = Conv(x_self, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 83 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=reuse) 84 | x_cross = x_raw + x_res_cross 85 | x_cross = Conv(x_cross, num_cls, (3, 3), (1, 1), dilate=(12, 12), pad=(12, 12), 86 | no_bias=True, name='fc1', lr_mult=lr_mult, reuse=x_self) 87 | return x_self, x_cross 88 | 89 | -------------------------------------------------------------------------------- /lib/models/cian_modules.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def in_embedding_conv(x_feat, num_filter_hidden, is_downsample=True, lr_mult=1, reuse=None): 4 | x_query = Conv(x_feat, num_filter_hidden, (1, 1), no_bias=True, 5 | name='conv_embed_q', lr_mult=lr_mult, reuse=reuse) 6 | x_key = Conv(x_feat, num_filter_hidden, (1, 1), no_bias=True, 7 | name='conv_embed_k', lr_mult=lr_mult, reuse=reuse) 8 | x_value = Conv(x_feat, num_filter_hidden, (1, 1), no_bias=True, 9 | name='conv_embed_v', lr_mult=lr_mult, reuse=reuse) 10 | 11 | if is_downsample: 12 | x_key = Pool(x_key, (3, 3), (2, 2), (1, 1)) 13 | x_value = Pool(x_value, (3, 3), (2, 2), (1, 1)) 14 | return x_query, x_key, x_value 15 | 16 | def out_embedding_convbn(x_res, num_filter_out, use_global_stats=False, lr_mult=1, reuse=None): 17 | x_res = Conv(x_res, num_filter_out, (1, 1), no_bias=True, 18 | name='conv_out', lr_mult=lr_mult, reuse=reuse) 19 | x_res = BN(x_res, fix_gamma=False, use_global_stats=use_global_stats, 20 | name='bn_out', lr_mult=lr_mult, reuse=reuse) 21 | return x_res 22 | 23 | def compute_sim_mat(x_key, x_query, sim_type): 24 | if sim_type == 'dot': 25 | sim_mat = mx.sym.batch_dot(x_key, x_query, transpose_a=True) 26 | elif sim_type == 'cosine': 27 | x_key_norm = mx.sym.L2Normalization(x_key, mode='channel') 28 | x_query_norm = mx.sym.L2Normalization(x_query, mode='channel') 29 | sim_mat = mx.sym.batch_dot(x_key_norm, x_query_norm, transpose_a=True) 30 | else: 31 | raise ValueError(sim_type) 32 | return sim_mat 33 | 34 | def build_self_affinity(x_feat, num_filter_hidden, num_filter_out, is_downsample=True, 35 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 36 | use_global_stats=False, lr_mult=1, reuse=None, return_internals=False): 37 | get_embedding_in = eval('in_embedding_' + in_embed_type) 38 | get_embedding_out = eval('out_embedding_' + out_embed_type) 39 | 40 | x_query, x_key, x_value = get_embedding_in(x_feat, 41 | num_filter_hidden, is_downsample, lr_mult, reuse) 42 | 43 | x_query = mx.sym.reshape(x_query, (0, 0, -3)) 44 | x_key = mx.sym.reshape(x_key, (0, 0, -3)) 45 | x_value = mx.sym.reshape(x_value, (0, 0, -3)) 46 | 47 | sim_mat = compute_sim_mat(x_key, x_query, sim_type) 48 | sim_mat = mx.sym.softmax(sim_mat, axis=1) 49 | 50 | x_res = mx.sym.batch_dot(x_value, sim_mat) 51 | x_res = mx.sym.reshape_like(x_res, x_feat, lhs_begin=2, lhs_end=3, rhs_begin=2, rhs_end=4) 52 | 53 | x_out = get_embedding_out(x_res, num_filter_out, use_global_stats, lr_mult, reuse) 54 | 55 | if return_internals: 56 | return x_out, x_query, x_key, x_value, sim_mat, x_res 57 | return x_out 58 | 59 | def build_cross_affinity(x_feat, num_filter_hidden, num_filter_out, is_downsample=True, 60 | in_embed_type='conv', out_embed_type='convbn', sim_type='dot', 61 | group_size=2, merge_type='max', merge_self=True, 62 | use_global_stats=False, lr_mult=1, reuse=None): 63 | get_embedding_in = eval('in_embedding_' + in_embed_type) 64 | get_embedding_out = eval('out_embedding_' + out_embed_type) 65 | 66 | x_out_self, x_query, x_key, x_value, sim_mat_self, x_res_self = build_self_affinity( 67 | x_feat, num_filter_hidden, num_filter_out, is_downsample, 68 | in_embed_type, out_embed_type, sim_type, 69 | use_global_stats, lr_mult, reuse, True) 70 | 71 | # split 72 | x_key_sp = list(mx.sym.split(x_key, num_outputs=group_size, axis=0)) 73 | x_value_sp = list(mx.sym.split(x_value, num_outputs=group_size, axis=0)) 74 | 75 | # roll, res 76 | x_res_list = [] 77 | for i in range(group_size - 1): 78 | x_key_sp = x_key_sp[1:] + x_key_sp[0:1] 79 | x_value_sp = x_value_sp[1:] + x_value_sp[0:1] 80 | 81 | x_key_roll = mx.sym.concat(*x_key_sp, dim=0) 82 | x_value_roll = mx.sym.concat(*x_value_sp, dim=0) 83 | 84 | sim_mat = compute_sim_mat(x_key_roll, x_query, sim_type) 85 | sim_mat = mx.sym.softmax(sim_mat, axis=1) 86 | 87 | x_res = mx.sym.batch_dot(x_value_roll, sim_mat) 88 | x_res = mx.sym.reshape_like(x_res, x_feat, lhs_begin=2, lhs_end=3, rhs_begin=2, rhs_end=4) 89 | x_res_list.append(x_res) 90 | 91 | # merge 92 | if merge_self: 93 | x_res_list.append(x_res_self) 94 | 95 | if merge_type == 'max': 96 | x_res_cross = x_res_list[0] 97 | for x_res in x_res_list[1:]: 98 | x_res_cross = mx.sym.maximum(x_res_cross, x_res) 99 | elif merge_type == 'avg': 100 | x_res_cross = sum(x_res_list) / len(x_res_list) 101 | 102 | # embed out 103 | x_out_cross = get_embedding_out(x_res_cross, num_filter_out, use_global_stats, lr_mult, x_out_self) 104 | 105 | return x_out_self, x_out_cross 106 | 107 | 108 | -------------------------------------------------------------------------------- /lib/models/layers.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | 3 | # ==== layers ==== 4 | 5 | def Convolution(data, num_filter, kernel, stride=None, dilate=None, pad=None, num_group=1, no_bias=False, 6 | weight=None, bias=None, name=None, lr_mult=1, reuse=None, **kwargs): 7 | if reuse is not None: 8 | assert name is not None 9 | name = GetLayerName.get('conv') if name is None else name 10 | 11 | stride = (1,) * len(kernel) if stride is None else stride 12 | dilate = (1,) * len(kernel) if dilate is None else dilate 13 | 14 | # tensorflow 'SAME' padding 15 | if isinstance(pad, str): 16 | input_size = kwargs.get('input_size', None) 17 | if input_size is None: 18 | raise ValueError("`input_size` is needed for padding") 19 | del kwargs['input_size'] 20 | if isinstance(input_size, int): 21 | in_size_h = in_size_w = input_size 22 | else: 23 | in_size_h, in_size_w = input_size 24 | ph0, ph1 = padding_helper(in_size_h, kernel[0], stride[0], pad) 25 | pw0, pw1 = padding_helper(in_size_w, kernel[1], stride[1], pad) 26 | data = mx.sym.pad(data, mode='constant', pad_width=(0,0,0,0,ph0,ph1,pw0,pw1)) 27 | pad = (0,) * len(kernel) 28 | else: 29 | pad = (0,) * len(kernel) if pad is None else pad 30 | assert len(kwargs) == 0, sorted(kwargs) 31 | 32 | W = get_variable(name+'_weight', lr_mult, reuse) if weight is None else weight 33 | if no_bias: 34 | x = mx.sym.Convolution(data, num_filter=num_filter, kernel=kernel, stride=stride, dilate=dilate, pad=pad, 35 | num_group=num_group, no_bias=no_bias, name=name if reuse is None else None, 36 | weight=W) 37 | else: 38 | B = get_variable(name+'_bias', lr_mult, reuse) if bias is None else bias 39 | x = mx.sym.Convolution(data, num_filter=num_filter, kernel=kernel, stride=stride, dilate=dilate, pad=pad, 40 | num_group=num_group, no_bias=no_bias, name=name if reuse is None else None, 41 | weight=W, bias=B) 42 | return x 43 | 44 | def Deconvolution(data, num_filter, kernel, stride=None, dilate=None, pad=None, adj=None, target_shape=None, 45 | num_group=1, no_bias=False, weight=None, bias=None, name=None, lr_mult=1, reuse=None): 46 | if reuse is not None: 47 | assert name is not None 48 | name = GetLayerName.get('deconv') if name is None else name 49 | 50 | stride = (1,) * len(kernel) if stride is None else stride 51 | dilate = (1,) * len(kernel) if dilate is None else dilate 52 | pad = (0,) * len(kernel) if pad is None else pad 53 | adj = (0,) * len(kernel) if adj is None else adj 54 | target_shape = tuple([]) if target_shape is None else target_shape 55 | 56 | W = get_variable(name+'_weight', lr_mult, reuse) if weight is None else weight 57 | if no_bias: 58 | x = mx.sym.Deconvolution(data, num_filter=num_filter, kernel=kernel, stride=stride, dilate=dilate, pad=pad, 59 | adj=adj, target_shape=target_shape, num_group=num_group, no_bias=no_bias, 60 | name=name if reuse is None else None, weight=W) 61 | else: 62 | B = get_variable(name+'_bias', lr_mult, reuse) if bias is None else bias 63 | x = mx.sym.Deconvolution(data, num_filter=num_filter, kernel=kernel, stride=stride, dilate=dilate, pad=pad, 64 | adj=adj, target_shape=target_shape, num_group=num_group, no_bias=no_bias, 65 | name=name if reuse is None else None, weight=W, bias=B) 66 | return x 67 | 68 | def FullyConnected(data, num_hidden, flatten=True, no_bias=False, weight=None, bias=None, name=None, lr_mult=1, reuse=None): 69 | if reuse is not None: 70 | assert name is not None 71 | name = GetLayerName.get('fc') if name is None else name 72 | 73 | W = get_variable(name+'_weight', lr_mult, reuse) if weight is None else weight 74 | if no_bias: 75 | x = mx.sym.FullyConnected(data, num_hidden=num_hidden, flatten=flatten, no_bias=no_bias, weight=W, 76 | name=name if reuse is None else None) 77 | else: 78 | B = get_variable(name+'_bias', lr_mult, reuse) if bias is None else bias 79 | x = mx.sym.FullyConnected(data, num_hidden=num_hidden, flatten=flatten, no_bias=no_bias, weight=W, bias=B, 80 | name=name if reuse is None else None) 81 | return x 82 | 83 | def Relu(data, name=None): 84 | name = GetLayerName.get('relu') if name is None else name 85 | x = mx.sym.Activation(data, act_type='relu', name=name) 86 | return x 87 | 88 | def LeakyRelu(data, slope=0.25, name=None): 89 | name = GetLayerName.get('leakyRelu') if name is None else name 90 | x = mx.sym.LeakyReLU(data, slope=slope, act_type='leaky', name=name) 91 | return x 92 | 93 | def Tanh(data, name=None): 94 | name = GetLayerName.get('tanh') if name is None else name 95 | x = mx.sym.tanh(data, name=name) 96 | return x 97 | 98 | def Swish(data, name=None): 99 | name = GetLayerName.get('swish') if name is None else name 100 | x = data * mx.sym.sigmoid(data) 101 | return x 102 | 103 | def Pooling(data, kernel, stride=None, pad=None, pool_type='max', global_pool=False, name=None): 104 | name = GetLayerName.get('pool') if name is None else name 105 | 106 | stride = kernel if stride is None else stride 107 | pad = (0,) * len(kernel) if pad is None else pad 108 | 109 | x = mx.sym.Pooling(data, kernel=kernel, stride=stride, pad=pad, pool_type=pool_type, 110 | global_pool=global_pool, name=name) 111 | return x 112 | 113 | def Dropout(data, p, name=None): 114 | name = GetLayerName.get('drop') if name is None else name 115 | 116 | x = mx.sym.Dropout(data, p=p, name=name) 117 | return x 118 | 119 | def BatchNorm(data, fix_gamma=False, momentum=0.9, eps=1e-5, use_global_stats=False, gamma=None, beta=None, 120 | moving_mean=None, moving_var=None, name=None, lr_mult=1, reuse=None): 121 | if reuse is not None: 122 | assert name is not None 123 | name = GetLayerName.get('bn') if name is None else name 124 | 125 | gamma = get_variable(name+'_gamma', lr_mult, reuse) if gamma is None else gamma 126 | beta = get_variable(name+'_beta', lr_mult, reuse) if beta is None else beta 127 | moving_mean = get_variable(name+'_moving_mean', 1, reuse) if moving_mean is None else moving_mean 128 | moving_var = get_variable(name+'_moving_var', 1, reuse) if moving_var is None else moving_var 129 | 130 | x = mx.sym.BatchNorm(data, fix_gamma=fix_gamma, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 131 | gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, 132 | name=name if reuse is None else None) 133 | return x 134 | 135 | def InstanceNorm(data, eps=1e-5, gamma=None, beta=None, name=None, lr_mult=1, reuse=None): 136 | if reuse is not None: 137 | assert name is not None 138 | name = GetLayerName.get('in') if name is None else name 139 | 140 | gamma = get_variable(name+'_gamma', lr_mult, reuse) if gamma is None else gamma 141 | beta = get_variable(name+'_beta', lr_mult, reuse) if beta is None else beta 142 | 143 | x = mx.sym.InstanceNorm(data, eps=eps, gamma=gamma, beta=beta, name=name if reuse is None else None) 144 | return x 145 | 146 | def Flatten(data, name=None): 147 | name = GetLayerName.get('flatten') if name is None else name 148 | x = mx.sym.flatten(data, name=name) 149 | return x 150 | 151 | # ==== shortcuts ==== 152 | Conv = Convolution 153 | Deconv = Deconvolution 154 | FC = FullyConnected 155 | Pool = Pooling 156 | Drop = Dropout 157 | BN = BatchNorm 158 | IN = InstanceNorm 159 | 160 | def ConvRelu(*args, **kwargs): 161 | x = Conv(*args, **kwargs) 162 | x = Relu(x, x.name+'_relu') 163 | return x 164 | 165 | def BNRelu(*args, **kwargs): 166 | x = BN(*args, **kwargs) 167 | x = Relu(x, x.name+'_relu') 168 | return x 169 | 170 | def FCRelu(*args, **kwargs): 171 | x = FC(*args, **kwargs) 172 | x = Relu(x, x.name+'_relu') 173 | return x 174 | 175 | def ConvBNRelu(*args, **kwargs): 176 | x = Conv(*args, **kwargs) 177 | x = BN(x, name=x.name+'_bn', lr_mult=kwargs.get('lr_mult', 1), reuse=kwargs.get('reuse', None)) 178 | x = Relu(x, x.name+'_relu') 179 | return x 180 | 181 | # ==== __utils__ ==== 182 | def get_variable(name, lr_mult=1, reuse=None): 183 | if reuse is None: 184 | return mx.sym.Variable(name, lr_mult=lr_mult) 185 | return reuse.get_internals()[name] 186 | 187 | class GetLayerName(object): 188 | _name_count = {} 189 | 190 | @classmethod 191 | def get(cls, name_prefix): 192 | cnt = cls._name_count.get(name_prefix, 0) 193 | cls._name_count[name_prefix] = cnt + 1 194 | return name_prefix + str(cnt) 195 | 196 | def padding_helper(in_size, kernel_size, stride, pad_type='same'): 197 | pad_type = pad_type.lower() 198 | if pad_type == 'same': 199 | out_size = in_size // stride + int((in_size % stride) > 0) 200 | pad_size = max((out_size - 1) * stride + kernel_size - in_size, 0) 201 | return pad_size // 2, pad_size - pad_size // 2 202 | else: 203 | raise ValueError(pad_type) 204 | -------------------------------------------------------------------------------- /lib/models/layers_custom/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers_custom_v1 import * 2 | from .constant import * 3 | -------------------------------------------------------------------------------- /lib/models/layers_custom/constant.py: -------------------------------------------------------------------------------- 1 | from ..layers import * 2 | import numpy as np 3 | 4 | class OpConstant(mx.operator.CustomOp): 5 | def __init__(self, val): 6 | self.val = val 7 | 8 | def forward(self, is_train, req, in_data, out_data, aux): 9 | self.assign(out_data[0], req[0], self.val) 10 | 11 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 12 | pass 13 | 14 | @mx.operator.register('Constant') 15 | class OpConstantProp(mx.operator.CustomOpProp): 16 | def __init__(self, val_str, shape_str, type_str='float32'): 17 | super(OpConstantProp, self).__init__(need_top_grad=False) 18 | val = [float(x) for x in val_str.split(',')] 19 | shape = [int(x) for x in shape_str.split(',')] 20 | self.val = mx.nd.array(val, dtype=type_str).reshape(shape) 21 | 22 | def list_arguments(self): 23 | return [] 24 | 25 | def list_outputs(self): 26 | return ['output'] 27 | 28 | def infer_shape(self, in_shape): 29 | return in_shape, [self.val.shape], [] 30 | 31 | def infer_type(self, in_type): 32 | return in_type, [self.val.dtype], [] 33 | 34 | def create_operator(self, ctx, shapes, dtypes): 35 | return OpConstant(self.val.as_in_context(ctx)) 36 | 37 | def CustomConstantEncoder(value, dtype='float32'): 38 | if not isinstance(value, np.ndarray): 39 | if not isinstance(value, (list, tuple)): 40 | value = [value] 41 | value = np.array(value, dtype=dtype) 42 | return ','.join([str(x) for x in value.ravel()]), ','.join([str(x) for x in value.shape]) 43 | 44 | 45 | def Constant(value, dtype='float32'): 46 | assert isinstance(dtype, str), dtype 47 | val, shape = CustomConstantEncoder(value, dtype) 48 | return mx.sym.Custom(val_str=val, shape_str=shape, type_str=dtype, op_type='Constant') 49 | 50 | -------------------------------------------------------------------------------- /lib/models/layers_custom/layers_custom_v1.py: -------------------------------------------------------------------------------- 1 | from ..layers import * 2 | 3 | # ==== Bilinear Sampling ==== 4 | class BilinearScale(mx.operator.CustomOp): 5 | def __init__(self, scale): 6 | self.scale = scale 7 | 8 | def forward(self, is_train, req, in_data, out_data, aux): 9 | x = in_data[0] 10 | h, w = x.shape[2:] 11 | new_h = int((h - 1) * self.scale) + 1 12 | new_w = int((w - 1) * self.scale) + 1 13 | 14 | x.attach_grad() 15 | with mx.autograd.record(): 16 | new_x = mx.nd.contrib.BilinearResize2D(x, height=new_h, width=new_w) 17 | self.new_x = new_x 18 | self.x = x 19 | 20 | self.assign(out_data[0], req[0], new_x) 21 | 22 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 23 | self.new_x.backward(out_grad[0]) 24 | self.assign(in_grad[0], req[0], self.x.grad) 25 | 26 | @mx.operator.register("BilinearScale") 27 | class BilinearScaleProp(mx.operator.CustomOpProp): 28 | def __init__(self, scale): 29 | super(BilinearScaleProp, self).__init__(need_top_grad=True) 30 | self.scale = float(scale) 31 | 32 | def infer_shape(self, in_shape): 33 | n, c, h, w = in_shape[0] 34 | new_h = int((h - 1) * self.scale) + 1 35 | new_w = int((w - 1) * self.scale) + 1 36 | return in_shape, [(n, c, new_h, new_w)], [] 37 | 38 | def create_operator(self, ctx, shapes, dtypes): 39 | return BilinearScale(self.scale) 40 | 41 | class BilinearScaleLike(mx.operator.CustomOp): 42 | def forward(self, is_train, req, in_data, out_data, aux): 43 | x, x_ref = in_data 44 | new_h, new_w = x_ref.shape[2:] 45 | 46 | x.attach_grad() 47 | with mx.autograd.record(): 48 | new_x = mx.nd.contrib.BilinearResize2D(x, height=new_h, width=new_w) 49 | self.new_x = new_x 50 | self.x = x 51 | 52 | self.assign(out_data[0], req[0], new_x) 53 | 54 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 55 | self.new_x.backward(out_grad[0]) 56 | in_grad[1][:] = 0 57 | self.assign(in_grad[0], req[0], self.x.grad) 58 | 59 | @mx.operator.register("BilinearScaleLike") 60 | class BilinearScaleLikeProp(mx.operator.CustomOpProp): 61 | def __init__(self): 62 | super(BilinearScaleLikeProp, self).__init__(need_top_grad=True) 63 | 64 | def list_arguments(self): 65 | return ['d1', 'd2'] 66 | 67 | def infer_shape(self, in_shape): 68 | out_shape = list(in_shape[1]) 69 | out_shape[1] = in_shape[0][1] 70 | return in_shape, [out_shape,], [] 71 | 72 | def create_operator(self, ctx, shapes, dtypes): 73 | return BilinearScaleLike() 74 | 75 | 76 | # ==== Loss ==== 77 | class SegmentLoss(mx.operator.CustomOp): 78 | def __init__(self, has_grad_scale): 79 | self.has_grad_scale = has_grad_scale 80 | 81 | def forward(self, is_train, req, in_data, out_data, aux): 82 | # logit, label, (grad_scale) = in_data 83 | prediction = mx.nd.softmax(in_data[0], axis=1) 84 | self.assign(out_data[0], req[0], prediction) 85 | 86 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 87 | prediction = out_data[0] 88 | label = mx.nd.one_hot(in_data[1], depth=prediction.shape[1]).transpose((0, 3, 1, 2)) 89 | 90 | if prediction.shape[2] != label.shape[2]: 91 | label = mx.nd.contrib.BilinearResize2D(label, 92 | height=prediction.shape[2], width=prediction.shape[3]) 93 | label = mx.nd.one_hot(mx.nd.argmax(label, axis=1), 94 | depth=prediction.shape[1]).transpose((0, 3, 1, 2)) * (mx.nd.max(label, axis=1, keepdims=True) > 0.5) 95 | 96 | mask = label.sum(axis=1, keepdims=True) 97 | num_pixel = mx.nd.maximum(mask.sum() / mask.shape[0], 1e-5) 98 | 99 | grad = (prediction - label) * mask / num_pixel 100 | if self.has_grad_scale: 101 | grad_scale = in_data[2].reshape(-1, 1, 1, 1) 102 | grad = grad * grad_scale 103 | 104 | in_grad[1][:] = 0 105 | self.assign(in_grad[0], req[0], grad) 106 | 107 | @mx.operator.register("SegmentLoss") 108 | class SegmentLossProp(mx.operator.CustomOpProp): 109 | def __init__(self, has_grad_scale=0): 110 | super(SegmentLossProp, self).__init__(need_top_grad=False) 111 | self.has_grad_scale = int(has_grad_scale) > 0 112 | 113 | def list_arguments(self): 114 | if self.has_grad_scale: 115 | return ['data', 'label', 'scale'] 116 | else: 117 | return ['data', 'label'] 118 | 119 | def infer_shape(self, in_shape): 120 | return in_shape, [in_shape[0],], [] 121 | 122 | def create_operator(self, ctx, shapes, dtypes): 123 | return SegmentLoss(self.has_grad_scale) 124 | 125 | 126 | class CompletionLoss(mx.operator.CustomOp): 127 | def __init__(self, has_grad_scale): 128 | self.has_grad_scale = has_grad_scale 129 | 130 | def forward(self, is_train, req, in_data, out_data, aux): 131 | # logit, target, label, (grad_scale) = in_data 132 | prediction = mx.nd.softmax(in_data[0], axis=1) 133 | self.assign(out_data[0], req[0], prediction) 134 | 135 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 136 | logit, target, label = in_data[:3] 137 | prediction = out_data[0] 138 | 139 | onehot = target.argmax(axis=1) 140 | onehot = mx.nd.one_hot(onehot, depth=logit.shape[1]).transpose((0, 3, 1, 2)) 141 | 142 | label = mx.nd.one_hot(label, depth=logit.shape[1]).transpose((0, 3, 1, 2)) 143 | mask = label.max(axis=(2, 3), keepdims=True) 144 | onehot = onehot * mask 145 | 146 | mask = onehot.sum(axis=1, keepdims=True) 147 | num_pixel = mask.sum() / mask.shape[0] 148 | 149 | grad = (prediction - onehot) * mask / num_pixel 150 | 151 | if self.has_grad_scale: 152 | grad_scale = in_data[3].reshape(-1, 1, 1, 1) 153 | grad = grad * grad_scale 154 | 155 | in_grad[1][:] = 0 156 | in_grad[2][:] = 0 157 | self.assign(in_grad[0], req[0], grad) 158 | 159 | @mx.operator.register("CompletionLoss") 160 | class CompletionLossProp(mx.operator.CustomOpProp): 161 | def __init__(self, has_grad_scale=0): 162 | super(CompletionLossProp, self).__init__(need_top_grad=False) 163 | self.has_grad_scale = int(has_grad_scale) > 0 164 | 165 | def list_arguments(self): 166 | if self.has_grad_scale: 167 | return ['data', 'target', 'label', 'scale'] 168 | else: 169 | return ['data', 'target', 'label'] 170 | 171 | def infer_shape(self, in_shape): 172 | return in_shape, [in_shape[0]], [] 173 | 174 | def create_operator(self, ctx, shapes, dtypes): 175 | return CompletionLoss(self.has_grad_scale) 176 | 177 | 178 | class MultiSigmoidLoss(mx.operator.CustomOp): 179 | def forward(self, is_train, req, in_data, out_data, aux): 180 | logit, label = in_data 181 | prediction = mx.nd.sigmoid(logit, axis=1) 182 | self.assign(out_data[0], req[0], prediction) 183 | 184 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 185 | prediction = out_data[0] 186 | label = in_data[1] 187 | 188 | grad = prediction - label 189 | 190 | in_grad[1][:] = 0 191 | self.assign(in_grad[0], req[0], grad) 192 | 193 | @mx.operator.register("MultiSigmoidLoss") 194 | class MultiSigmoidLossProp(mx.operator.CustomOpProp): 195 | def __init__(self): 196 | super(MultiSigmoidLossProp, self).__init__(need_top_grad=False) 197 | 198 | def list_arguments(self): 199 | return ['data', 'label'] 200 | 201 | def list_outputs(self): 202 | return ['output'] 203 | 204 | def infer_shape(self, in_shape): 205 | return in_shape, [in_shape[0]], [] 206 | 207 | def create_operator(self, ctx, shapes, dtypes): 208 | return MultiSigmoidLoss() 209 | -------------------------------------------------------------------------------- /lib/models/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/lib/models/model/__init__.py -------------------------------------------------------------------------------- /lib/models/model/efficientnet.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .layers_custom import * 3 | import re 4 | import numpy as np 5 | from collections import namedtuple 6 | 7 | DEFAULT_EFFICIENT_PARAMS = { 8 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 9 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 10 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 11 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 12 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 13 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 14 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 15 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5) 16 | } 17 | 18 | DEFAULT_EFFICIENT_BLOCK_ARGS = [ 19 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 20 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 21 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 22 | 'r1_k3_s11_e6_i192_o320_se0.25', 23 | ] 24 | 25 | def config_efficientnet(model_name): 26 | assert re.match(r'^efficientnet-b[0-7]$', model_name), model_name 27 | 28 | #efficientnet_params = { 29 | # 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 30 | # 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 31 | # 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 32 | # 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 33 | # 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 34 | # 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 35 | # 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 36 | # 'efficientnet-b7': (2.0, 3.1, 600, 0.5) 37 | #}[model_name] 38 | efficientnet_params = DEFAULT_EFFICIENT_PARAMS[model_name] 39 | 40 | #block_args = [ 41 | # 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 42 | # 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 43 | # 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 44 | # 'r1_k3_s11_e6_i192_o320_se0.25', 45 | #] 46 | block_args = DEFAULT_EFFICIENT_BLOCK_ARGS 47 | 48 | #block_args = [ 49 | # 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 50 | # 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 51 | # 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s11_e6_i112_o192_se0.25_d2', 52 | # 'r1_k3_s11_e6_i192_o320_se0.25_d2', 53 | #] 54 | 55 | width_coefficient, depth_coefficient, resolution, dropout_rate = efficientnet_params 56 | global_params = { 57 | 'block_args': block_args, 58 | 'batch_norm_momentum': 0.99, 59 | 'batch_norm_epsilon': 1e-3, 60 | 'dropout_rate': dropout_rate, 61 | 'survival_prob': 0.8, 62 | 'num_classes': 1000, 63 | 'width_coefficient': width_coefficient, 64 | 'depth_coefficient': depth_coefficient, 65 | 'depth_divisor': 8, 66 | 'min_depth': None, 67 | 'use_se': True, 68 | 'clip_projection_output': False 69 | # data_format, relu_fn, batch_norm 70 | } 71 | global_params = namedtuple('global_parmas', sorted(global_params))(**global_params) 72 | 73 | kv_list = [dict([re.split(r'([\d\.]+)', op)[:2] for op in _block_args.split('_')]) for _block_args in block_args] 74 | block_args_list = [{ 75 | 'kernel_size': int(kv['k']), 76 | 'num_repeat': int(kv['r']), 77 | 'input_filters': int(kv['i']), 78 | 'output_filters': int(kv['o']), 79 | 'expand_ratio': int(kv['e']), 80 | 'id_skip': 'noskip' not in block_string, 81 | 'se_ratio': float(kv['se']) if 'se' in kv else None, 82 | 'strides': (int(kv['s'][0]), int(kv['s'][1])), 83 | 'conv_type': int(kv.get('c', '0')), 84 | 'fused_conv': int(kv.get('f', '0')), 85 | 'super_pixel': int(kv.get('p', '0')), 86 | 'dilate': int(kv.get('d', '1')), 87 | 'condconv': 'cc' in block_string, 88 | 'survival_prob': 1.0 89 | } for kv, block_string in zip(kv_list, block_args)] 90 | block_args_list = [namedtuple('block_args', sorted(x))(**x) for x in block_args_list] 91 | 92 | return block_args_list, global_params 93 | 94 | def MBConvBlock(data, block_args, global_params, use_global_stats, block_id, name, lr_mult, reuse, input_size=None): 95 | if block_args.super_pixel: 96 | raise NotImplementedError 97 | if block_args.condconv: 98 | raise NotImplementedError 99 | 100 | kernel = (block_args.kernel_size,)*2 101 | dilate = (1 if kernel[0] == 1 else block_args.dilate,)*2 102 | pad = ( ((kernel[0]-1)*dilate[0]+1)//2, )*2 103 | #pad = (kernel[0]//2,)*2 104 | #dilate = (1, 1) 105 | momentum = global_params.batch_norm_momentum 106 | eps = global_params.batch_norm_epsilon 107 | 108 | num_filters = block_args.input_filters * block_args.expand_ratio 109 | 110 | conv_id, bn_id = 0, 0 111 | 112 | if block_args.fused_conv: 113 | x = Conv(data, num_filters, kernel, block_args.strides, pad=pad, dilate=dilate, 114 | no_bias=True, name=name+'block%d_conv'%block_id, lr_mult=lr_mult, reuse=reuse) 115 | #x = Conv(data, num_filters, kernel, block_args.strides, pad='same', 116 | # no_bias=True, name=name+'block%d_conv'%block_id, lr_mult=lr_mult, reuse=reuse, input_size=input_size) 117 | else: 118 | if block_args.expand_ratio != 1: 119 | x = Conv(data, num_filters, (1,1), no_bias=True, name=name+'block%d_conv%d'%(block_id, conv_id), lr_mult=lr_mult, reuse=reuse) 120 | x = BN(x, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name=name+'block%d_bn%d'%(block_id, bn_id), lr_mult=lr_mult, reuse=reuse) 121 | x = Swish(x) 122 | conv_id, bn_id = conv_id + 1, bn_id + 1 123 | else: 124 | x = data 125 | 126 | x = Conv(x, num_filters, kernel, block_args.strides, pad=pad, dilate=dilate, 127 | num_group=num_filters, no_bias=True, name=name+'block%d_depthwise_conv0'%block_id, lr_mult=lr_mult, reuse=reuse) 128 | #x = Conv(x, num_filters, kernel, block_args.strides, pad='same', 129 | # num_group=num_filters, no_bias=True, name=name+'block%d_depthwise_conv0'%block_id, lr_mult=lr_mult, reuse=reuse, input_size=input_size) 130 | 131 | x = BN(x, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name=name+'block%d_bn%d'%(block_id, bn_id), lr_mult=lr_mult, reuse=reuse) 132 | x = Swish(x) 133 | bn_id += 1 134 | 135 | has_se = global_params.use_se and block_args.se_ratio is not None and 0 < block_args.se_ratio <= 1 136 | if has_se: 137 | num_filters_rd = max(1, int(block_args.input_filters * block_args.se_ratio)) 138 | x_se = mx.sym.mean(x, axis=(2, 3), keepdims=True) 139 | x_se = Conv(x_se, num_filters_rd, (1,1), name=name+'block%d_se_conv0'%block_id, lr_mult=lr_mult, reuse=reuse) 140 | x_se = Swish(x_se) 141 | x_se = Conv(x_se, num_filters, (1,1), name=name+'block%d_se_conv1'%block_id, lr_mult=lr_mult, reuse=reuse) 142 | x = mx.sym.broadcast_mul(mx.sym.sigmoid(x_se), x) 143 | 144 | x = Conv(x, block_args.output_filters, (1,1), no_bias=True, name=name+'block%d_conv%d'%(block_id, conv_id), lr_mult=lr_mult, reuse=reuse) 145 | x = BN(x, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name=name+'block%d_bn%d'%(block_id, bn_id), lr_mult=lr_mult, reuse=reuse) 146 | conv_id, bn_id = conv_id + 1, bn_id + 1 147 | 148 | if global_params.clip_projection_output: 149 | x = mx.sym.clip(x, a_min=-6, a_max=6) 150 | 151 | if block_args.id_skip and all([s == 1 for s in block_args.strides]) and block_args.input_filters == block_args.output_filters: 152 | if block_args.survival_prob > 0: 153 | x = mx.sym.Custom(x, p=1-block_args.survival_prob, op_type='DropConnect') 154 | x = x + data 155 | 156 | return x 157 | 158 | def MBConvBlockWithoutDepthwise(data, block_args, global_params, use_global_stats, begin_id, name, lr_mult, reuse): 159 | raise NotImplementedError 160 | 161 | def meta_efficientnet(model_name, get_internals=False, input_size=None): 162 | block_args_list, global_params = config_efficientnet(model_name) 163 | 164 | def round_filters(num_filters): 165 | multiplier = global_params.width_coefficient 166 | if not multiplier: 167 | return num_filters 168 | divisor = global_params.depth_divisor 169 | min_depth = global_params.min_depth 170 | 171 | num_filters = num_filters * multiplier 172 | new_num = max(min_depth or divisor, int(num_filters+divisor/2)//divisor*divisor) 173 | if new_num < 0.9 * num_filters: 174 | new_num += divisor 175 | return int(new_num) 176 | 177 | def round_repeats(num_repeat): 178 | multiplier = global_params.depth_coefficient 179 | if not multiplier: 180 | return num_repeat 181 | return int(np.ceil(multiplier * num_repeat)) 182 | 183 | def efficient_model(data, use_global_stats=False, bn_data=False, name=None, lr_mult=1, reuse=None): 184 | name = '' if name is None else name 185 | endpoints = {} 186 | 187 | momentum = global_params.batch_norm_momentum 188 | eps = global_params.batch_norm_epsilon 189 | endpoints['input'] = data 190 | 191 | # data 192 | if bn_data: 193 | data = BN(data, fix_gamma=True, momentum=momentum, eps=eps, name='bn_data', lr_mult=lr_mult, reuse=reuse) 194 | 195 | # Stem 196 | x = Conv(data, round_filters(32), (3,3), (2,2), pad=(1,1), no_bias=True, name=name+'stem_conv0', lr_mult=lr_mult, reuse=reuse) 197 | #x = Conv(data, round_filters(32), (3,3), (2,2), pad='same', no_bias=True, name=name+'stem_conv0', lr_mult=lr_mult, reuse=reuse, input_size=input_size) 198 | x = BN(x, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='stem_bn0', lr_mult=lr_mult, reuse=reuse) 199 | x = Swish(x) 200 | endpoints['stem'] = x 201 | 202 | # Blocks 203 | block_id = 0 204 | total_blocks = sum([block_args.num_repeat for block_args in block_args_list]) 205 | survival_prob = global_params.survival_prob 206 | 207 | for i, block_args in enumerate(block_args_list): 208 | assert block_args.num_repeat > 0 209 | assert block_args.super_pixel in [0, 1, 2] 210 | block_args = block_args._replace( 211 | input_filters=round_filters(block_args.input_filters), 212 | output_filters=round_filters(block_args.output_filters), 213 | num_repeat=round_repeats(block_args.num_repeat), 214 | survival_prob=1.0 - (1.0-global_params.survival_prob)*float(block_id)/total_blocks 215 | ) 216 | 217 | ConvBlock = {0: MBConvBlock, 1: MBConvBlockWithoutDepthwise}[block_args.conv_type] 218 | 219 | # the first block 220 | x = ConvBlock(x, block_args, global_params, use_global_stats=use_global_stats, block_id=block_id, name=name, lr_mult=lr_mult, reuse=reuse) 221 | endpoints['block%d'%block_id] = x 222 | block_id += 1 223 | 224 | # the following blocks 225 | for j in range(block_args.num_repeat - 1): 226 | block_args = block_args._replace( 227 | input_filters=block_args.output_filters, 228 | strides=(1, 1), 229 | survival_prob=1.0 - (1.0-global_params.survival_prob)*float(block_id)/total_blocks 230 | ) 231 | x = ConvBlock(x, block_args, global_params, use_global_stats=use_global_stats, block_id=block_id, name=name, lr_mult=lr_mult, reuse=reuse, input_size=input_size) 232 | endpoints['block%d'%block_id] = x 233 | block_id += 1 234 | 235 | # Head 236 | x = Conv(x, round_filters(1280), (1,1), no_bias=True, name=name+'head_conv0', lr_mult=lr_mult, reuse=reuse) 237 | x = BN(x, momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='head_bn0', lr_mult=lr_mult, reuse=reuse) 238 | x = Swish(x) 239 | if global_params.dropout_rate > 0: 240 | x = Drop(x, p=global_params.dropout_rate) 241 | endpoints['head'] = x 242 | 243 | x = Pool(x, kernel=(1, 1), pool_type='avg', global_pool=True) 244 | x = mx.sym.flatten(x) 245 | x = FC(x, global_params.num_classes, name=name+'head_fc0', lr_mult=lr_mult, reuse=reuse) 246 | endpoints['logit'] = x 247 | x = mx.sym.softmax(x, axis=1) 248 | endpoints['prob'] = x 249 | 250 | if get_internals: 251 | return x, endpoints 252 | return x 253 | 254 | return efficient_model 255 | 256 | efficientnet_b0 = meta_efficientnet('efficientnet-b0') 257 | 258 | 259 | def tf2mx_params(ckpt_file, dst_file=None, name='', use_ema=True): 260 | convert_w = lambda x: mx.nd.array(x.transpose(3, 2, 0, 1) if x.ndim == 4 else x.T) 261 | convert_b = lambda x: mx.nd.array(x) 262 | convert_dp_w = lambda x: mx.nd.array(x.transpose(2, 3, 0, 1)) 263 | lookup_ptype = {'kernel': ('arg', 'weight', convert_w), 'bias': ('arg', 'bias', convert_b), 264 | 'depthwise_kernel': ('arg', 'weight', convert_dp_w), 265 | 'gamma': ('arg', 'gamma', convert_b), 'beta': ('arg', 'beta', convert_b), 266 | 'moving_mean': ('aux', 'moving_mean', convert_b), 'moving_variance': ('aux', 'moving_var', convert_b)} 267 | lookup_op = {'conv2d': 'conv', 'depthwise_conv2d': 'depthwise_conv', 268 | 'tpu_batch_normalization': 'bn', 'dense': 'fc'} 269 | 270 | def mapKey(tf_key): 271 | names = tf_key.split('/') 272 | if not re.match(r'^efficientnet-b[0-7]$', names[0]): 273 | return None, None 274 | block, op, ptype = names[1:4] 275 | 276 | if block.startswith('blocks'): 277 | block = 'block' + block.split('_')[-1] 278 | block_name = name + block + '_' 279 | 280 | if op == 'se': 281 | op, ptype = names[3:5] 282 | block_name = block_name + 'se_' 283 | 284 | r = re.match(r'^\w*_(\d+)$', op) 285 | op_id = (r.group(1) if r else '0') 286 | _op = re.match(r'^(\w+)_\d+$', op).group(1) if r else op 287 | try: 288 | prefix, suffix, converter = lookup_ptype[ptype] 289 | except: 290 | raise KeyError("[{}], ({}, {}, {}), {}".format(ptype, block, op, ptype, tf_key)) 291 | 292 | op_name = lookup_op[_op] 293 | return prefix + ':' + block_name+op_name+op_id + '_' + suffix, converter 294 | 295 | try: 296 | import tensorflow.compat.v1 as tf 297 | except ImportError: 298 | import tensorflow as tf 299 | reader = tf.train.load_checkpoint(ckpt_file) 300 | shape_map = reader.get_variable_to_shape_map() 301 | keys = sorted(shape_map.keys()) 302 | 303 | ema_keys = [k for k in keys if k.endswith('ExponentialMovingAverage')] 304 | keys = list(set(list(set(keys) - set(ema_keys)) + [k.rsplit('/', 1)[0] for k in ema_keys])) 305 | keys_ = [k + '/ExponentialMovingAverage' for k in keys] 306 | kk = {k: k_ if (use_ema and (k_ in ema_keys)) else k for k, k_ in zip(keys, keys_)} 307 | 308 | mx_params = {} 309 | for k in kk.keys(): 310 | tf_key = kk[k] 311 | mx_key, converter = mapKey(k) 312 | if mx_key is None: 313 | if tf_key != 'global_step': 314 | print("Cannot parse tf_key: %s" % tf_key) 315 | continue 316 | if mx_key in mx_params: 317 | raise KeyError("Duplicate key: %s, %s, %s" % (k, tf_key, mx_key)) 318 | mx_params[mx_key] = converter(reader.get_tensor(tf_key)) 319 | 320 | if dst_file is not None: 321 | mx.nd.save(dst_file, mx_params) 322 | 323 | arg_params = {k[4:]: v for k, v in mx_params.items() if k.startswith('arg:')} 324 | aux_params = {k[4:]: v for k, v in mx_params.items() if k.startswith('aux:')} 325 | return arg_params, aux_params 326 | -------------------------------------------------------------------------------- /lib/models/model/inception_bn.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def incepConv(data, num_filter, kernel, stride=None, dilate=None, pad=None, momentum=0.9, eps=1e-5, 4 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 5 | assert name is not None 6 | x = Conv(data, num_filter, kernel, stride, dilate, pad, name='conv_%s'%name, lr_mult=lr_mult, reuse=reuse) 7 | x = BN(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 8 | name='bn_%s'%name, lr_mult=lr_mult, reuse=reuse) 9 | x = Relu(x) 10 | return x 11 | 12 | def incepBlockA(data, num_filter_1, num_filter_3r, num_filter_3, num_filter_d3r, num_filter_d3, num_filter_p, 13 | pool_type, dilate=1, momentum=0.9, eps=1e-5, use_global_stats=False, 14 | name=None, lr_mult=1, reuse=None): 15 | assert name is not None 16 | 17 | x1 = incepConv(data, num_filter_1, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 18 | name='%s_1x1'%name, lr_mult=lr_mult, reuse=reuse) 19 | 20 | x3 = incepConv(data, num_filter_3r, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 21 | name='%s_3x3_reduce'%name, lr_mult=lr_mult, reuse=reuse) 22 | x3 = incepConv(x3, num_filter_3, (3, 3), pad=(dilate,)*2, dilate=(dilate,)*2, momentum=momentum, eps=eps, 23 | use_global_stats=use_global_stats, name='%s_3x3'%name, lr_mult=lr_mult, reuse=reuse) 24 | 25 | xd3 = incepConv(data, num_filter_d3r, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 26 | name='%s_double_3x3_reduce'%name, lr_mult=lr_mult, reuse=reuse) 27 | xd3 = incepConv(xd3, num_filter_d3, (3, 3), pad=(dilate,)*2, dilate=(dilate,)*2, momentum=momentum, eps=eps, 28 | use_global_stats=use_global_stats, name='%s_double_3x3_0'%name, lr_mult=lr_mult, reuse=reuse) 29 | xd3 = incepConv(xd3, num_filter_d3, (3, 3), pad=(dilate,)*2, dilate=(dilate,)*2, momentum=momentum, eps=eps, 30 | use_global_stats=use_global_stats, name='%s_double_3x3_1'%name, lr_mult=lr_mult, reuse=reuse) 31 | 32 | xp = Pool(data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool_type) 33 | xp = incepConv(xp, num_filter_p, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 34 | name='%s_proj'%name, lr_mult=lr_mult, reuse=reuse) 35 | 36 | x = mx.sym.Concat(x1, x3, xd3, xp, dim=1, name='ch_concat_%s_chconcat'%name) 37 | return x 38 | 39 | def incepBlockB(data, num_filter_3r, num_filter_3, num_filter_d3r, num_filter_d3, 40 | stride=2, dilate=1, momentum=0.9, eps=1e-5, use_global_stats=False, 41 | name=None, lr_mult=1, reuse=None): 42 | assert name is not None 43 | 44 | x3 = incepConv(data, num_filter_3r, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 45 | name='%s_3x3_reduce'%name, lr_mult=lr_mult, reuse=reuse) 46 | x3 = incepConv(x3, num_filter_3, (3, 3), stride=(stride,)*2, pad=(dilate,)*2, dilate=(dilate,)*2, 47 | momentum=momentum, eps=eps, use_global_stats=use_global_stats, name='%s_3x3'%name, lr_mult=lr_mult, reuse=reuse) 48 | 49 | xd3 = incepConv(data, num_filter_d3r, (1, 1), momentum=momentum, eps=eps, use_global_stats=use_global_stats, 50 | name='%s_double_3x3_reduce'%name, lr_mult=lr_mult, reuse=reuse) 51 | xd3 = incepConv(xd3, num_filter_d3, (3, 3), stride=(1, 1), pad=(dilate,)*2, dilate=(dilate,)*2, 52 | momentum=momentum, eps=eps, use_global_stats=use_global_stats, 53 | name='%s_double_3x3_0'%name, lr_mult=lr_mult, reuse=reuse) 54 | xd3 = incepConv(xd3, num_filter_d3, (3, 3), stride=(stride,)*2, pad=(dilate,)*2, dilate=(dilate,)*2, 55 | momentum=momentum, eps=eps, use_global_stats=use_global_stats, 56 | name='%s_double_3x3_1'%name, lr_mult=lr_mult, reuse=reuse) 57 | 58 | xp = Pool(data, kernel=(3, 3), stride=(stride,)*2, pad=(1, 1), pool_type='max') 59 | 60 | x = mx.sym.Concat(x3, xd3, xp, dim=1, name='ch_concat_%s_chconcat'%name) 61 | return x 62 | 63 | 64 | def inceptionBN(x, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, name=None, lr_mult=1, reuse=None): 65 | name = '' if name is None else name 66 | if bn_data: 67 | x = BN(x, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 68 | name=name+'bn_data', reuse=reuse) 69 | 70 | x = incepConv(x, 64, (7, 7), stride=(2, 2), pad=(3, 3), name=name+'1', lr_mult=lr_mult, reuse=reuse) 71 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), name=name+'pool_1', pool_type='max') 72 | 73 | x = incepConv(x, 64, (1, 1), stride=(1, 1), pad=(0, 0), name=name+'2_red', lr_mult=lr_mult, reuse=reuse) 74 | x = incepConv(x, 192, (3, 3), stride=(1, 1), pad=(1, 1), name=name+'2', lr_mult=lr_mult, reuse=reuse) 75 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), name=name+'pool_2', pool_type='max') 76 | 77 | x = incepBlockA(x, 64, 64, 64, 64, 96, 32, 'avg', 1, momentum, eps, use_global_stats, '3a', lr_mult, reuse) 78 | x = incepBlockA(x, 64, 64, 96, 64, 96, 64, 'avg', 1, momentum, eps, use_global_stats, '3b', lr_mult, reuse) 79 | x = incepBlockB(x, 128, 160, 64, 96, 1, 2, momentum, eps, use_global_stats, '3c', lr_mult, reuse) 80 | 81 | x = incepBlockA(x, 224, 64, 96, 96, 128, 128, 'avg', 2, momentum, eps, use_global_stats, '4a', lr_mult, reuse) 82 | x = incepBlockA(x, 192, 96, 128, 96, 128, 128, 'avg', 2, momentum, eps, use_global_stats, '4b', lr_mult, reuse) 83 | x = incepBlockA(x, 160, 128, 160, 128, 160, 128, 'avg', 2, momentum, eps, use_global_stats, '4c', lr_mult, reuse) 84 | x = incepBlockA(x, 96, 128, 192, 160, 192, 128, 'avg', 2, momentum, eps, use_global_stats, '4d', lr_mult, reuse) 85 | x = incepBlockB(x, 128, 192, 192, 256, 1, 4, momentum, eps, use_global_stats, '4e', lr_mult, reuse) 86 | 87 | x = incepBlockA(x, 352, 192, 320, 160, 224, 128, 'avg', 4, momentum, eps, use_global_stats, '5a', lr_mult, reuse) 88 | x = incepBlockA(x, 352, 192, 320, 192, 224, 128, 'max', 4, momentum, eps, use_global_stats, '5b', lr_mult, reuse) 89 | return x 90 | 91 | 92 | -------------------------------------------------------------------------------- /lib/models/model/layers_custom/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers_custom_v1 import * 2 | from .drop_connect import * 3 | from .activations import * 4 | from .constant import * 5 | -------------------------------------------------------------------------------- /lib/models/model/layers_custom/activations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/lib/models/model/layers_custom/activations.py -------------------------------------------------------------------------------- /lib/models/model/layers_custom/constant.py: -------------------------------------------------------------------------------- 1 | from ..layers import * 2 | import numpy as np 3 | 4 | class OpConstant(mx.operator.CustomOp): 5 | def __init__(self, val): 6 | self.val = val 7 | 8 | def forward(self, is_train, req, in_data, out_data, aux): 9 | self.assign(out_data[0], req[0], self.val) 10 | 11 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 12 | pass 13 | 14 | @mx.operator.register('Constant') 15 | class OpConstantProp(mx.operator.CustomOpProp): 16 | def __init__(self, val_str, shape_str, type_str='float32'): 17 | super(OpConstantProp, self).__init__(need_top_grad=False) 18 | val = [float(x) for x in val_str.split(',')] 19 | shape = [int(x) for x in shape_str.split(',')] 20 | self.val = mx.nd.array(val, dtype=type_str).reshape(shape) 21 | 22 | def list_arguments(self): 23 | return [] 24 | 25 | def list_outputs(self): 26 | return ['output'] 27 | 28 | def infer_shape(self, in_shape): 29 | return in_shape, [self.val.shape], [] 30 | 31 | def infer_type(self, in_type): 32 | return in_type, [self.val.dtype], [] 33 | 34 | def create_operator(self, ctx, shapes, dtypes): 35 | return OpConstant(self.val.as_in_context(ctx)) 36 | 37 | def CustomConstantEncoder(value, dtype='float32'): 38 | if not isinstance(value, np.ndarray): 39 | if not isinstance(value, (list, tuple)): 40 | value = [value] 41 | value = np.array(value, dtype=dtype) 42 | return ','.join([str(x) for x in value.ravel()]), ','.join([str(x) for x in value.shape]) 43 | 44 | 45 | def Constant(value, dtype='float32'): 46 | assert isinstance(dtype, str), dtype 47 | val, shape = CustomConstantEncoder(value, dtype) 48 | return mx.sym.Custom(val_str=val, shape_str=shape, type_str=dtype, op_type='Constant') 49 | 50 | -------------------------------------------------------------------------------- /lib/models/model/layers_custom/drop_connect.py: -------------------------------------------------------------------------------- 1 | from ..layers import * 2 | 3 | class DropConnect(mx.operator.CustomOp): 4 | def __init__(self, p): 5 | self.drop_rate = p 6 | self.mask = None 7 | 8 | def forward(self, is_train, req, in_data, out_data, aux): 9 | data = in_data[0] 10 | if is_train or self.drop_rate == 0: 11 | mask_shape = [data.shape[0]] + [1] * (len(data.shape)-1) 12 | mask = mx.nd.random.uniform(0, 1, mask_shape, ctx=data.context) 13 | mask = (mask > self.drop_rate) / (1 - self.drop_rate) 14 | out = data * mask 15 | self.mask = mask 16 | else: 17 | out = data 18 | self.mask = None 19 | self.assign(out_data[0], req[0], out) 20 | 21 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 22 | if self.mask is None: 23 | grad = out_grad[0].copy() 24 | else: 25 | grad = out_grad[0] * self.mask 26 | 27 | self.assign(in_grad[0], req[0], grad) 28 | 29 | @mx.operator.register("DropConnect") 30 | class DropConnectProp(mx.operator.CustomOpProp): 31 | def __init__(self, p): 32 | super(DropConnectProp, self).__init__(need_top_grad=True) 33 | self.drop_rate = float(p) 34 | assert self.drop_rate >= 0 and self.drop_rate < 1 35 | 36 | def list_arguments(self): 37 | return ['data'] 38 | 39 | def list_outputs(self): 40 | return ['output'] 41 | 42 | def infer_shape(self, in_shape): 43 | return in_shape, in_shape, [] 44 | 45 | def create_operator(self, ctx, shapes, dtypes): 46 | return DropConnect(self.drop_rate) 47 | -------------------------------------------------------------------------------- /lib/models/model/layers_custom/layers_custom_v1.py: -------------------------------------------------------------------------------- 1 | from ..layers import * 2 | 3 | # ==== Bilinear Sampling ==== 4 | class BilinearScale(mx.operator.CustomOp): 5 | def __init__(self, scale): 6 | self.scale = scale 7 | 8 | def forward(self, is_train, req, in_data, out_data, aux): 9 | x = in_data[0] 10 | h, w = x.shape[2:] 11 | new_h = int((h - 1) * self.scale) + 1 12 | new_w = int((w - 1) * self.scale) + 1 13 | 14 | x.attach_grad() 15 | with mx.autograd.record(): 16 | new_x = mx.nd.contrib.BilinearResize2D(x, height=new_h, width=new_w) 17 | self.new_x = new_x 18 | self.x = x 19 | 20 | self.assign(out_data[0], req[0], new_x) 21 | 22 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 23 | self.new_x.backward(out_grad[0]) 24 | self.assign(in_grad[0], req[0], self.x.grad) 25 | 26 | @mx.operator.register("BilinearScale") 27 | class BilinearScaleProp(mx.operator.CustomOpProp): 28 | def __init__(self, scale): 29 | super(BilinearScaleProp, self).__init__(need_top_grad=True) 30 | self.scale = float(scale) 31 | 32 | def infer_shape(self, in_shape): 33 | n, c, h, w = in_shape[0] 34 | new_h = int((h - 1) * self.scale) + 1 35 | new_w = int((w - 1) * self.scale) + 1 36 | return in_shape, [(n, c, new_h, new_w)], [] 37 | 38 | def create_operator(self, ctx, shapes, dtypes): 39 | return BilinearScale(self.scale) 40 | 41 | class BilinearScaleLike(mx.operator.CustomOp): 42 | def forward(self, is_train, req, in_data, out_data, aux): 43 | x, x_ref = in_data 44 | new_h, new_w = x_ref.shape[2:] 45 | 46 | x.attach_grad() 47 | with mx.autograd.record(): 48 | new_x = mx.nd.contrib.BilinearResize2D(x, height=new_h, width=new_w) 49 | self.new_x = new_x 50 | self.x = x 51 | 52 | self.assign(out_data[0], req[0], new_x) 53 | 54 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 55 | self.new_x.backward(out_grad[0]) 56 | in_grad[1][:] = 0 57 | self.assign(in_grad[0], req[0], self.x.grad) 58 | 59 | @mx.operator.register("BilinearScaleLike") 60 | class BilinearScaleLikeProp(mx.operator.CustomOpProp): 61 | def __init__(self): 62 | super(BilinearScaleLikeProp, self).__init__(need_top_grad=True) 63 | 64 | def list_arguments(self): 65 | return ['d1', 'd2'] 66 | 67 | def infer_shape(self, in_shape): 68 | out_shape = list(in_shape[1]) 69 | out_shape[1] = in_shape[0][1] 70 | return in_shape, [out_shape,], [] 71 | 72 | def create_operator(self, ctx, shapes, dtypes): 73 | return BilinearScaleLike() 74 | 75 | 76 | # ==== Loss ==== 77 | class SegmentLoss(mx.operator.CustomOp): 78 | def __init__(self, has_grad_scale): 79 | self.has_grad_scale = has_grad_scale 80 | 81 | def forward(self, is_train, req, in_data, out_data, aux): 82 | # logit, label, (grad_scale) = in_data 83 | prediction = mx.nd.softmax(in_data[0], axis=1) 84 | self.assign(out_data[0], req[0], prediction) 85 | 86 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 87 | prediction = out_data[0] 88 | label = mx.nd.one_hot(in_data[1], depth=prediction.shape[1]).transpose((0, 3, 1, 2)) 89 | 90 | if prediction.shape[2] != label.shape[2]: 91 | label = mx.nd.contrib.BilinearResize2D(label, 92 | height=prediction.shape[2], width=prediction.shape[3]) 93 | label = mx.nd.one_hot(mx.nd.argmax(label, axis=1), 94 | depth=prediction.shape[1]).transpose((0, 3, 1, 2)) * (mx.nd.max(label, axis=1, keepdims=True) > 0.5) 95 | 96 | mask = label.sum(axis=1, keepdims=True) 97 | num_pixel = mx.nd.maximum(mask.sum() / mask.shape[0], 1e-5) 98 | 99 | grad = (prediction - label) * mask / num_pixel 100 | if self.has_grad_scale: 101 | grad_scale = in_data[2].reshape(-1, 1, 1, 1) 102 | grad = grad * grad_scale 103 | 104 | in_grad[1][:] = 0 105 | self.assign(in_grad[0], req[0], grad) 106 | 107 | @mx.operator.register("SegmentLoss") 108 | class SegmentLossProp(mx.operator.CustomOpProp): 109 | def __init__(self, has_grad_scale=0): 110 | super(SegmentLossProp, self).__init__(need_top_grad=False) 111 | self.has_grad_scale = int(has_grad_scale) > 0 112 | 113 | def list_arguments(self): 114 | if self.has_grad_scale: 115 | return ['data', 'label', 'scale'] 116 | else: 117 | return ['data', 'label'] 118 | 119 | def infer_shape(self, in_shape): 120 | return in_shape, [in_shape[0],], [] 121 | 122 | def create_operator(self, ctx, shapes, dtypes): 123 | return SegmentLoss(self.has_grad_scale) 124 | 125 | class MultiSigmoidLoss(mx.operator.CustomOp): 126 | def forward(self, is_train, req, in_data, out_data, aux): 127 | logit, label = in_data 128 | prediction = mx.nd.sigmoid(logit, axis=1) 129 | self.assign(out_data[0], req[0], prediction) 130 | 131 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 132 | prediction = out_data[0] 133 | label = in_data[1] 134 | 135 | grad = prediction - label 136 | 137 | in_grad[1][:] = 0 138 | self.assign(in_grad[0], req[0], grad) 139 | 140 | @mx.operator.register("MultiSigmoidLoss") 141 | class MultiSigmoidLossProp(mx.operator.CustomOpProp): 142 | def __init__(self): 143 | super(MultiSigmoidLossProp, self).__init__(need_top_grad=False) 144 | 145 | def list_arguments(self): 146 | return ['data', 'label'] 147 | 148 | def list_outputs(self): 149 | return ['output'] 150 | 151 | def infer_shape(self, in_shape): 152 | return in_shape, [in_shape[0]], [] 153 | 154 | def create_operator(self, ctx, shapes, dtypes): 155 | return MultiSigmoidLoss() 156 | 157 | class MultiSoftmaxLoss(mx.operator.CustomOp): 158 | def forward(self, is_train, req, in_data, out_data, aux): 159 | logit, label = in_data 160 | prediction = mx.nd.softmax(logit, axis=1) 161 | self.assign(out_data[0], req[0], prediction) 162 | 163 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 164 | prediction = out_data[0] 165 | label = in_data[1] 166 | 167 | grad = prediction - label 168 | 169 | in_grad[1][:] = 0 170 | self.assign(in_grad[0], req[0], grad) 171 | 172 | @mx.operator.register("MultiSoftmaxLoss") 173 | class MultiSoftmaxLossProp(mx.operator.CustomOpProp): 174 | def __init__(self): 175 | super(MultiSoftmaxLossProp, self).__init__(need_top_grad=False) 176 | 177 | def list_arguments(self): 178 | return ['data', 'label'] 179 | 180 | def list_outputs(self): 181 | return ['output'] 182 | 183 | def infer_shape(self, in_shape): 184 | return in_shape, [in_shape[0]], [] 185 | 186 | def create_operator(self, ctx, shapes, dtypes): 187 | return MultiSoftmaxLoss() 188 | 189 | # ==== Others ==== 190 | 191 | -------------------------------------------------------------------------------- /lib/models/model/resnet.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def ResStem(data, num_filter, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 4 | name=None, lr_mult=1, reuse=None): 5 | name = '' if name is None else name 6 | if bn_data: 7 | x = BN(data, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 8 | name=name+'bn_data', reuse=reuse) 9 | else: 10 | x = data 11 | 12 | x = Conv(x, num_filter=num_filter, kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, 13 | name=name+'conv0', lr_mult=lr_mult, reuse=reuse) 14 | x = BN(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 15 | name=name+'bn0', lr_mult=lr_mult, reuse=reuse) 16 | x = Relu(x, name=name+'relu0') 17 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name=name+'pool0') 18 | return x 19 | 20 | def ResUnit(data, num_filter, stride, dilate, projection, bottle_neck, momentum=0.9, eps=1e-5, 21 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 22 | assert name is not None 23 | x = BNRelu(data, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 24 | name=name+'_bn1', lr_mult=lr_mult, reuse=reuse) 25 | 26 | if projection: 27 | shortcut = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(stride,)*2, 28 | pad=(0, 0), no_bias=True, name=name+'_sc', lr_mult=lr_mult, reuse=reuse) 29 | else: 30 | shortcut = data 31 | 32 | if bottle_neck: 33 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(1, 1), stride=(1, 1), pad=(0, 0), 34 | no_bias=True, name=name+'_conv1', lr_mult=lr_mult, reuse=reuse) 35 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 36 | name=name+'_bn2', lr_mult=lr_mult, reuse=reuse) 37 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(3, 3), stride=(stride,)*2, pad=(dilate,)*2, 38 | dilate=(dilate,)*2, no_bias=True, name=name+'_conv2', lr_mult=lr_mult, reuse=reuse) 39 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 40 | name=name+'_bn3', lr_mult=lr_mult, reuse=reuse) 41 | x = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 42 | no_bias=True, name=name+'_conv3', lr_mult=lr_mult, reuse=reuse) 43 | else: 44 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(stride,)*2, pad=(dilate,)*2, 45 | dilate=(dilate,)*2, no_bias=True, name=name+'_conv1', lr_mult=lr_mult, reuse=reuse) 46 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 47 | name=name+'_bn2', lr_mult=lr_mult, reuse=reuse) 48 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1), 49 | no_bias=True, name=name+'_conv2', lr_mult=lr_mult, reuse=reuse) 50 | 51 | x = x + shortcut 52 | return x 53 | 54 | def ResBlock(data, num_unit, num_filter, stride, dilate, bottle_neck, momentum=0.9, eps=1e-5, 55 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 56 | assert name is not None 57 | x = ResUnit(data, num_filter, stride, dilate, True, bottle_neck, momentum, eps, 58 | use_global_stats, name+'_unit1', lr_mult, reuse) 59 | for i in range(1, num_unit): 60 | x = ResUnit(x, num_filter, 1, dilate, False, bottle_neck, momentum, eps, 61 | use_global_stats, name+'_unit%d'%(i+1), lr_mult, reuse) 62 | return x 63 | 64 | def _Resnet(x, num_units, num_filters, bottle_neck, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 65 | strides=(1, 2, 2, 2), dilates=(1, 1, 1, 1), name=None, lr_mult=1, reuse=None): 66 | name = '' if name is None else name 67 | 68 | x = ResStem(x, num_filters[0], momentum, eps, use_global_stats, bn_data, name, lr_mult, reuse) 69 | for i in range(4): 70 | x = ResBlock(x, num_units[i], num_filters[i+1], strides[i], dilates[i], bottle_neck, 71 | momentum, eps, use_global_stats, name+'stage%d'%(i+1), lr_mult, reuse) 72 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 73 | name=name+'bn1', lr_mult=lr_mult, reuse=reuse) 74 | return x 75 | 76 | def resnet18(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 77 | name = '' if name is None else name 78 | x = _Resnet(x, (2, 2, 2, 2), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 79 | name=name, lr_mult=lr_mult, reuse=reuse) 80 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 81 | x = Flatten(x) 82 | return x 83 | 84 | def resnet34(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 85 | name = '' if name is None else name 86 | x = _Resnet(x, (3, 4, 6, 3), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 87 | name=name, lr_mult=lr_mult, reuse=reuse) 88 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 89 | x = Flatten(x) 90 | return x 91 | 92 | def resnet50(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 93 | name = '' if name is None else name 94 | x = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 95 | name=name, lr_mult=lr_mult, reuse=reuse) 96 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 97 | x = Flatten(x) 98 | return x 99 | 100 | def resnet101(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 101 | name = '' if name is None else name 102 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 103 | name=name, lr_mult=lr_mult, reuse=reuse) 104 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 105 | x = Flatten(x) 106 | return x 107 | 108 | def resnet101_largefov(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 109 | name = '' if name is None else name 110 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 111 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 112 | x = Conv(x, num_cls, kernel=(3, 3), dilate=(12, 12), pad=(12, 12), name=name+'fc1', lr_mult=lr_mult, reuse=reuse) 113 | return x 114 | 115 | def resnet101_aspp(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 116 | name = '' if name is None else name 117 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 118 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 119 | x_aspp = [] 120 | for d in (6, 12, 18, 24): 121 | x_aspp.append(Conv(x, num_cls, kernel=(3, 3), dilate=(d, d), pad=(d, d), 122 | name=name+'fc1_aspp%d' % d, lr_mult=lr_mult, reuse=reuse)) 123 | x = sum(x_aspp) 124 | return x 125 | 126 | -------------------------------------------------------------------------------- /lib/models/model/resnet_v1.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def ResStemV1(data, num_filter, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 4 | name=None, lr_mult=1, reuse=None): 5 | name = '' if name is None else name 6 | if bn_data: 7 | x = BN(data, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 8 | name=name+'bn_data', reuse=reuse) 9 | else: 10 | x = data 11 | 12 | x = Conv(x, num_filter=num_filter, kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, 13 | name=name+'conv0', lr_mult=lr_mult, reuse=reuse) 14 | x = BN(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 15 | name=name+'bn0', lr_mult=lr_mult, reuse=reuse) 16 | x = Relu(x, name=name+'relu0') 17 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name=name+'pool0') 18 | return x 19 | 20 | def ResUnitV1(data, num_filter, stride, dilate, projection, bottle_neck, momentum=0.9, eps=1e-5, 21 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 22 | assert name is not None 23 | 24 | if projection: 25 | shortcut = Conv(data, num_filter=num_filter, kernel=(1, 1), stride=(stride,)*2, 26 | pad=(0, 0), no_bias=True, name=name+'_conv0', lr_mult=lr_mult, reuse=reuse) 27 | shortcut = BN(shortcut, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 28 | name=name+'_bn0', lr_mult=lr_mult, reuse=reuse) 29 | else: 30 | shortcut = data 31 | 32 | if bottle_neck: 33 | x = Conv(data, num_filter=int(num_filter/4.), kernel=(1, 1), stride=(stride,)*2, pad=(0, 0), 34 | no_bias=True, name=name+'_conv1', lr_mult=lr_mult, reuse=reuse) 35 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 36 | name=name+'_bn1', lr_mult=lr_mult, reuse=reuse) 37 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(3, 3), stride=(1, 1), pad=(dilate,)*2, 38 | dilate=(dilate,)*2, no_bias=True, name=name+'_conv2', lr_mult=lr_mult, reuse=reuse) 39 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 40 | name=name+'_bn2', lr_mult=lr_mult, reuse=reuse) 41 | x = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 42 | no_bias=True, name=name+'_conv3', lr_mult=lr_mult, reuse=reuse) 43 | x = BN(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 44 | name=name+'_bn3', lr_mult=lr_mult, reuse=reuse) 45 | else: 46 | raise NotImplementedError 47 | 48 | x = x + shortcut 49 | x = Relu(x) 50 | return x 51 | 52 | def ResBlockV1(data, num_unit, num_filter, stride, dilate, bottle_neck, momentum=0.9, eps=1e-5, 53 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 54 | assert name is not None 55 | x = ResUnitV1(data, num_filter, stride, dilate, True, bottle_neck, momentum, eps, 56 | use_global_stats, name+'_unit1', lr_mult, reuse) 57 | for i in range(1, num_unit): 58 | x = ResUnitV1(x, num_filter, 1, dilate, False, bottle_neck, momentum, eps, 59 | use_global_stats, name+'_unit%d'%(i+1), lr_mult, reuse) 60 | return x 61 | 62 | def _Resnet(x, num_units, num_filters, bottle_neck, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 63 | strides=(1, 2, 2, 2), dilates=(1, 1, 1, 1), name=None, lr_mult=1, reuse=None): 64 | name = '' if name is None else name 65 | 66 | x = ResStemV1(x, num_filters[0], momentum, eps, use_global_stats, bn_data, name, lr_mult, reuse) 67 | for i in range(4): 68 | x = ResBlockV1(x, num_units[i], num_filters[i+1], strides[i], dilates[i], bottle_neck, 69 | momentum, eps, use_global_stats, name+'stage%d'%(i+1), lr_mult, reuse) 70 | return x 71 | 72 | def resnet18(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 73 | name = '' if name is None else name 74 | x = _Resnet(x, (2, 2, 2, 2), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 75 | name=name, lr_mult=lr_mult, reuse=reuse) 76 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 77 | x = Flatten(x) 78 | return x 79 | 80 | def resnet34(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 81 | name = '' if name is None else name 82 | x = _Resnet(x, (3, 4, 6, 3), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 83 | name=name, lr_mult=lr_mult, reuse=reuse) 84 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 85 | x = Flatten(x) 86 | return x 87 | 88 | def resnet50(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 89 | name = '' if name is None else name 90 | x = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 91 | name=name, lr_mult=lr_mult, reuse=reuse) 92 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 93 | x = Flatten(x) 94 | return x 95 | 96 | def resnet101(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 97 | name = '' if name is None else name 98 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 99 | name=name, lr_mult=lr_mult, reuse=reuse) 100 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 101 | x = Flatten(x) 102 | return x 103 | 104 | def resnet101_largefov(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 105 | name = '' if name is None else name 106 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 107 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 108 | x = Conv(x, num_cls, kernel=(3, 3), dilate=(12, 12), pad=(12, 12), name=name+'fc1', lr_mult=lr_mult, reuse=reuse) 109 | return x 110 | 111 | def resnet101_aspp(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 112 | name = '' if name is None else name 113 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 114 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 115 | x_aspp = [] 116 | for d in (6, 12, 18, 24): 117 | x_aspp.append(Conv(x, num_cls, kernel=(3, 3), dilate=(d, d), pad=(d, d), 118 | name=name+'fc1_aspp%d' % d, lr_mult=lr_mult, reuse=reuse)) 119 | x = sum(x_aspp) 120 | return x 121 | 122 | -------------------------------------------------------------------------------- /lib/models/model/vgg.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | 4 | def vgg16(x, name=None, lr_mult=1, reuse=None): 5 | name = '' if name is None else name 6 | 7 | x = ConvRelu(x, 64, (3, 3), pad=(1, 1), name=name+'conv1_1', lr_mult=lr_mult, reuse=reuse) 8 | x = ConvRelu(x, 64, (3, 3), pad=(1, 1), name=name+'conv1_2', lr_mult=lr_mult, reuse=reuse) 9 | x = Pool(x, (2, 2), name=name+'pool1') 10 | 11 | x = ConvRelu(x, 128, (3, 3), pad=(1, 1), name=name+'conv2_1', lr_mult=lr_mult, reuse=reuse) 12 | x = ConvRelu(x, 128, (3, 3), pad=(1, 1), name=name+'conv2_2', lr_mult=lr_mult, reuse=reuse) 13 | x = Pool(x, (2, 2), name=name+'pool2') 14 | 15 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_1', lr_mult=lr_mult, reuse=reuse) 16 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_2', lr_mult=lr_mult, reuse=reuse) 17 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_3', lr_mult=lr_mult, reuse=reuse) 18 | x = Pool(x, (2, 2), name=name+'pool3') 19 | 20 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_1', lr_mult=lr_mult, reuse=reuse) 21 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_2', lr_mult=lr_mult, reuse=reuse) 22 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_3', lr_mult=lr_mult, reuse=reuse) 23 | x = Pool(x, (2, 2), name=name+'pool4') 24 | 25 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv5_1', lr_mult=lr_mult, reuse=reuse) 26 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv5_2', lr_mult=lr_mult, reuse=reuse) 27 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv5_3', lr_mult=lr_mult, reuse=reuse) 28 | x = Pool(x, (2, 2), name=name+'pool5') 29 | 30 | x = Flatten(x, name=name+'flatten') 31 | x = FCRelu(x, num_hidden=4096, name=name+'fc6', lr_mult=lr_mult, reuse=reuse) 32 | x = Drop(x, p=0.5, name=name+'drop6') 33 | x = FCRelu(x, num_hidden=4096, name=name+'fc7', lr_mult=lr_mult, reuse=reuse) 34 | x = Drop(x, p=0.5, name=name+'drop7') 35 | return x 36 | 37 | def vgg16_deeplab(x, name=None, lr_mult=1, reuse=None): 38 | name = '' if name is None else name 39 | 40 | x = ConvRelu(x, 64, (3, 3), pad=(1, 1), name=name+'conv1_1', lr_mult=lr_mult, reuse=reuse) 41 | x = ConvRelu(x, 64, (3, 3), pad=(1, 1), name=name+'conv1_2', lr_mult=lr_mult, reuse=reuse) 42 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), name=name+'pool1') 43 | 44 | x = ConvRelu(x, 128, (3, 3), pad=(1, 1), name=name+'conv2_1', lr_mult=lr_mult, reuse=reuse) 45 | x = ConvRelu(x, 128, (3, 3), pad=(1, 1), name=name+'conv2_2', lr_mult=lr_mult, reuse=reuse) 46 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), name=name+'pool2') 47 | 48 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_1', lr_mult=lr_mult, reuse=reuse) 49 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_2', lr_mult=lr_mult, reuse=reuse) 50 | x = ConvRelu(x, 256, (3, 3), pad=(1, 1), name=name+'conv3_3', lr_mult=lr_mult, reuse=reuse) 51 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), name=name+'pool3') 52 | 53 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_1', lr_mult=lr_mult, reuse=reuse) 54 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_2', lr_mult=lr_mult, reuse=reuse) 55 | x = ConvRelu(x, 512, (3, 3), pad=(1, 1), name=name+'conv4_3', lr_mult=lr_mult, reuse=reuse) 56 | x = Pool(x, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name=name+'pool4') 57 | 58 | x = ConvRelu(x, 512, (3, 3), dilate=(2, 2), pad=(2, 2), name=name+'conv5_1', lr_mult=lr_mult, reuse=reuse) 59 | x = ConvRelu(x, 512, (3, 3), dilate=(2, 2), pad=(2, 2), name=name+'conv5_2', lr_mult=lr_mult, reuse=reuse) 60 | x = ConvRelu(x, 512, (3, 3), dilate=(2, 2), pad=(2, 2), name=name+'conv5_3', lr_mult=lr_mult, reuse=reuse) 61 | x = Pool(x, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name=name+'pool5') 62 | x = Pool(x, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name=name+'pool5a', pool_type='avg') 63 | return x 64 | 65 | def vgg16_largefov(x, num_cls, name=None, lr_mult=10, reuse=None): 66 | name = '' if name is None else name 67 | 68 | x = vgg16_deeplab(x, name, lr_mult=1, reuse=reuse) 69 | 70 | x = ConvRelu(x, 1024, (3, 3), dilate=(12, 12), pad=(12, 12), name=name+'fc6', reuse=reuse) 71 | x = Drop(x, 0.5, name=name+'drop6') 72 | 73 | x = ConvRelu(x, 1024, (1, 1), name=name+'fc7', reuse=reuse) 74 | x = Drop(x, 0.5, name=name+'drop7') 75 | 76 | x = Conv(x, num_cls, (1, 1), name=name+'fc8', lr_mult=lr_mult, reuse=reuse) 77 | return x 78 | 79 | def vgg16_aspp(x, num_cls, name=None, lr_mult=10, reuse=None): 80 | name = '' if name is None else name 81 | 82 | x_backbone = vgg16_deeplab(x, name, lr_mult=1, reuse=reuse) 83 | 84 | x_aspp = [] 85 | for d in (6, 12, 18, 24): 86 | x = ConvRelu(x_backbone, 1024, (3, 3), dilate=(d, d), pad=(d, d), name=name+'fc6_aspp%d'%d, reuse=reuse) 87 | x = Drop(x, 0.5) 88 | 89 | x = ConvRelu(x, 1024, (1, 1), name=name+'fc7_aspp%d'%d, reuse=reuse) 90 | x = Drop(x, 0.5) 91 | 92 | x = Conv(x, num_cls, (1, 1), name=name+'fc8_aspp%d'%d, lr_mult=lr_mult, reuse=reuse) 93 | x_aspp.append(x) 94 | 95 | x = sum(x_aspp) 96 | return x 97 | 98 | -------------------------------------------------------------------------------- /lib/models/model/wide_resnet.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def wResStem(data, num_filter, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 4 | name=None, lr_mult=1, reuse=None): 5 | name = '' if name is None else name 6 | if bn_data: 7 | x = BN(data, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 8 | name=name+'bn_data', reuse=reuse) 9 | else: 10 | x = data 11 | 12 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1), no_bias=True, 13 | name=name+'conv1a', lr_mult=lr_mult, reuse=reuse) 14 | return x 15 | 16 | 17 | def wResUnit(data, num_filter, stride, dilate, projection, bottle_neck, dropout=0, momentum=0.9, eps=1e-5, 18 | use_global_stats=False, name=None, lr_mult=1, reuse=None, **kwargs): 19 | assert name is not None 20 | 21 | x = BNRelu(data, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 22 | name='bn'+name+'_branch2a', lr_mult=lr_mult, reuse=reuse) 23 | 24 | if projection: 25 | shortcut = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(stride,)*2, 26 | pad=(0, 0), no_bias=True, name='res'+name+'_branch1', lr_mult=lr_mult, reuse=reuse) 27 | else: 28 | shortcut = data 29 | 30 | if bottle_neck: 31 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(1, 1), stride=(1, 1), pad=(0, 0), 32 | no_bias=True, name='res'+name+'_branch2a', lr_mult=lr_mult, reuse=reuse) 33 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 34 | name='bn'+name+'_branch2b1', lr_mult=lr_mult, reuse=reuse) 35 | if dropout > 0: 36 | x = Drop(x, p=dropout) 37 | x = Conv(x, num_filter=int(num_filter/2.), kernel=(3, 3), stride=(stride,)*2, pad=(dilate,)*2, 38 | dilate=(dilate,)*2, no_bias=True, name='res'+name+'_branch2b1', lr_mult=lr_mult, reuse=reuse) 39 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 40 | name='bn'+name+'_branch2b2', lr_mult=lr_mult, reuse=reuse) 41 | if dropout > 0: 42 | x = Drop(x, p=dropout) 43 | x = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 44 | no_bias=True, name='res'+name+'_branch2b2', lr_mult=lr_mult, reuse=reuse) 45 | else: 46 | #mid_filter = num_filter//2 if name in ['5', '5a', '5b1', '5b2'] else num_filter 47 | mid_filter = kwargs.get('mid_filter', num_filter) 48 | fst_dilate = kwargs.get('fst_dilate', dilate) 49 | x = Conv(x, num_filter=mid_filter, kernel=(3, 3), stride=(stride,)*2, pad=(fst_dilate,)*2, 50 | dilate=(fst_dilate,)*2, no_bias=True, name='res'+name+'_branch2a', lr_mult=lr_mult, reuse=reuse) 51 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 52 | name='bn'+name+'_branch2b1', lr_mult=lr_mult, reuse=reuse) 53 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(dilate,)*2, 54 | dilate=(dilate,)*2, no_bias=True, name='res'+name+'_branch2b1', lr_mult=lr_mult, reuse=reuse) 55 | 56 | x = x + shortcut 57 | return x 58 | 59 | 60 | def wResBlock(data, num_unit, num_filter, stride, dilate, bottle_neck, dropout=0, momentum=0.9, eps=1e-5, 61 | use_global_stats=False, name=None, lr_mult=1, reuse=None, **kwargs): 62 | assert name is not None 63 | x = wResUnit(data, num_filter, stride, dilate, True, bottle_neck, dropout, momentum, eps, 64 | use_global_stats, name=name+'a', lr_mult=lr_mult, reuse=reuse, **kwargs) 65 | for i in range(1, num_unit): 66 | x = wResUnit(x, num_filter, 1, dilate, False, bottle_neck, dropout, momentum, eps, 67 | use_global_stats, name=name+'b%d'%i, lr_mult=lr_mult, reuse=reuse, **kwargs) 68 | return x 69 | 70 | 71 | def wresnet38(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, out_internals=False, lr_mult=1, reuse=None): 72 | name = '' if name is None else name 73 | internals = [] 74 | 75 | x = wResStem(x, 64, momentum, eps, use_global_stats, bn_data=True, name=name, lr_mult=lr_mult, reuse=reuse) 76 | x = wResBlock(x, 3, 128, 2, 1, False, 0, momentum, eps, use_global_stats, name+'2', lr_mult, reuse) 77 | x = wResBlock(x, 3, 256, 2, 1, False, 0, momentum, eps, use_global_stats, name+'3', lr_mult, reuse) 78 | x = wResBlock(x, 6, 512, 2, 1, False, 0, momentum, eps, use_global_stats, name+'4', lr_mult, reuse) 79 | x = wResBlock(x, 3, 1024, 1, 2, False, 0, momentum, eps, use_global_stats, name+'5', lr_mult, reuse, mid_filter=512, fst_dilate=1) 80 | internals.append(x) 81 | x = wResBlock(x, 1, 2048, 1, 4, True, 0.3, momentum, eps, use_global_stats, name+'6', lr_mult, reuse) 82 | internals.append(x) 83 | x = wResBlock(x, 1, 4096, 1, 4, True, 0.5, momentum, eps, use_global_stats, name+'7', lr_mult, reuse) 84 | internals.append(x) 85 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 86 | name=name+'bn7', lr_mult=lr_mult, reuse=reuse) 87 | 88 | if out_internals: 89 | return x, internals 90 | else: 91 | return x 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /lib/models/multi_scale.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def MultiScale(scales): 4 | scales = [s for s in scales if s != 1] 5 | 6 | def func_wrapper(model_func): 7 | def model_func_ms(*args, **kwargs): 8 | assert len(args) > 0, 'Cannot find input variable' 9 | input_var = args[0] 10 | args = args[1:] 11 | 12 | out_0 = model_func(*((input_var,) + args), **kwargs) 13 | assert len(out_0) == 1, 'Only single output implemented' 14 | 15 | reuse = kwargs.get('reuse', None) 16 | if reuse is None: 17 | reuse = out_0 18 | if 'reuse' in kwargs: 19 | del kwargs['reuse'] 20 | 21 | is_tensor4d = len(out_0.infer_shape(data=(1, 3, 100, 100))[1][0]) == 4 22 | 23 | out_ms = [out_0] 24 | for scale in scales: 25 | input_var_s = mx.sym.Custom(input_var, scale=scale, op_type='BilinearScale') 26 | out_s = model_func(*((input_var_s,) + args), reuse=reuse, **kwargs) 27 | if is_tensor4d: 28 | out_s = mx.sym.Custom(out_s, out_0, op_type='BilinearScaleLike') 29 | out_ms.append(out_s) 30 | 31 | out_max = out_ms[0] 32 | for out_s in out_ms[1:]: 33 | out_max = mx.sym.maximum(out_max, out_s) 34 | out_ms.append(out_max) 35 | 36 | return mx.sym.Group(out_ms) 37 | return model_func_ms 38 | return func_wrapper 39 | -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | 3 | def ResStem(data, num_filter, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 4 | name=None, lr_mult=1, reuse=None): 5 | name = '' if name is None else name 6 | if bn_data: 7 | x = BN(data, fix_gamma=True, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 8 | name=name+'bn_data', reuse=reuse) 9 | else: 10 | x = data 11 | 12 | x = Conv(x, num_filter=num_filter, kernel=(7, 7), stride=(2, 2), pad=(3, 3), no_bias=True, 13 | name=name+'conv0', lr_mult=lr_mult, reuse=reuse) 14 | x = BN(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 15 | name=name+'bn0', lr_mult=lr_mult, reuse=reuse) 16 | x = Relu(x, name=name+'relu0') 17 | x = Pool(x, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max', name=name+'pool0') 18 | return x 19 | 20 | def ResUnit(data, num_filter, stride, dilate, projection, bottle_neck, momentum=0.9, eps=1e-5, 21 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 22 | assert name is not None 23 | x = BNRelu(data, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 24 | name=name+'_bn1', lr_mult=lr_mult, reuse=reuse) 25 | 26 | if projection: 27 | shortcut = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(stride,)*2, 28 | pad=(0, 0), no_bias=True, name=name+'_sc', lr_mult=lr_mult, reuse=reuse) 29 | else: 30 | shortcut = data 31 | 32 | if bottle_neck: 33 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(1, 1), stride=(1, 1), pad=(0, 0), 34 | no_bias=True, name=name+'_conv1', lr_mult=lr_mult, reuse=reuse) 35 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 36 | name=name+'_bn2', lr_mult=lr_mult, reuse=reuse) 37 | x = Conv(x, num_filter=int(num_filter/4.), kernel=(3, 3), stride=(stride,)*2, pad=(dilate,)*2, 38 | dilate=(dilate,)*2, no_bias=True, name=name+'_conv2', lr_mult=lr_mult, reuse=reuse) 39 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 40 | name=name+'_bn3', lr_mult=lr_mult, reuse=reuse) 41 | x = Conv(x, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 42 | no_bias=True, name=name+'_conv3', lr_mult=lr_mult, reuse=reuse) 43 | else: 44 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(stride,)*2, pad=(dilate,)*2, 45 | dilate=(dilate,)*2, no_bias=True, name=name+'_conv1', lr_mult=lr_mult, reuse=reuse) 46 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 47 | name=name+'_bn2', lr_mult=lr_mult, reuse=reuse) 48 | x = Conv(x, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1), 49 | no_bias=True, name=name+'_conv2', lr_mult=lr_mult, reuse=reuse) 50 | 51 | x = x + shortcut 52 | return x 53 | 54 | def ResBlock(data, num_unit, num_filter, stride, dilate, bottle_neck, momentum=0.9, eps=1e-5, 55 | use_global_stats=False, name=None, lr_mult=1, reuse=None): 56 | assert name is not None 57 | x = ResUnit(data, num_filter, stride, dilate, True, bottle_neck, momentum, eps, 58 | use_global_stats, name+'_unit1', lr_mult, reuse) 59 | for i in range(1, num_unit): 60 | x = ResUnit(x, num_filter, 1, dilate, False, bottle_neck, momentum, eps, 61 | use_global_stats, name+'_unit%d'%(i+1), lr_mult, reuse) 62 | return x 63 | 64 | def _Resnet(x, num_units, num_filters, bottle_neck, momentum=0.9, eps=1e-5, use_global_stats=False, bn_data=True, 65 | strides=(1, 2, 2, 2), dilates=(1, 1, 1, 1), name=None, lr_mult=1, reuse=None): 66 | name = '' if name is None else name 67 | 68 | x = ResStem(x, num_filters[0], momentum, eps, use_global_stats, bn_data, name, lr_mult, reuse) 69 | for i in range(4): 70 | x = ResBlock(x, num_units[i], num_filters[i+1], strides[i], dilates[i], bottle_neck, 71 | momentum, eps, use_global_stats, name+'stage%d'%(i+1), lr_mult, reuse) 72 | x = BNRelu(x, fix_gamma=False, momentum=momentum, eps=eps, use_global_stats=use_global_stats, 73 | name=name+'bn1', lr_mult=lr_mult, reuse=reuse) 74 | return x 75 | 76 | def resnet18(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 77 | name = '' if name is None else name 78 | x = _Resnet(x, (2, 2, 2, 2), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 79 | name=name, lr_mult=lr_mult, reuse=reuse) 80 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 81 | x = Flatten(x) 82 | return x 83 | 84 | def resnet34(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 85 | name = '' if name is None else name 86 | x = _Resnet(x, (3, 4, 6, 3), (64, 64, 128, 256, 512), False, momentum, eps, use_global_stats, 87 | name=name, lr_mult=lr_mult, reuse=reuse) 88 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 89 | x = Flatten(x) 90 | return x 91 | 92 | def resnet50(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 93 | name = '' if name is None else name 94 | x = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 95 | name=name, lr_mult=lr_mult, reuse=reuse) 96 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 97 | x = Flatten(x) 98 | return x 99 | 100 | def resnet101(x, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=1, reuse=None): 101 | name = '' if name is None else name 102 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 103 | name=name, lr_mult=lr_mult, reuse=reuse) 104 | x = Pool(x, (1, 1), pool_type='avg', global_pool=True) 105 | x = Flatten(x) 106 | return x 107 | 108 | def resnet50_largefov(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 109 | name = '' if name is None else name 110 | x = _Resnet(x, (3, 4, 6, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 111 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 112 | x = Conv(x, num_cls, kernel=(3, 3), dilate=(12, 12), pad=(12, 12), name=name+'fc1', lr_mult=lr_mult, reuse=reuse) 113 | return x 114 | 115 | def resnet101_largefov(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 116 | name = '' if name is None else name 117 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 118 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 119 | x = Conv(x, num_cls, kernel=(3, 3), dilate=(12, 12), pad=(12, 12), name=name+'fc1', lr_mult=lr_mult, reuse=reuse) 120 | return x 121 | 122 | ''' 123 | def resnet101_aspp(x, num_cls, momentum=0.9, eps=1e-5, use_global_stats=False, name=None, lr_mult=10, reuse=None): 124 | name = '' if name is None else name 125 | x = _Resnet(x, (3, 4, 23, 3), (64, 256, 512, 1024, 2048), True, momentum, eps, use_global_stats, 126 | strides=(1, 2, 1, 1), dilates=(1, 1, 2, 4), name=name, lr_mult=1, reuse=reuse) 127 | x_aspp = [] 128 | for d in (6, 12, 18, 24): 129 | x_aspp.append(Conv(x, num_cls, kernel=(3, 3), dilate=(d, d), pad=(d, d), 130 | name=name+'fc1_aspp%d' % d, lr_mult=lr_mult, reuse=reuse)) 131 | x = sum(x_aspp) 132 | return x 133 | ''' 134 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mxnet_tools import * 2 | from .image_tools import * 3 | from .dataset_tools import VOC 4 | -------------------------------------------------------------------------------- /lib/utils/dataset_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xml.etree.ElementTree as ET 3 | 4 | class _VOC_proto(object): 5 | @staticmethod 6 | def _get_palette(): 7 | def bitget(bit, idx): 8 | return (bit & (1 << idx)) > 0 9 | cmap = [] 10 | for i in range(256): 11 | r, g, b = 0, 0, 0 12 | idx = i 13 | for j in range(8): 14 | r = r | (bitget(idx, 0) << (7 - j)) 15 | g = g | (bitget(idx, 1) << (7 - j)) 16 | b = b | (bitget(idx, 2) << (7 - j)) 17 | idx = idx >> 3 18 | cmap.append((b, g, r)) 19 | return np.array(cmap).astype(np.uint8) 20 | 21 | def __init__(self): 22 | self.categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 23 | 'bus', 'car', 'cat', 'chair', 'cow', 24 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 25 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] 26 | self.palette = self._get_palette() 27 | 28 | def name2index(self, name): 29 | return self.categories.index(name) 30 | 31 | def index2name(self, index): 32 | return self.categories[index] 33 | 34 | def get_annotation(self, filename, use_diff=False): 35 | tree = ET.parse(filename) 36 | root = tree.getroot() 37 | annotation = [] 38 | tmp_annotation = [] 39 | for obj in root.findall('object'): 40 | cat = obj.find('name').text 41 | non_diff = 1 - int(obj.find('difficult').text) 42 | if use_diff or non_diff: 43 | annotation.append(self.name2index(cat)) 44 | else: 45 | tmp_annotation.append(self.name2index(cat)) 46 | annotation = list(set(annotation)) 47 | 48 | if len(annotation) == 0: 49 | annotation += list(set(tmp_annotation)) 50 | annotation.sort() 51 | return annotation 52 | 53 | VOC = _VOC_proto() 54 | 55 | 56 | -------------------------------------------------------------------------------- /lib/utils/image_tools.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import pickle 4 | import numpy as np 5 | 6 | 7 | def imwrite(filename, image): 8 | dirname = os.path.dirname(filename) 9 | if not os.path.exists(dirname): 10 | try: 11 | os.makedirs(dirname) 12 | except: 13 | pass 14 | cv2.imwrite(filename, image) 15 | 16 | def npsave(filename, data): 17 | dirname = os.path.dirname(filename) 18 | if not os.path.exists(dirname): 19 | try: 20 | os.makedirs(dirname) 21 | except: 22 | pass 23 | np.save(filename, data) 24 | 25 | def pkldump(filename, data): 26 | dirname = os.path.dirname(filename) 27 | if not os.path.exists(dirname): 28 | try: 29 | os.makedirs(dirname) 30 | except: 31 | pass 32 | with open(filename, 'wb') as f: 33 | pickle.dump(data, f) 34 | 35 | def imhstack(images, height=None): 36 | images = as_list(images) 37 | images = list(map(image2C3, images)) 38 | 39 | if height is None: 40 | height = np.array([img.shape[0] for img in images]).max() 41 | images = [resize_height(img, height) for img in images] 42 | 43 | if len(images) == 1: 44 | return images[0] 45 | 46 | images = [[img, np.full((height, 3, 3), 255, np.uint8)] for img in images] 47 | images = np.hstack(sum(images, [])) 48 | return images 49 | 50 | def imvstack(images, width=None): 51 | images = as_list(images) 52 | images = list(map(image2C3, images)) 53 | 54 | if width is None: 55 | width = np.array([img.shape[1] for img in images]).max() 56 | images = [resize_width(img, width) for img in images] 57 | 58 | if len(images) == 1: 59 | return images[0] 60 | 61 | images = [[img, np.full((3, width, 3), 255, np.uint8)] for img in images] 62 | images = np.vstack(sum(images, [])) 63 | return images 64 | 65 | def as_list(data): 66 | if not isinstance(data, (list, tuple)): 67 | return [data] 68 | return list(data) 69 | 70 | def image2C3(image): 71 | if image.ndim == 3: 72 | return image 73 | if image.ndim == 2: 74 | return np.repeat(image[..., np.newaxis], 3, axis=2) 75 | raise ValueError("image.ndim = {}, invalid image.".format(image.ndim)) 76 | 77 | def resize_height(image, height): 78 | if image.shape[0] == height: 79 | return image 80 | h, w = image.shape[:2] 81 | width = height * w // h 82 | image = cv2.resize(image, (width, height)) 83 | return image 84 | 85 | def resize_width(image, width): 86 | if image.shape[1] == width: 87 | return image 88 | h, w = image.shape[:2] 89 | height = width * h // w 90 | image = cv2.resize(image, (width, height)) 91 | return image 92 | 93 | def imtext(image, text, space=(3, 3), color=(0, 0, 0), thickness=1, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1.): 94 | assert isinstance(text, str), type(text) 95 | size = cv2.getTextSize(text, fontFace, fontScale, thickness) 96 | image = cv2.putText(image, text, (space[0], size[1]+space[1]), fontFace, fontScale, color, thickness) 97 | return image 98 | 99 | -------------------------------------------------------------------------------- /lib/utils/mxnet_tools.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import time 4 | import os 5 | import logging 6 | from datetime import datetime 7 | from subprocess import call 8 | from types import ModuleType 9 | 10 | def setGPU(gpus): 11 | len_gpus = len(gpus.split(',')) 12 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 13 | gpus = ','.join(map(str, range(len_gpus))) 14 | return gpus 15 | 16 | def getTime(): 17 | return datetime.now().strftime('%m-%d %H:%M:%S') 18 | 19 | class Timer(object): 20 | curr_record = None 21 | prev_record = None 22 | 23 | @classmethod 24 | def record(cls): 25 | cls.prev_record = cls.curr_record 26 | cls.curr_record = time.time() 27 | 28 | @classmethod 29 | def interval(cls): 30 | if cls.prev_record is None: 31 | return 0 32 | return cls.curr_record - cls.prev_record 33 | 34 | def wrapColor(string, color): 35 | try: 36 | header = { 37 | 'red': '\033[91m', 38 | 'green': '\033[92m', 39 | 'yellow': '\033[93m', 40 | 'blue': '\033[94m', 41 | 'purple': '\033[95m', 42 | 'cyan': '\033[96m', 43 | 'darkcyan': '\033[36m', 44 | 'bold': '\033[1m', 45 | 'underline': '\033[4m'}[color.lower()] 46 | except KeyError: 47 | raise ValueError("Unknown color: {}".format(color)) 48 | return header + string + '\033[0m' 49 | 50 | def info(logger, msg, color=None): 51 | msg = '[{}]'.format(getTime()) + msg 52 | if logger is not None: 53 | logger.info(msg) 54 | 55 | if color is not None: 56 | msg = wrapColor(msg, color) 57 | print(msg) 58 | 59 | def summaryArgs(logger, args, color=None): 60 | if isinstance(args, ModuleType): 61 | args = vars(args) 62 | keys = [key for key in args.keys() if key[:2] != '__'] 63 | keys.sort() 64 | length = max([len(x) for x in keys]) 65 | msg = [('{:<'+str(length)+'}: {}').format(k, args[k]) for k in keys] 66 | 67 | msg = '\n' + '\n'.join(msg) 68 | info(logger, msg, color) 69 | 70 | def loadParams(filename): 71 | data = mx.nd.load(filename) 72 | arg_params, aux_params = {}, {} 73 | for name, value in data.items(): 74 | if name[:3] == 'arg': 75 | arg_params[name[4:]] = value 76 | elif name[:3] == 'aux': 77 | aux_params[name[4:]] = value 78 | if len(arg_params) == 0: 79 | arg_params = None 80 | if len(aux_params) == 0: 81 | aux_params = None 82 | return arg_params, aux_params 83 | 84 | class SaveParams(object): 85 | def __init__(self, model, snapshot, model_name, num_save=5): 86 | self.model = model 87 | self.snapshot = snapshot 88 | self.model_name = model_name 89 | self.num_save = num_save 90 | self.save_params = [] 91 | 92 | def save(self, n_epoch): 93 | self.save_params += [ 94 | os.path.join(self.snapshot, '{}-{:04d}.params'.format(self.model_name, n_epoch)), 95 | os.path.join(self.snapshot, '{}-{:04d}.states'.format(self.model_name, n_epoch))] 96 | self.model.save_params(self.save_params[-2]) 97 | self.model.save_optimizer_states(self.save_params[-1]) 98 | 99 | if len(self.save_params) > 2 * self.num_save: 100 | call(['rm', self.save_params[0], self.save_params[1]]) 101 | self.save_params = self.save_params[2:] 102 | return self.save_params[-2:] 103 | 104 | def __call__(self, n_epoch): 105 | return self.save(n_epoch) 106 | 107 | def getLogger(snapshot, model_name): 108 | if not os.path.exists(snapshot): 109 | os.makedirs(snapshot) 110 | logging.basicConfig(filename=os.path.join(snapshot, model_name+'.log'), level=logging.INFO) 111 | logger = logging.getLogger() 112 | return logger 113 | 114 | class LrScheduler(object): 115 | def __init__(self, method, init_lr, kwargs): 116 | self.method = method 117 | self.init_lr = init_lr 118 | 119 | if method == 'step': 120 | self.step_list = kwargs['step_list'] 121 | self.factor = kwargs['factor'] 122 | self.get = self._step 123 | elif method == 'poly': 124 | self.num_epoch = kwargs['num_epoch'] 125 | self.power = kwargs['power'] 126 | self.get = self._poly 127 | elif method == 'ramp': 128 | self.ramp_up = kwargs['ramp_up'] 129 | self.ramp_down = kwargs['ramp_down'] 130 | self.num_epoch = kwargs['num_epoch'] 131 | self.scale = kwargs['scale'] 132 | self.get = self._ramp 133 | else: 134 | raise ValueError(method) 135 | 136 | def _step(self, current_epoch): 137 | lr = self.init_lr 138 | step_list = [x for x in self.step_list] 139 | while len(step_list) > 0 and current_epoch >= step_list[0]: 140 | lr *= self.factor 141 | del step_list[0] 142 | return lr 143 | 144 | def _poly(self, current_epoch): 145 | lr = self.init_lr * ((1. - float(current_epoch)/self.num_epoch)**self.power) 146 | return lr 147 | 148 | def _ramp(self, current_epoch): 149 | if current_epoch < self.ramp_up: 150 | decay = np.exp(-(1 - float(current_epoch)/self.ramp_up)**2 * self.scale) 151 | elif current_epoch > (self.num_epoch - self.ramp_down): 152 | decay = np.exp(-(float(current_epoch+self.ramp_down-self.num_epoch)/self.ramp_down)**2 * self.scale) 153 | else: 154 | decay = 1. 155 | lr = self.init_lr * decay 156 | return lr 157 | 158 | class GradBuffer(object): 159 | def __init__(self, model): 160 | self.model = model 161 | self.cache = None 162 | 163 | def write(self): 164 | if self.cache is None: 165 | self.cache = [[None if g is None else g.copyto(g.context) for g in g_list]\ 166 | for g_list in self.model._exec_group.grad_arrays] 167 | else: 168 | for gs_src, gs_dst in zip(self.model._exec_group.grad_arrays, self.cache): 169 | for g_src, g_dst in zip(gs_src, gs_dst): 170 | if g_src is None: 171 | continue 172 | g_src.copyto(g_dst) 173 | 174 | def read_add(self): 175 | assert self.cache is not None 176 | for gs_src, gs_dst in zip(self.model._exec_group.grad_arrays, self.cache): 177 | for g_src, g_dst in zip(gs_src, gs_dst): 178 | if g_src is None: 179 | continue 180 | g_src += g_dst 181 | 182 | def initNormal(mean, std, name, shape): 183 | if name.endswith('_weight'): 184 | return mx.nd.normal(mean, std, shape) 185 | if name.endswith('_bias'): 186 | return mx.nd.zeros(shape) 187 | if name.endswith('_gamma'): 188 | return mx.nd.ones(shape) 189 | if name.endswith('_beta'): 190 | return mx.nd.zeros(shape) 191 | if name.endswith('_moving_mean'): 192 | return mx.nd.zeros(shape) 193 | if name.endswith('_moving_var'): 194 | return mx.nd.ones(shape) 195 | raise ValueError("Unknown name type for `{}`".format(name)) 196 | 197 | def checkParams(mod, arg_params, aux_params, auto_fix=True, initializer=mx.init.Normal(0.01), logger=None): 198 | arg_params = {} if arg_params is None else arg_params 199 | aux_params = {} if aux_params is None else aux_params 200 | 201 | arg_shapes = {name: array[0].shape for name, array in \ 202 | zip(mod._exec_group.param_names, mod._exec_group.param_arrays)} 203 | aux_shapes = {name: array[0].shape for name, array in \ 204 | zip(mod._exec_group.aux_names, mod._exec_group.aux_arrays)} 205 | 206 | extra_arg_params, extra_aux_params = [], [] 207 | for name in arg_params.keys(): 208 | if name not in arg_shapes: 209 | extra_arg_params.append(name) 210 | for name in aux_params.keys(): 211 | if name not in aux_shapes: 212 | extra_aux_params.append(name) 213 | 214 | miss_arg_params, miss_aux_params = [], [] 215 | for name in arg_shapes.keys(): 216 | if name not in arg_params: 217 | miss_arg_params.append(name) 218 | for name in aux_shapes.keys(): 219 | if name not in aux_params: 220 | miss_aux_params.append(name) 221 | 222 | mismatch_arg_params, mismatch_aux_params = [], [] 223 | for name in arg_params.keys(): 224 | if (name in arg_shapes) and (arg_shapes[name] != arg_params[name].shape): 225 | mismatch_arg_params.append(name) 226 | for name in aux_params.keys(): 227 | if (name in aux_shapes) and (aux_shapes[name] != aux_params[name].shape): 228 | mismatch_aux_params.append(name) 229 | 230 | for name in extra_arg_params: 231 | info(logger, "Find extra arg_params: {}: given {}".format(name, arg_params[name].shape), 'red') 232 | for name in extra_aux_params: 233 | info(logger, "Find extra aux_params: {}: given {}".format(name, aux_params[name].shape), 'red') 234 | for name in miss_arg_params: 235 | info(logger, "Find missing arg_params: {}: target {}".format(name, arg_shapes[name]), 'red') 236 | for name in miss_aux_params: 237 | info(logger, "Find missing aux_params: {}: target {}".format(name, aux_shapes[name]), 'red') 238 | for name in mismatch_arg_params: 239 | info(logger, "Find mismatch arg_params: {}: given {}, target {}".format( 240 | name, arg_params[name].shape, arg_shapes[name]), 'red') 241 | for name in mismatch_aux_params: 242 | info(logger, "Find mismatch aux_params: {}: given {}, target {}".format( 243 | name, aux_params[name].shape, aux_shapes[name]), 'red') 244 | 245 | if len(extra_arg_params + extra_aux_params + \ 246 | miss_arg_params + miss_aux_params + \ 247 | mismatch_arg_params + mismatch_aux_params) == 0: 248 | return arg_params, aux_params 249 | 250 | if not auto_fix: 251 | info(logger, "Bad params not fixed.", 'red') 252 | return arg_params, aux_params 253 | 254 | for name in (extra_arg_params + mismatch_arg_params): 255 | del arg_params[name] 256 | for name in (extra_aux_params + mismatch_aux_params): 257 | del aux_params[name] 258 | 259 | attrs = mod._symbol.attr_dict() 260 | for name in (miss_arg_params + mismatch_arg_params): 261 | arg_params[name] = mx.nd.zeros(arg_shapes[name]) 262 | try: 263 | initializer(mx.init.InitDesc(name, attrs.get(name, None)), arg_params[name]) 264 | except ValueError: 265 | initializer(name, arg_params[name]) 266 | for name in (miss_aux_params + mismatch_aux_params): 267 | aux_params[name] = mx.nd.zeros(aux_shapes[name]) 268 | try: 269 | initializer(mx.init.InitDesc(name, attrs.get(name, None)), aux_params[name]) 270 | except ValueError: 271 | initializer(name, aux_params[name]) 272 | info(logger, "Bad params auto fixed successfully.", 'red') 273 | return arg_params, aux_params 274 | 275 | 276 | -------------------------------------------------------------------------------- /run_cian.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPUS=0,1 3 | GPU_IDS=(${GPUS//,/ }) 4 | 5 | # Config dataset path 6 | DATASET=/path/of/VOC2012 7 | image_root=$DATASET/JPEGImages 8 | annotation_root=$DATASET/Annotations 9 | groundtruth_root=$DATASET/extra/SegmentationClassAug 10 | 11 | # Config model 12 | model=resnet101_largefov 13 | pretrained=./data/pretrained/resnet-101-0000.params 14 | seeds=./data/Seeds/CIAN_SEEDS 15 | snapshot=./snapshot/CIAN/$model 16 | 17 | 18 | # ========== Training & Testing ========== 19 | # Train 20 | python ./scripts/train_infer_segment.py --image-root $image_root --label-root $seeds --annotation-root $annotation_root --snapshot $snapshot --model $model --pretrained $pretrained --gpus $GPUS 21 | 22 | # Test on val set 23 | ( IFS=$'\n'; echo "${!GPU_IDS[*]}" ) | xargs -I{} -P${#GPU_IDS[@]} python ./scripts/train_infer_segment.py --image-root $image_root --annotation-root $annotation_root --snapshot $snapshot --model $model --gpus $GPUS --infer --pid {} 24 | python ./scripts/eval_segment.py --groundtruth-root $groundtruth_root --prediction-root $snapshot/pred_crf 25 | 26 | 27 | 28 | # ========== Retraining & Testing ========== 29 | # Generate new seeds 30 | ( IFS=$'\n'; echo "${!GPU_IDS[*]}" ) | xargs -I{} -P${#GPU_IDS[@]} python ./scripts/generate_retrain.py --image-root $image_root --annotation-root $annotation_root --snapshot $snapshot --model $model --gpus $GPUS --pid {} 31 | 32 | # Train 33 | seeds=$snapshot/train_aug_pred_crf 34 | python ./scripts/train_infer_segment.py --image-root $image_root --label-root $seeds --annotation-root $annotation_root --snapshot $snapshot --model $model --pretrained $pretrained --gpus $GPUS --retrain 35 | 36 | # Test on val set 37 | ( IFS=$'\n'; echo "${!GPU_IDS[*]}" ) | xargs -I{} -P${#GPU_IDS[@]} python ./scripts/train_infer_segment.py --image-root $image_root --annotation-root $annotation_root --snapshot $snapshot --model $model --gpus $GPUS --infer --pid {} --retrain 38 | python ./scripts/eval_segment.py --groundtruth-root $groundtruth_root --prediction-root $snapshot'_retrain/pred_crf' 39 | 40 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/js-fan/CIAN/680c6259c8679123107ea3b3ee1d48a1b70d8179/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/eval_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import multiprocessing as mp 4 | import cv2 5 | import os 6 | 7 | def run_eval(data_list, pred_root, gt_root, num_cls): 8 | def compute_confusion_matrix(names, label_root, pred_root, num_cls, num_threads=16, arr_=None): 9 | if num_threads == 1: 10 | mat = np.zeros((num_cls, num_cls), np.float32) 11 | for name in names: 12 | gt = cv2.imread(os.path.join(label_root, name+'.png'), 0).astype(np.int32) 13 | pred = cv2.imread(os.path.join(pred_root, name+'.png'), 0).astype(np.int32) 14 | if gt.shape != pred.shape: 15 | info(None, "NAME {}, gt.shape != pred.shape: [{} vs. {}]".format(name, gt.shape, pred.shape), 'red') 16 | continue 17 | 18 | valid = gt < num_cls 19 | mat += np.bincount(gt[valid] * num_cls + pred[valid], minlength=num_cls**2).reshape(num_cls, -1) 20 | 21 | if arr_ is not None: 22 | arr_mat = np.frombuffer(arr_.get_obj(), np.float32) 23 | arr_mat += mat.ravel() 24 | return mat 25 | else: 26 | workload = np.full((num_threads,), len(names)//num_threads, np.int32) 27 | if workload.sum() < len(names): 28 | workload[:(len(names) - workload.sum())] += 1 29 | workload = np.cumsum(np.hstack([0, workload])) 30 | 31 | names_split = [names[i:j] for i, j in zip(workload[:-1], workload[1:])] 32 | 33 | arr_ = mp.Array('f', np.zeros((num_cls * num_cls,), np.float32)) 34 | mat = np.frombuffer(arr_.get_obj(), np.float32).reshape(num_cls, -1) 35 | 36 | jobs = [mp.Process(target=compute_confusion_matrix, args=(_names, label_root, pred_root, num_cls, 1, arr_)) \ 37 | for _names in names_split] 38 | res = [job.start() for job in jobs] 39 | res = [job.join() for job in jobs] 40 | return mat.copy() 41 | 42 | def compute_eval_results(confmat): 43 | iou = np.diag(confmat) / np.maximum(confmat.sum(axis=0) + confmat.sum(axis=1) - np.diag(confmat), 1e-10) 44 | return iou 45 | 46 | # 47 | with open(data_list) as f: 48 | names = [x.strip() for x in f.readlines()] 49 | 50 | confmat = compute_confusion_matrix(names, gt_root, pred_root, num_cls) 51 | iou = compute_eval_results(confmat) 52 | 53 | msg = "mIOU: {}\n{}\n\n".format(iou.mean(), iou) 54 | print(msg) 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | 59 | parser.add_argument('--num-cls', type=int, default=21) 60 | parser.add_argument('--data-list', type=str, default='./data/VOC2012/val.txt') 61 | parser.add_argument('--prediction-root', type=str, required=True) 62 | parser.add_argument('--groundtruth-root', type=str, required=True) 63 | 64 | args = parser.parse_args() 65 | run_eval(args.data_list, args.prediction_root, args.groundtruth_root, args.num_cls) 66 | 67 | -------------------------------------------------------------------------------- /scripts/generate_retrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lib.loader import VOCSegGroupLoader, VOCSegLoader 3 | from lib.utils import * 4 | from lib.models import * 5 | import pydensecrf.densecrf as dcrf 6 | from scripts.train_infer_segment import build_model 7 | 8 | 9 | def generate_retrain_seeds(args, pid=-1): 10 | data_slice = None 11 | if pid >=0: 12 | gpus = args.gpus.split(',') 13 | data_slice = slice(pid, None, len(gpus)) 14 | args.gpus = gpus[pid] 15 | args.batch_size = len(args.gpus.split(',')) 16 | mod = build_model(args, False) 17 | 18 | with open(args.data_list, 'r') as f: 19 | data_names = [x.strip() for x in f.readlines()] 20 | if data_slice is not None: 21 | data_names = data_names[data_slice] 22 | image_src_list = [os.path.join(args.image_root, name+'.jpg') for name in data_names] 23 | pred_root = args.snapshot 24 | 25 | for name, img_src in zip(data_names, image_src_list): 26 | img = cv2.imread(img_src)[..., ::-1].copy() 27 | h, w = img.shape[:2] 28 | 29 | img_ = cv2.resize(img, (args.image_size, args.image_size)) 30 | batch = mx.io.DataBatch(data=[mx.nd.array(img_[np.newaxis].transpose(0,3,1,2))]) 31 | 32 | mod.forward(batch, is_train=False) 33 | probs = mod.get_outputs()[0].copy() 34 | 35 | if not args.no_mirror: 36 | batch2 = mx.io.DataBatch(data=[batch.data[0][:, :, :, ::-1]]) 37 | mod.forward(batch2, is_train=False) 38 | probs_mirror = mod.get_outputs()[0][:, :, :, ::-1].copy() 39 | probs = (probs + probs_mirror) / 2 40 | prob = mx.nd.contrib.BilinearResize2D(probs, height=h, width=w).asnumpy()[0] 41 | 42 | pred = prob.argmax(axis=0).astype(np.uint8) 43 | 44 | d = dcrf.DenseCRF2D(w, h, prob.shape[0]) 45 | u = - prob.reshape(prob.shape[0], -1) 46 | d.setUnaryEnergy(u) 47 | 48 | d.addPairwiseGaussian(sxy=3, compat=3) 49 | d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=img, compat=10) 50 | #d.addPairwiseGaussian(sxy=1, compat=3) 51 | #d.addPairwiseBilateral(sxy=67, srgb=3, rgbim=img, compat=4) 52 | 53 | prob_crf = d.inference(10) 54 | prob_crf = np.array(prob_crf).reshape(-1, h, w) 55 | pred_crf = prob_crf.argmax(axis=0).astype(np.uint8) 56 | 57 | imwrite(os.path.join(pred_root, 'train_aug_pred', name+'.png'), pred) 58 | imwrite(os.path.join(pred_root, 'train_aug_pred_crf', name+'.png'), pred_crf) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--image-root', type=str, required=True) 64 | parser.add_argument('--annotation-root', type=str, required=True) 65 | parser.add_argument('--label-root', type=str, default='') 66 | parser.add_argument('--train-list', type=str, default='data/VOC2012/train_aug.txt') 67 | parser.add_argument('--test-list', type=str, default='data/VOC2012/val.txt') 68 | parser.add_argument('--data-list', type=str, default='') 69 | parser.add_argument('--snapshot', type=str, required=True) 70 | 71 | parser.add_argument('--model', type=str, required=True) 72 | parser.add_argument('--pretrained', type=str, default='') 73 | 74 | # train 75 | parser.add_argument('--begin-epoch', type=int, default=0) 76 | parser.add_argument('--num-epoch', type=int, default=20) 77 | parser.add_argument('--batch-size', type=int, default=1) 78 | parser.add_argument('--image-size', type=int, default=513) 79 | parser.add_argument('--num-cls', type=int, default=21) 80 | parser.add_argument('--lr', type=float, default=5e-4) 81 | 82 | parser.add_argument('--in-embed-type', type=str, default='conv') 83 | parser.add_argument('--out-embed-type', type=str, default='convbn') 84 | parser.add_argument('--merge-type', type=str, default='max') 85 | parser.add_argument('--group_size', type=int, default=2) 86 | 87 | parser.add_argument('--log-frequency', type=int, default=50) 88 | parser.add_argument('--num-save', type=int, default=5) 89 | parser.add_argument('--gpus', type=str, default='0', help='e.g. 0,1,2,3') 90 | 91 | # eval 92 | parser.add_argument('--pid', type=int, default=-1) 93 | parser.add_argument('--no-mirror', action='store_true') 94 | parser.add_argument('--no-ms', action='store_true') 95 | 96 | args = parser.parse_args() 97 | 98 | # 99 | if not args.data_list: 100 | args.data_list = args.train_list 101 | 102 | generate_retrain_seeds(args, args.pid) 103 | 104 | -------------------------------------------------------------------------------- /scripts/train_infer_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lib.loader import VOCSegGroupLoader, VOCSegLoader 3 | from lib.utils import * 4 | from lib.models import * 5 | import pydensecrf.densecrf as dcrf 6 | 7 | 8 | def build_model(args, for_training=True): 9 | x_img = mx.sym.Variable('data') 10 | x_lbl = mx.sym.Variable('label') 11 | 12 | symbol_func = eval(args.model.lower() + \ 13 | ('_CA' if for_training else '_SA') ) 14 | if (not args.no_ms) and (not for_training): 15 | symbol_func = MultiScale([0.75, 1, 1.25])(symbol_func) 16 | 17 | x_pred = symbol_func(x_img, args.num_cls, 18 | use_global_stats_backbone=True, 19 | use_global_stats_affinity=not for_training, 20 | merge_self=True, 21 | in_embed_type=args.in_embed_type, 22 | out_embed_type=args.out_embed_type, 23 | merge_type=args.merge_type, 24 | group_size=args.group_size ) 25 | 26 | if for_training: 27 | # cross entropy loss 28 | loss_list = [mx.sym.Custom(_pred, x_lbl, op_type='SegmentLoss') for _pred in x_pred] 29 | 30 | # completion loss 31 | x_cross = x_pred[1] 32 | loss_list += [mx.sym.Custom(_pred, x_cross, x_lbl, op_type='CompletionLoss') for _pred in x_pred] 33 | 34 | symbol = mx.sym.Group(loss_list) 35 | else: 36 | x_pred = [mx.sym.contrib.BilinearResize2D(_pred, 37 | height=args.image_size, width=args.image_size) for _pred in x_pred] 38 | x_pred = x_pred[0] if len(x_pred) == 1 else sum(x_pred) / len(x_pred) 39 | symbol = x_pred 40 | 41 | # freeze backbone bn params 42 | aff_bn_names = [] 43 | for suffix in ['_gamma', '_beta']: 44 | for aff_bn_name in ['bn_out', 'bn_embed_q', 'bn_embed_k', 'bn_embed_v']: 45 | aff_bn_names.append(aff_bn_name + suffix) 46 | 47 | fixed_param_names = [] 48 | if for_training: 49 | fixed_param_names += [name for name in symbol.list_arguments() \ 50 | if (name.endswith('_gamma') or name.endswith('_beta')) and (name not in aff_bn_names) ] 51 | 52 | # build model 53 | mod = mx.mod.Module(symbol, data_names=('data',), 54 | label_names=('label',) if for_training else None, 55 | context=[mx.gpu(int(gpu_id)) for gpu_id in setGPU(args.gpus).split(',')], 56 | fixed_param_names=fixed_param_names) 57 | label_size = (args.image_size - 1) // 8 + 1 58 | mod.bind(data_shapes=[('data', (args.batch_size, 3, args.image_size, args.image_size))], 59 | label_shapes=[('label', (args.batch_size, label_size, label_size))] if for_training else None ) 60 | 61 | # load / initialize parameters 62 | if for_training: 63 | pretrained = args.pretrained 64 | if args.retrain: 65 | pretrained = os.path.join(args.snapshot.rsplit('_', 1)[0], '%s-%04d.params' % (args.model, args.num_epoch-1)) 66 | else: 67 | pretrained = os.path.join(args.snapshot, '%s-%04d.params' % (args.model, args.num_epoch-1)) 68 | assert os.path.exists(pretrained), pretrained 69 | info(None, "Using pretrained params: {}".format(pretrained), 'red') 70 | 71 | arg_params, aux_params = loadParams(pretrained) 72 | arg_params, aux_params = checkParams(mod, arg_params, aux_params, initializer=mx.init.Normal(0.01), auto_fix=for_training) 73 | 74 | if for_training and (not args.retrain): 75 | arg_params['bn_out_gamma'] *= 0 76 | 77 | mod.init_params(arg_params=arg_params, aux_params=aux_params) 78 | mod.init_optimizer(optimizer='sgd', optimizer_params={ 79 | 'learning_rate': args.lr, 80 | 'momentum': 0.9, 81 | 'wd': 5e-4}, 82 | kvstore='device') 83 | 84 | return mod 85 | 86 | 87 | def run_training(args): 88 | mod = build_model(args) 89 | 90 | loader = VOCSegGroupLoader(args.image_root, args.label_root, args.annotation_root, 91 | args.data_list, args.batch_size, args.group_size, len(mod._context), args.image_size, 92 | pad=False, shuffle=True, rand_scale=True, rand_mirror=True, rand_crop=True, downsample=8) 93 | 94 | saveParams = SaveParams(mod, args.snapshot, args.model, args.num_save) 95 | lrScheduler = LrScheduler('poly', args.lr, {'num_epoch': args.num_epoch, 'power': 0.9} ) 96 | logger = getLogger(args.snapshot, args.model) 97 | summaryArgs(logger, vars(args), 'green') 98 | 99 | # train 100 | for n_epoch in range(args.begin_epoch, args.num_epoch): 101 | loader.reset() 102 | confmat = np.zeros((args.num_cls, args.num_cls), np.float32) 103 | loss = 0 104 | 105 | mod._optimizer.lr = lrScheduler.get(n_epoch) 106 | info(logger, "Learning rate: {}".format(mod._optimizer.lr), 'yellow') 107 | 108 | # monitor 109 | Timer.record() 110 | for n_batch, batch in enumerate(loader, 1): 111 | mod.forward_backward(batch) 112 | mod.update() 113 | 114 | if n_batch % args.log_frequency == 0: 115 | probs = mod.get_outputs()[0].as_in_context(mx.cpu()) 116 | label = mx.nd.one_hot(batch.label[0], args.num_cls).transpose((0, 3, 1, 2)) 117 | if probs.shape[2] != label.shape[2]: 118 | label = mx.nd.contrib.BilinearResize2D(label, height=probs.shape[2], width=probs.shape[3]) 119 | mask = label.sum(axis=1) 120 | _loss = -( (mx.nd.log((probs * label).sum(axis=1) + 1e-5) * mask).sum(axis=(1,2)) / \ 121 | mx.nd.maximum(mask.sum(axis=(1,2)), 1e-5) ).mean() 122 | 123 | loss_mom = (float(n_batch) - args.log_frequency) // n_batch 124 | loss = loss_mom * loss + (1 - loss_mom) * float(_loss.asnumpy()) 125 | 126 | gt = label.argmax(axis=1).asnumpy().astype(np.int32) 127 | pred = probs.argmax(axis=1).asnumpy().astype(np.int32) 128 | assert gt.shape == pred.shape 129 | idx = label.max(axis=1).asnumpy() > 0.01 130 | confmat += np.bincount(gt[idx] * args.num_cls + pred[idx], 131 | minlength=args.num_cls**2).reshape(args.num_cls, -1) 132 | iou = float((np.diag(confmat) / (confmat.sum(axis=0)+confmat.sum(axis=1)-np.diag(confmat)+1e-5)).mean()) 133 | 134 | Timer.record() 135 | msg = "Epoch={}, Batch={}, miou={:.4f}, loss={:.4f}, speed={:.1f} b/s" 136 | msg = msg.format(n_epoch, n_batch, iou, loss, args.log_frequency/Timer.interval()) 137 | info(logger, msg) 138 | 139 | saved_params = saveParams(n_epoch) 140 | info(logger, "Saved checkpoint:" + "\n ".join(saved_params), 'green') 141 | 142 | 143 | def run_infer(args, pid=-1): 144 | data_slice = None 145 | if pid >=0: 146 | gpus = args.gpus.split(',') 147 | data_slice = slice(pid, None, len(gpus)) 148 | args.gpus = gpus[pid] 149 | args.batch_size = len(args.gpus.split(',')) 150 | mod = build_model(args, False) 151 | 152 | loader = VOCSegLoader(args.image_root, None, args.data_list, 153 | args.batch_size, args.image_size, pad=True, shuffle=False, 154 | rand_scale=False, rand_mirror=False, rand_crop=False, data_slice=data_slice) 155 | pred_root = args.snapshot 156 | 157 | # inference 158 | for n_batch, batch in enumerate(loader, 1): 159 | image_src_list = loader.cache_image_src_list 160 | mod.forward(batch, is_train=False) 161 | probs = mod.get_outputs()[0].asnumpy() 162 | 163 | if not args.no_mirror: 164 | batch2 = mx.io.DataBatch(data=[batch.data[0][:, :, :, ::-1]]) 165 | mod.forward(batch2, is_train=False) 166 | probs_mirror = mod.get_outputs()[0].asnumpy()[:, :, :, ::-1] 167 | probs = (probs + probs_mirror) / 2 168 | 169 | for img_src, prob in zip(image_src_list, probs): 170 | img = cv2.imread(img_src)[..., ::-1].copy() 171 | h, w = img.shape[:2] 172 | prob = prob[:, :h, :w] 173 | pred = prob.argmax(axis=0).astype(np.uint8) 174 | 175 | d = dcrf.DenseCRF2D(w, h, prob.shape[0]) 176 | u = - prob.reshape(prob.shape[0], -1) 177 | d.setUnaryEnergy(u) 178 | 179 | # default CRF params 180 | d.addPairwiseGaussian(sxy=3, compat=3) 181 | d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=img, compat=10) 182 | 183 | prob_crf = d.inference(5) 184 | prob_crf = np.array(prob_crf).reshape(-1, h, w) 185 | pred_crf = prob_crf.argmax(axis=0).astype(np.uint8) 186 | 187 | name = os.path.basename(img_src).rsplit('.', 1)[0] 188 | imwrite(os.path.join(pred_root, 'pred', name+'.png'), pred) 189 | imwrite(os.path.join(pred_root, 'pred_crf', name+'.png'), pred_crf) 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--image-root', type=str, required=True) 195 | parser.add_argument('--annotation-root', type=str, required=True) 196 | parser.add_argument('--label-root', type=str, default='') 197 | parser.add_argument('--train-list', type=str, default='data/VOC2012/train_aug.txt') 198 | parser.add_argument('--test-list', type=str, default='data/VOC2012/val.txt') 199 | parser.add_argument('--data-list', type=str, default='') 200 | parser.add_argument('--snapshot', type=str, required=True) 201 | 202 | parser.add_argument('--model', type=str, required=True) 203 | parser.add_argument('--pretrained', type=str, default='') 204 | 205 | # train 206 | parser.add_argument('--begin-epoch', type=int, default=0) 207 | parser.add_argument('--num-epoch', type=int, default=20) 208 | parser.add_argument('--batch-size', type=int, default=16) 209 | parser.add_argument('--image-size', type=int, default=321) 210 | parser.add_argument('--num-cls', type=int, default=21) 211 | parser.add_argument('--lr', type=float, default=5e-4) 212 | 213 | parser.add_argument('--in-embed-type', type=str, default='conv') 214 | parser.add_argument('--out-embed-type', type=str, default='convbn') 215 | parser.add_argument('--merge-type', type=str, default='max') 216 | parser.add_argument('--group_size', type=int, default=2) 217 | 218 | parser.add_argument('--log-frequency', type=int, default=50) 219 | parser.add_argument('--num-save', type=int, default=5) 220 | parser.add_argument('--gpus', type=str, default='0', help='e.g. 0,1,2,3') 221 | 222 | # eval 223 | parser.add_argument('--infer', action='store_true') 224 | parser.add_argument('--pid', type=int, default=-1) 225 | parser.add_argument('--no-mirror', action='store_true') 226 | parser.add_argument('--no-ms', action='store_true') 227 | 228 | # retrain 229 | parser.add_argument('--retrain', action='store_true') 230 | 231 | args = parser.parse_args() 232 | 233 | # 234 | if not args.data_list: 235 | args.data_list = args.test_list if args.infer else args.train_list 236 | 237 | if args.retrain: 238 | args.snapshot += '_retrain' 239 | 240 | if args.infer: 241 | args.image_size = 513 242 | run_infer(args, args.pid) 243 | else: 244 | assert args.label_root 245 | if not args.retrain: 246 | assert args.pretrained 247 | run_training(args) 248 | 249 | --------------------------------------------------------------------------------