├── .gitignore ├── README.md ├── datasets ├── IRSTD-1k │ ├── test.txt │ └── trainval.txt ├── NUDT-SIRST00 │ ├── test.txt │ └── trainval.txt └── Sirstv2_512 │ ├── test.txt │ └── trainval.txt ├── demo.py ├── segment_anything_training ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── build_IRSAM.cpython-38.pyc │ └── build_sam.cpython-38.pyc ├── build_IRSAM.py ├── build_sam.py ├── modeling │ ├── IRSAM_decoder.py │ ├── IRSAM_edge.py │ ├── IRSAM_encoder.py │ ├── PMD.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── IRSAM_decoder.cpython-38.pyc │ │ ├── IRSAM_edge.cpython-38.pyc │ │ ├── IRSAM_encoder.cpython-38.pyc │ │ ├── PMD.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── common.cpython-38.pyc │ │ ├── image_encoder.cpython-38.pyc │ │ ├── mask_decoder.cpython-38.pyc │ │ ├── prompt_encoder.cpython-38.pyc │ │ ├── sam.cpython-38.pyc │ │ └── transformer.cpython-38.pyc │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py └── utils │ ├── PMD.py │ ├── __init__.py │ ├── __pycache__ │ ├── PMD.cpython-38.pyc │ └── __init__.cpython-38.pyc │ └── transforms.py └── utils ├── __pycache__ ├── dataloader.cpython-38.pyc ├── log.cpython-38.pyc ├── loss_mask.cpython-38.pyc ├── metric.cpython-38.pyc ├── metrics.cpython-38.pyc └── misc.cpython-38.pyc ├── dataloader.py ├── log.py ├── loss_mask.py ├── metric.py ├── metrics.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training instruction for IRSAM 2 | 3 | We organize the training folder as follows. 4 | ``` 5 | train 6 | |____datasets 7 | |____segment_anything_training 8 | |____demo.py 9 | |____utils 10 | | |____dataloader.py 11 | | |____log.py 12 | | |____misc.py 13 | | |____metric.py 14 | | |____metrics.py 15 | | |____loss_mask.py 16 | | |____misc.py 17 | |____workdirs 18 | |____mobile_sam.pt 19 | ``` 20 | 21 | According to the given train/test.txt, place the training/testing images/labels into four separate folders. 22 | 23 | The current IRSAM_encoder provided is the baseline of Mobile-SAM, and our designed IRSAM_encoder with WPMD will be released soon. 24 | 25 | ## Evaluation 26 | To evaluate IRSAM on various dataset, modify the dataset path in the train-IRSAM.py file 27 | 28 | ### Example 29 | ``` 30 | python train_IRSAM.py --output workdirs/your_workdir --checkpoint your_checkpoint --eval 31 | ``` 32 | -------------------------------------------------------------------------------- /datasets/IRSTD-1k/test.txt: -------------------------------------------------------------------------------- 1 | XDU189 2 | XDU935 3 | XDU672 4 | XDU231 5 | XDU818 6 | XDU888 7 | XDU146 8 | XDU48 9 | XDU492 10 | XDU241 11 | XDU195 12 | XDU801 13 | XDU104 14 | XDU637 15 | XDU996 16 | XDU482 17 | XDU406 18 | XDU889 19 | XDU558 20 | XDU117 21 | XDU777 22 | XDU134 23 | XDU223 24 | XDU943 25 | XDU762 26 | XDU662 27 | XDU54 28 | XDU685 29 | XDU167 30 | XDU489 31 | XDU505 32 | XDU527 33 | XDU817 34 | XDU253 35 | XDU193 36 | XDU597 37 | XDU151 38 | XDU404 39 | XDU596 40 | XDU97 41 | XDU321 42 | XDU279 43 | XDU93 44 | XDU205 45 | XDU9 46 | XDU219 47 | XDU674 48 | XDU501 49 | XDU316 50 | XDU343 51 | XDU885 52 | XDU426 53 | XDU485 54 | XDU850 55 | XDU516 56 | XDU216 57 | XDU160 58 | XDU176 59 | XDU504 60 | XDU883 61 | XDU244 62 | XDU919 63 | XDU781 64 | XDU369 65 | XDU398 66 | XDU441 67 | XDU75 68 | XDU240 69 | XDU805 70 | XDU108 71 | XDU709 72 | XDU352 73 | XDU747 74 | XDU209 75 | XDU845 76 | XDU557 77 | XDU775 78 | XDU56 79 | XDU657 80 | XDU753 81 | XDU788 82 | XDU682 83 | XDU794 84 | XDU877 85 | XDU421 86 | XDU733 87 | XDU546 88 | XDU999 89 | XDU5 90 | XDU63 91 | XDU966 92 | XDU922 93 | XDU789 94 | XDU295 95 | XDU863 96 | XDU578 97 | XDU743 98 | XDU46 99 | XDU115 100 | XDU876 101 | XDU932 102 | XDU289 103 | XDU855 104 | XDU933 105 | XDU517 106 | XDU329 107 | XDU3 108 | XDU451 109 | XDU694 110 | XDU878 111 | XDU259 112 | XDU708 113 | XDU442 114 | XDU829 115 | XDU833 116 | XDU648 117 | XDU381 118 | XDU868 119 | XDU803 120 | XDU673 121 | XDU415 122 | XDU667 123 | XDU968 124 | XDU169 125 | XDU525 126 | XDU164 127 | XDU704 128 | XDU711 129 | XDU111 130 | XDU354 131 | XDU927 132 | XDU758 133 | XDU87 134 | XDU697 135 | XDU957 136 | XDU49 137 | XDU563 138 | XDU954 139 | XDU45 140 | XDU429 141 | XDU902 142 | XDU302 143 | XDU523 144 | XDU41 145 | XDU816 146 | XDU785 147 | XDU759 148 | XDU872 149 | XDU185 150 | XDU881 151 | XDU447 152 | XDU129 153 | XDU614 154 | XDU920 155 | XDU334 156 | XDU257 157 | XDU892 158 | XDU103 159 | XDU698 160 | XDU862 161 | XDU33 162 | XDU416 163 | XDU40 164 | XDU715 165 | XDU203 166 | XDU589 167 | XDU142 168 | XDU50 169 | XDU455 170 | XDU620 171 | XDU67 172 | XDU371 173 | XDU192 174 | XDU28 175 | XDU43 176 | XDU661 177 | XDU692 178 | XDU463 179 | XDU745 180 | XDU258 181 | XDU842 182 | XDU459 183 | XDU147 184 | XDU319 185 | XDU225 186 | XDU178 187 | XDU567 188 | XDU925 189 | XDU394 190 | XDU110 191 | XDU663 192 | XDU376 193 | XDU450 194 | XDU10 195 | XDU955 196 | XDU374 197 | XDU278 198 | XDU393 199 | XDU570 200 | XDU217 201 | -------------------------------------------------------------------------------- /datasets/IRSTD-1k/trainval.txt: -------------------------------------------------------------------------------- 1 | XDU514 2 | XDU646 3 | XDU904 4 | XDU660 5 | XDU347 6 | XDU962 7 | XDU92 8 | XDU838 9 | XDU907 10 | XDU496 11 | XDU83 12 | XDU606 13 | XDU307 14 | XDU138 15 | XDU357 16 | XDU993 17 | XDU693 18 | XDU493 19 | XDU891 20 | XDU410 21 | XDU288 22 | XDU562 23 | XDU849 24 | XDU23 25 | XDU199 26 | XDU370 27 | XDU537 28 | XDU871 29 | XDU656 30 | XDU331 31 | XDU328 32 | XDU403 33 | XDU230 34 | XDU529 35 | XDU229 36 | XDU476 37 | XDU792 38 | XDU412 39 | XDU689 40 | XDU51 41 | XDU532 42 | XDU356 43 | XDU303 44 | XDU161 45 | XDU879 46 | XDU867 47 | XDU773 48 | XDU323 49 | XDU836 50 | XDU236 51 | XDU749 52 | XDU807 53 | XDU72 54 | XDU128 55 | XDU822 56 | XDU480 57 | XDU270 58 | XDU651 59 | XDU815 60 | XDU30 61 | XDU548 62 | XDU386 63 | XDU555 64 | XDU122 65 | XDU798 66 | XDU264 67 | XDU725 68 | XDU806 69 | XDU440 70 | XDU332 71 | XDU875 72 | XDU325 73 | XDU486 74 | XDU659 75 | XDU835 76 | XDU335 77 | XDU330 78 | XDU132 79 | XDU89 80 | XDU580 81 | XDU8 82 | XDU4 83 | XDU280 84 | XDU895 85 | XDU869 86 | XDU799 87 | XDU419 88 | XDU772 89 | XDU896 90 | XDU604 91 | XDU116 92 | XDU85 93 | XDU91 94 | XDU158 95 | XDU130 96 | XDU611 97 | XDU98 98 | XDU299 99 | XDU114 100 | XDU923 101 | XDU94 102 | XDU800 103 | XDU934 104 | XDU998 105 | XDU338 106 | XDU959 107 | XDU712 108 | XDU754 109 | XDU636 110 | XDU624 111 | XDU918 112 | XDU26 113 | XDU559 114 | XDU761 115 | XDU909 116 | XDU340 117 | XDU262 118 | XDU964 119 | XDU29 120 | XDU443 121 | XDU929 122 | XDU739 123 | XDU654 124 | XDU858 125 | XDU551 126 | XDU365 127 | XDU665 128 | XDU705 129 | XDU434 130 | XDU105 131 | XDU887 132 | XDU910 133 | XDU261 134 | XDU360 135 | XDU912 136 | XDU894 137 | XDU727 138 | XDU42 139 | XDU556 140 | XDU613 141 | XDU285 142 | XDU668 143 | XDU592 144 | XDU341 145 | XDU942 146 | XDU982 147 | XDU191 148 | XDU528 149 | XDU720 150 | XDU601 151 | XDU531 152 | XDU21 153 | XDU275 154 | XDU232 155 | XDU344 156 | XDU301 157 | XDU977 158 | XDU390 159 | XDU938 160 | XDU975 161 | XDU688 162 | XDU979 163 | XDU58 164 | XDU494 165 | XDU988 166 | XDU826 167 | XDU571 168 | XDU1 169 | XDU779 170 | XDU633 171 | XDU987 172 | XDU449 173 | XDU313 174 | XDU153 175 | XDU718 176 | XDU978 177 | XDU538 178 | XDU950 179 | XDU740 180 | XDU986 181 | XDU780 182 | XDU965 183 | XDU900 184 | XDU423 185 | XDU194 186 | XDU490 187 | XDU73 188 | XDU675 189 | XDU254 190 | XDU547 191 | XDU983 192 | XDU227 193 | XDU68 194 | XDU445 195 | XDU903 196 | XDU652 197 | XDU642 198 | XDU599 199 | XDU64 200 | XDU221 201 | XDU291 202 | XDU460 203 | XDU623 204 | XDU766 205 | XDU680 206 | XDU714 207 | XDU25 208 | XDU19 209 | XDU65 210 | XDU671 211 | XDU333 212 | XDU375 213 | XDU550 214 | XDU790 215 | XDU397 216 | XDU497 217 | XDU645 218 | XDU470 219 | XDU913 220 | XDU956 221 | XDU218 222 | XDU380 223 | XDU487 224 | XDU66 225 | XDU324 226 | XDU612 227 | XDU384 228 | XDU544 229 | XDU144 230 | XDU542 231 | XDU248 232 | XDU461 233 | XDU148 234 | XDU653 235 | XDU336 236 | XDU866 237 | XDU456 238 | XDU540 239 | XDU351 240 | XDU112 241 | XDU272 242 | XDU53 243 | XDU515 244 | XDU699 245 | XDU417 246 | XDU639 247 | XDU342 248 | XDU31 249 | XDU448 250 | XDU292 251 | XDU728 252 | XDU180 253 | XDU149 254 | XDU706 255 | XDU973 256 | XDU320 257 | XDU625 258 | XDU811 259 | XDU462 260 | XDU388 261 | XDU638 262 | XDU90 263 | XDU109 264 | XDU722 265 | XDU162 266 | XDU972 267 | XDU767 268 | XDU349 269 | XDU263 270 | XDU465 271 | XDU576 272 | XDU507 273 | XDU644 274 | XDU587 275 | XDU255 276 | XDU326 277 | XDU500 278 | XDU586 279 | XDU524 280 | XDU765 281 | XDU890 282 | XDU960 283 | XDU594 284 | XDU80 285 | XDU14 286 | XDU569 287 | XDU953 288 | XDU884 289 | XDU282 290 | XDU387 291 | XDU579 292 | XDU260 293 | XDU252 294 | XDU971 295 | XDU905 296 | XDU994 297 | XDU36 298 | XDU266 299 | XDU405 300 | XDU208 301 | XDU207 302 | XDU723 303 | XDU35 304 | XDU590 305 | XDU678 306 | XDU629 307 | XDU939 308 | XDU804 309 | XDU948 310 | XDU24 311 | XDU967 312 | XDU293 313 | XDU545 314 | XDU901 315 | XDU290 316 | XDU989 317 | XDU322 318 | XDU707 319 | XDU188 320 | XDU582 321 | XDU810 322 | XDU439 323 | XDU300 324 | XDU237 325 | XDU457 326 | XDU433 327 | XDU831 328 | XDU917 329 | XDU677 330 | XDU82 331 | XDU561 332 | XDU413 333 | XDU834 334 | XDU368 335 | XDU658 336 | XDU898 337 | XDU173 338 | XDU34 339 | XDU467 340 | XDU2 341 | XDU265 342 | XDU735 343 | XDU454 344 | XDU163 345 | XDU81 346 | XDU74 347 | XDU140 348 | XDU166 349 | XDU478 350 | XDU61 351 | XDU880 352 | XDU841 353 | XDU565 354 | XDU377 355 | XDU839 356 | XDU691 357 | XDU607 358 | XDU530 359 | XDU844 360 | XDU847 361 | XDU472 362 | XDU882 363 | XDU17 364 | XDU619 365 | XDU12 366 | XDU736 367 | XDU859 368 | XDU186 369 | XDU985 370 | XDU389 371 | XDU921 372 | XDU355 373 | XDU539 374 | XDU141 375 | XDU916 376 | XDU135 377 | XDU519 378 | XDU621 379 | XDU435 380 | XDU760 381 | XDU783 382 | XDU591 383 | XDU628 384 | XDU464 385 | XDU649 386 | XDU198 387 | XDU560 388 | XDU372 389 | XDU553 390 | XDU458 391 | XDU183 392 | XDU650 393 | XDU622 394 | XDU825 395 | XDU471 396 | XDU190 397 | XDU414 398 | XDU44 399 | XDU741 400 | XDU635 401 | XDU647 402 | XDU573 403 | XDU864 404 | XDU618 405 | XDU364 406 | XDU543 407 | XDU437 408 | XDU502 409 | XDU824 410 | XDU952 411 | XDU125 412 | XDU641 413 | XDU491 414 | XDU201 415 | XDU947 416 | XDU444 417 | XDU79 418 | XDU518 419 | XDU32 420 | XDU283 421 | XDU802 422 | XDU483 423 | XDU59 424 | XDU477 425 | XDU670 426 | XDU234 427 | XDU577 428 | XDU969 429 | XDU669 430 | XDU995 431 | XDU425 432 | XDU506 433 | XDU970 434 | XDU681 435 | XDU353 436 | XDU676 437 | XDU958 438 | XDU782 439 | XDU121 440 | XDU363 441 | XDU411 442 | XDU690 443 | XDU392 444 | XDU536 445 | XDU752 446 | XDU210 447 | XDU774 448 | XDU350 449 | XDU731 450 | XDU930 451 | XDU479 452 | XDU602 453 | XDU856 454 | XDU96 455 | XDU520 456 | XDU13 457 | XDU915 458 | XDU106 459 | XDU853 460 | XDU308 461 | XDU47 462 | XDU769 463 | XDU242 464 | XDU899 465 | XDU484 466 | XDU581 467 | XDU643 468 | XDU827 469 | XDU246 470 | XDU119 471 | XDU746 472 | XDU100 473 | XDU311 474 | XDU512 475 | XDU852 476 | XDU509 477 | XDU874 478 | XDU686 479 | XDU139 480 | XDU928 481 | XDU719 482 | XDU296 483 | XDU750 484 | XDU713 485 | XDU156 486 | XDU634 487 | XDU6 488 | XDU821 489 | XDU488 490 | XDU420 491 | XDU748 492 | XDU126 493 | XDU474 494 | XDU273 495 | XDU742 496 | XDU438 497 | XDU20 498 | XDU27 499 | XDU452 500 | XDU541 501 | XDU598 502 | XDU949 503 | XDU992 504 | XDU513 505 | XDU155 506 | XDU228 507 | XDU206 508 | XDU69 509 | XDU666 510 | XDU860 511 | XDU136 512 | XDU617 513 | XDU436 514 | XDU716 515 | XDU823 516 | XDU38 517 | XDU1000 518 | XDU851 519 | XDU627 520 | XDU974 521 | XDU717 522 | XDU734 523 | XDU796 524 | XDU481 525 | XDU990 526 | XDU679 527 | XDU764 528 | XDU238 529 | XDU848 530 | XDU908 531 | XDU418 532 | XDU696 533 | XDU378 534 | XDU724 535 | XDU632 536 | XDU182 537 | XDU76 538 | XDU791 539 | XDU830 540 | XDU814 541 | XDU306 542 | XDU931 543 | XDU154 544 | XDU564 545 | XDU383 546 | XDU473 547 | XDU84 548 | XDU143 549 | XDU18 550 | XDU683 551 | XDU495 552 | XDU78 553 | XDU840 554 | XDU382 555 | XDU385 556 | XDU233 557 | XDU220 558 | XDU616 559 | XDU655 560 | XDU797 561 | XDU854 562 | XDU312 563 | XDU924 564 | XDU593 565 | XDU150 566 | XDU60 567 | XDU951 568 | XDU812 569 | XDU608 570 | XDU408 571 | XDU184 572 | XDU552 573 | XDU172 574 | XDU820 575 | XDU584 576 | XDU314 577 | XDU702 578 | XDU174 579 | XDU379 580 | XDU511 581 | XDU914 582 | XDU214 583 | XDU315 584 | XDU535 585 | XDU305 586 | XDU294 587 | XDU71 588 | XDU534 589 | XDU133 590 | XDU204 591 | XDU991 592 | XDU475 593 | XDU870 594 | XDU793 595 | XDU585 596 | XDU431 597 | XDU786 598 | XDU944 599 | XDU102 600 | XDU245 601 | XDU857 602 | XDU427 603 | XDU738 604 | XDU726 605 | XDU568 606 | XDU843 607 | XDU131 608 | XDU298 609 | XDU498 610 | XDU837 611 | XDU243 612 | XDU588 613 | XDU832 614 | XDU526 615 | XDU710 616 | XDU177 617 | XDU310 618 | XDU165 619 | XDU603 620 | XDU99 621 | XDU22 622 | XDU615 623 | XDU566 624 | XDU286 625 | XDU703 626 | XDU361 627 | XDU795 628 | XDU277 629 | XDU508 630 | XDU521 631 | XDU297 632 | XDU317 633 | XDU861 634 | XDU271 635 | XDU318 636 | XDU572 637 | XDU247 638 | XDU202 639 | XDU946 640 | XDU466 641 | XDU732 642 | XDU226 643 | XDU583 644 | XDU120 645 | XDU926 646 | XDU401 647 | XDU687 648 | XDU984 649 | XDU819 650 | XDU664 651 | XDU400 652 | XDU446 653 | XDU453 654 | XDU730 655 | XDU362 656 | XDU62 657 | XDU175 658 | XDU809 659 | XDU430 660 | XDU124 661 | XDU346 662 | XDU776 663 | XDU605 664 | XDU187 665 | XDU337 666 | XDU211 667 | XDU684 668 | XDU179 669 | XDU981 670 | XDU784 671 | XDU701 672 | XDU358 673 | XDU768 674 | XDU911 675 | XDU235 676 | XDU215 677 | XDU145 678 | XDU609 679 | XDU281 680 | XDU432 681 | XDU196 682 | XDU499 683 | XDU250 684 | XDU304 685 | XDU600 686 | XDU309 687 | XDU171 688 | XDU787 689 | XDU595 690 | XDU808 691 | XDU7 692 | XDU846 693 | XDU428 694 | XDU287 695 | XDU729 696 | XDU213 697 | XDU828 698 | XDU941 699 | XDU395 700 | XDU756 701 | XDU897 702 | XDU239 703 | XDU610 704 | XDU251 705 | XDU373 706 | XDU533 707 | XDU95 708 | XDU57 709 | XDU945 710 | XDU222 711 | XDU168 712 | XDU137 713 | XDU961 714 | XDU906 715 | XDU937 716 | XDU0 717 | XDU770 718 | XDU268 719 | XDU963 720 | XDU113 721 | XDU771 722 | XDU763 723 | XDU339 724 | XDU52 725 | XDU737 726 | XDU755 727 | XDU159 728 | XDU626 729 | XDU16 730 | XDU118 731 | XDU77 732 | XDU574 733 | XDU402 734 | XDU407 735 | XDU88 736 | XDU778 737 | XDU391 738 | XDU15 739 | XDU940 740 | XDU886 741 | XDU359 742 | XDU424 743 | XDU721 744 | XDU399 745 | XDU345 746 | XDU157 747 | XDU107 748 | XDU873 749 | XDU865 750 | XDU39 751 | XDU893 752 | XDU976 753 | XDU695 754 | XDU367 755 | XDU700 756 | XDU422 757 | XDU936 758 | XDU123 759 | XDU503 760 | XDU366 761 | XDU101 762 | XDU631 763 | XDU276 764 | XDU549 765 | XDU212 766 | XDU197 767 | XDU640 768 | XDU200 769 | XDU37 770 | XDU469 771 | XDU522 772 | XDU575 773 | XDU256 774 | XDU409 775 | XDU152 776 | XDU224 777 | XDU86 778 | XDU630 779 | XDU980 780 | XDU813 781 | XDU70 782 | XDU249 783 | XDU396 784 | XDU11 785 | XDU327 786 | XDU269 787 | XDU284 788 | XDU757 789 | XDU348 790 | XDU554 791 | XDU127 792 | XDU267 793 | XDU751 794 | XDU181 795 | XDU468 796 | XDU274 797 | XDU55 798 | XDU510 799 | XDU170 800 | XDU744 801 | -------------------------------------------------------------------------------- /datasets/NUDT-SIRST00/test.txt: -------------------------------------------------------------------------------- 1 | 001000.png 2 | 001170.png 3 | 000159.png 4 | 000300.png 5 | 000506.png 6 | 000839.png 7 | 000189.png 8 | 000321.png 9 | 001207.png 10 | 000588.png 11 | 001165.png 12 | 000394.png 13 | 000876.png 14 | 001066.png 15 | 000026.png 16 | 000084.png 17 | 001010.png 18 | 000136.png 19 | 001216.png 20 | 000051.png 21 | 000111.png 22 | 000716.png 23 | 000162.png 24 | 000212.png 25 | 000777.png 26 | 000831.png 27 | 000711.png 28 | 001045.png 29 | 000203.png 30 | 001276.png 31 | 000479.png 32 | 000028.png 33 | 000427.png 34 | 001268.png 35 | 000481.png 36 | 000628.png 37 | 000972.png 38 | 000487.png 39 | 000053.png 40 | 000013.png 41 | 000264.png 42 | 000621.png 43 | 000391.png 44 | 000495.png 45 | 000146.png 46 | 000574.png 47 | 000885.png 48 | 001147.png 49 | 000899.png 50 | 001075.png 51 | 001171.png 52 | 000611.png 53 | 000147.png 54 | 000289.png 55 | 000858.png 56 | 001244.png 57 | 000066.png 58 | 001099.png 59 | 000038.png 60 | 001270.png 61 | 000903.png 62 | 000689.png 63 | 001192.png 64 | 001316.png 65 | 000338.png 66 | 001252.png 67 | 000091.png 68 | 000036.png 69 | 000131.png 70 | 001155.png 71 | 000643.png 72 | 000842.png 73 | 000380.png 74 | 000802.png 75 | 001130.png 76 | 000282.png 77 | 000526.png 78 | 000794.png 79 | 000683.png 80 | 000694.png 81 | 000387.png 82 | 001105.png 83 | 000600.png 84 | 000587.png 85 | 000891.png 86 | 000320.png 87 | 000423.png 88 | 001145.png 89 | 000907.png 90 | 000018.png 91 | 001068.png 92 | 001065.png 93 | 000322.png 94 | 001124.png 95 | 000714.png 96 | 000631.png 97 | 000167.png 98 | 000434.png 99 | 000896.png 100 | 001143.png 101 | 001025.png 102 | 001238.png 103 | 000268.png 104 | 001009.png 105 | 000936.png 106 | 000838.png 107 | 000068.png 108 | 000149.png 109 | 000107.png 110 | 000222.png 111 | 000116.png 112 | 000719.png 113 | 000296.png 114 | 000278.png 115 | 000329.png 116 | 000663.png 117 | 000229.png 118 | 001308.png 119 | 000925.png 120 | 000529.png 121 | 000349.png 122 | 000488.png 123 | 000857.png 124 | 000035.png 125 | 001200.png 126 | 000545.png 127 | 000820.png 128 | 001255.png 129 | 000590.png 130 | 000324.png 131 | 000753.png 132 | 000760.png 133 | 000449.png 134 | 000155.png 135 | 000050.png 136 | 001077.png 137 | 001242.png 138 | 000843.png 139 | 000509.png 140 | 000785.png 141 | 001257.png 142 | 001043.png 143 | 000145.png 144 | 000072.png 145 | 000566.png 146 | 000367.png 147 | 000652.png 148 | 000286.png 149 | 000077.png 150 | 000889.png 151 | 000178.png 152 | 001163.png 153 | 001094.png 154 | 000937.png 155 | 000283.png 156 | 001125.png 157 | 000430.png 158 | 000049.png 159 | 000056.png 160 | 000213.png 161 | 001191.png 162 | 001156.png 163 | 001063.png 164 | 000217.png 165 | 001046.png 166 | 000740.png 167 | 001169.png 168 | 000900.png 169 | 000893.png 170 | 001127.png 171 | 000976.png 172 | 000240.png 173 | 001042.png 174 | 000607.png 175 | 000579.png 176 | 000489.png 177 | 000477.png 178 | 000958.png 179 | 000302.png 180 | 000713.png 181 | 000669.png 182 | 000584.png 183 | 001038.png 184 | 000029.png 185 | 001274.png 186 | 000347.png 187 | 000544.png 188 | 000951.png 189 | 000673.png 190 | 000576.png 191 | 000660.png 192 | 000006.png 193 | 000078.png 194 | 000577.png 195 | 000170.png 196 | 000667.png 197 | 000502.png 198 | 001095.png 199 | 000271.png 200 | 001315.png 201 | 000436.png 202 | 001166.png 203 | 001108.png 204 | 000376.png 205 | 001296.png 206 | 000274.png 207 | 000721.png 208 | 000533.png 209 | 001006.png 210 | 000786.png 211 | 000718.png 212 | 000651.png 213 | 001305.png 214 | 000816.png 215 | 001262.png 216 | 000248.png 217 | 000395.png 218 | 000419.png 219 | 000725.png 220 | 000315.png 221 | 001214.png 222 | 000637.png 223 | 000382.png 224 | 000724.png 225 | 000065.png 226 | 001157.png 227 | 000904.png 228 | 000822.png 229 | 001160.png 230 | 001102.png 231 | 000472.png 232 | 000493.png 233 | 000764.png 234 | 000674.png 235 | 001211.png 236 | 001282.png 237 | 000610.png 238 | 000901.png 239 | 001261.png 240 | 000032.png 241 | 000412.png 242 | 001093.png 243 | 000124.png 244 | 000514.png 245 | 001267.png 246 | 000763.png 247 | 000993.png 248 | 000866.png 249 | 000447.png 250 | 000519.png 251 | 000984.png 252 | 001185.png 253 | 000314.png 254 | 000868.png 255 | 000557.png 256 | 000346.png 257 | 000609.png 258 | 000333.png 259 | 000328.png 260 | 000824.png 261 | 000644.png 262 | 000757.png 263 | 001227.png 264 | 000662.png 265 | 000681.png 266 | 000076.png 267 | -------------------------------------------------------------------------------- /datasets/NUDT-SIRST00/trainval.txt: -------------------------------------------------------------------------------- 1 | 000092.png 2 | 001306.png 3 | 000805.png 4 | 000342.png 5 | 001209.png 6 | 000272.png 7 | 000453.png 8 | 000508.png 9 | 000917.png 10 | 000310.png 11 | 000525.png 12 | 001033.png 13 | 000327.png 14 | 000024.png 15 | 000368.png 16 | 001254.png 17 | 001264.png 18 | 000801.png 19 | 000276.png 20 | 001258.png 21 | 000425.png 22 | 000069.png 23 | 000411.png 24 | 000418.png 25 | 000768.png 26 | 000144.png 27 | 000686.png 28 | 000692.png 29 | 001012.png 30 | 000962.png 31 | 000527.png 32 | 000790.png 33 | 000318.png 34 | 001221.png 35 | 000795.png 36 | 001266.png 37 | 000580.png 38 | 000730.png 39 | 000337.png 40 | 000825.png 41 | 000499.png 42 | 000537.png 43 | 000137.png 44 | 000780.png 45 | 000101.png 46 | 000515.png 47 | 000957.png 48 | 001032.png 49 | 001137.png 50 | 000284.png 51 | 000767.png 52 | 000736.png 53 | 001318.png 54 | 001158.png 55 | 000593.png 56 | 000595.png 57 | 000894.png 58 | 000659.png 59 | 000598.png 60 | 001056.png 61 | 000635.png 62 | 000291.png 63 | 000510.png 64 | 001236.png 65 | 000987.png 66 | 000362.png 67 | 000043.png 68 | 000848.png 69 | 000623.png 70 | 001069.png 71 | 000803.png 72 | 001178.png 73 | 000658.png 74 | 000232.png 75 | 000374.png 76 | 000097.png 77 | 001097.png 78 | 000044.png 79 | 000884.png 80 | 000210.png 81 | 000986.png 82 | 000971.png 83 | 001007.png 84 | 000997.png 85 | 000554.png 86 | 000500.png 87 | 001269.png 88 | 000312.png 89 | 000401.png 90 | 000793.png 91 | 001292.png 92 | 000617.png 93 | 001047.png 94 | 000020.png 95 | 000126.png 96 | 001311.png 97 | 001039.png 98 | 000142.png 99 | 000054.png 100 | 000358.png 101 | 000112.png 102 | 000875.png 103 | 000363.png 104 | 000108.png 105 | 000429.png 106 | 001299.png 107 | 000179.png 108 | 000797.png 109 | 000372.png 110 | 000624.png 111 | 001109.png 112 | 000273.png 113 | 000708.png 114 | 000994.png 115 | 001061.png 116 | 001250.png 117 | 000133.png 118 | 000973.png 119 | 000450.png 120 | 001008.png 121 | 000400.png 122 | 000059.png 123 | 000812.png 124 | 000728.png 125 | 000200.png 126 | 000398.png 127 | 000197.png 128 | 000005.png 129 | 000158.png 130 | 000523.png 131 | 000236.png 132 | 000205.png 133 | 000928.png 134 | 000636.png 135 | 000850.png 136 | 000313.png 137 | 000505.png 138 | 000783.png 139 | 000746.png 140 | 001111.png 141 | 000345.png 142 | 001177.png 143 | 000023.png 144 | 000045.png 145 | 000165.png 146 | 000865.png 147 | 000175.png 148 | 000859.png 149 | 001162.png 150 | 000855.png 151 | 000004.png 152 | 000462.png 153 | 000585.png 154 | 001217.png 155 | 000365.png 156 | 000747.png 157 | 000949.png 158 | 000216.png 159 | 000215.png 160 | 000878.png 161 | 000235.png 162 | 000742.png 163 | 000193.png 164 | 000294.png 165 | 000946.png 166 | 001118.png 167 | 000758.png 168 | 001195.png 169 | 000095.png 170 | 000751.png 171 | 000923.png 172 | 000183.png 173 | 001090.png 174 | 000180.png 175 | 001019.png 176 | 001263.png 177 | 001120.png 178 | 001251.png 179 | 000242.png 180 | 001167.png 181 | 000735.png 182 | 000748.png 183 | 001114.png 184 | 000413.png 185 | 001260.png 186 | 000432.png 187 | 000473.png 188 | 001295.png 189 | 000441.png 190 | 000306.png 191 | 000080.png 192 | 000561.png 193 | 000890.png 194 | 001036.png 195 | 001017.png 196 | 000445.png 197 | 001313.png 198 | 001284.png 199 | 000620.png 200 | 000403.png 201 | 001161.png 202 | 000371.png 203 | 000117.png 204 | 001133.png 205 | 000555.png 206 | 001180.png 207 | 001001.png 208 | 000534.png 209 | 001051.png 210 | 000727.png 211 | 000298.png 212 | 000293.png 213 | 000870.png 214 | 001011.png 215 | 000709.png 216 | 000064.png 217 | 000869.png 218 | 000245.png 219 | 000882.png 220 | 000530.png 221 | 000438.png 222 | 000592.png 223 | 001152.png 224 | 001279.png 225 | 001291.png 226 | 000833.png 227 | 000030.png 228 | 000301.png 229 | 000653.png 230 | 001314.png 231 | 000649.png 232 | 000100.png 233 | 001219.png 234 | 000814.png 235 | 001194.png 236 | 000397.png 237 | 000563.png 238 | 000459.png 239 | 001183.png 240 | 000934.png 241 | 000428.png 242 | 000867.png 243 | 000929.png 244 | 000685.png 245 | 000192.png 246 | 000942.png 247 | 000699.png 248 | 001259.png 249 | 000871.png 250 | 000945.png 251 | 000782.png 252 | 000087.png 253 | 000199.png 254 | 001323.png 255 | 000664.png 256 | 000330.png 257 | 001224.png 258 | 000055.png 259 | 000392.png 260 | 000096.png 261 | 000285.png 262 | 000573.png 263 | 000234.png 264 | 000464.png 265 | 000888.png 266 | 000920.png 267 | 000779.png 268 | 000737.png 269 | 000575.png 270 | 000513.png 271 | 000761.png 272 | 000770.png 273 | 001322.png 274 | 000408.png 275 | 000486.png 276 | 000539.png 277 | 000532.png 278 | 000369.png 279 | 000370.png 280 | 000041.png 281 | 000007.png 282 | 000655.png 283 | 000810.png 284 | 000743.png 285 | 000756.png 286 | 001286.png 287 | 000042.png 288 | 001289.png 289 | 000241.png 290 | 000874.png 291 | 000410.png 292 | 000016.png 293 | 001246.png 294 | 001060.png 295 | 001176.png 296 | 000386.png 297 | 000661.png 298 | 001055.png 299 | 000130.png 300 | 000517.png 301 | 001024.png 302 | 000639.png 303 | 001220.png 304 | 000343.png 305 | 001247.png 306 | 000963.png 307 | 000558.png 308 | 000311.png 309 | 001154.png 310 | 000881.png 311 | 001222.png 312 | 000033.png 313 | 000819.png 314 | 001189.png 315 | 000265.png 316 | 001248.png 317 | 000641.png 318 | 001078.png 319 | 001225.png 320 | 000299.png 321 | 000225.png 322 | 000691.png 323 | 000040.png 324 | 000416.png 325 | 000496.png 326 | 000629.png 327 | 000703.png 328 | 000204.png 329 | 000454.png 330 | 000331.png 331 | 000952.png 332 | 000177.png 333 | 001106.png 334 | 000926.png 335 | 001293.png 336 | 000619.png 337 | 000991.png 338 | 000325.png 339 | 001230.png 340 | 000905.png 341 | 000693.png 342 | 000119.png 343 | 001018.png 344 | 000431.png 345 | 000845.png 346 | 000755.png 347 | 000811.png 348 | 000132.png 349 | 000854.png 350 | 000549.png 351 | 000698.png 352 | 000195.png 353 | 000083.png 354 | 001281.png 355 | 001231.png 356 | 000011.png 357 | 000670.png 358 | 000012.png 359 | 001139.png 360 | 001168.png 361 | 000999.png 362 | 000927.png 363 | 001016.png 364 | 001203.png 365 | 000104.png 366 | 000966.png 367 | 001092.png 368 | 001206.png 369 | 000648.png 370 | 001205.png 371 | 000690.png 372 | 000256.png 373 | 000606.png 374 | 001014.png 375 | 000110.png 376 | 000960.png 377 | 000491.png 378 | 000003.png 379 | 000541.png 380 | 000270.png 381 | 000422.png 382 | 000102.png 383 | 001312.png 384 | 001073.png 385 | 000879.png 386 | 001287.png 387 | 000060.png 388 | 000057.png 389 | 001113.png 390 | 001087.png 391 | 000996.png 392 | 001122.png 393 | 000955.png 394 | 000448.png 395 | 000796.png 396 | 000169.png 397 | 000326.png 398 | 000194.png 399 | 001049.png 400 | 000538.png 401 | 000250.png 402 | 000174.png 403 | 001271.png 404 | 000840.png 405 | 000938.png 406 | 000863.png 407 | 000522.png 408 | 000941.png 409 | 000115.png 410 | 000191.png 411 | 000211.png 412 | 000082.png 413 | 000553.png 414 | 000983.png 415 | 001277.png 416 | 001142.png 417 | 000939.png 418 | 001233.png 419 | 000676.png 420 | 000348.png 421 | 000421.png 422 | 000552.png 423 | 001107.png 424 | 000287.png 425 | 000074.png 426 | 000470.png 427 | 000570.png 428 | 000332.png 429 | 000409.png 430 | 000520.png 431 | 001187.png 432 | 000851.png 433 | 000221.png 434 | 001071.png 435 | 001086.png 436 | 000932.png 437 | 000968.png 438 | 000990.png 439 | 000089.png 440 | 000094.png 441 | 000399.png 442 | 000223.png 443 | 000015.png 444 | 000677.png 445 | 000826.png 446 | 001091.png 447 | 000650.png 448 | 001126.png 449 | 000633.png 450 | 000483.png 451 | 000723.png 452 | 000001.png 453 | 001265.png 454 | 000618.png 455 | 000535.png 456 | 000665.png 457 | 000729.png 458 | 000715.png 459 | 000451.png 460 | 000334.png 461 | 000188.png 462 | 000058.png 463 | 000208.png 464 | 001098.png 465 | 001035.png 466 | 001074.png 467 | 001182.png 468 | 001173.png 469 | 001256.png 470 | 000352.png 471 | 000303.png 472 | 000738.png 473 | 000784.png 474 | 000237.png 475 | 000704.png 476 | 000978.png 477 | 000732.png 478 | 000351.png 479 | 000642.png 480 | 000407.png 481 | 000244.png 482 | 000404.png 483 | 000482.png 484 | 001117.png 485 | 001297.png 486 | 000800.png 487 | 000568.png 488 | 001057.png 489 | 001294.png 490 | 000457.png 491 | 000582.png 492 | 000460.png 493 | 001054.png 494 | 000731.png 495 | 000583.png 496 | 000201.png 497 | 000134.png 498 | 000654.png 499 | 000630.png 500 | 000253.png 501 | 000722.png 502 | 000666.png 503 | 001208.png 504 | 000572.png 505 | 001193.png 506 | 000355.png 507 | 000818.png 508 | 000828.png 509 | 000379.png 510 | 001321.png 511 | 001003.png 512 | 000143.png 513 | 000021.png 514 | 001015.png 515 | 001309.png 516 | 000377.png 517 | 001082.png 518 | 001141.png 519 | 000085.png 520 | 000616.png 521 | 000700.png 522 | 000067.png 523 | 000123.png 524 | 000771.png 525 | 000433.png 526 | 001298.png 527 | 000791.png 528 | 000390.png 529 | 001153.png 530 | 001013.png 531 | 000789.png 532 | 000468.png 533 | 000953.png 534 | 000263.png 535 | 000207.png 536 | 001272.png 537 | 000550.png 538 | 000444.png 539 | 000171.png 540 | 000335.png 541 | 000228.png 542 | 000426.png 543 | 000218.png 544 | 000909.png 545 | 000547.png 546 | 000378.png 547 | 000707.png 548 | 000202.png 549 | 000706.png 550 | 001079.png 551 | 001149.png 552 | 000105.png 553 | 000497.png 554 | 000914.png 555 | 001204.png 556 | 000873.png 557 | 000954.png 558 | 000919.png 559 | 000446.png 560 | 000424.png 561 | 000940.png 562 | 000922.png 563 | 000103.png 564 | 000415.png 565 | 001096.png 566 | 001285.png 567 | 000127.png 568 | 000186.png 569 | 000944.png 570 | 000063.png 571 | 000086.png 572 | 000915.png 573 | 001234.png 574 | 001324.png 575 | 000837.png 576 | 000081.png 577 | 000632.png 578 | 000597.png 579 | 001150.png 580 | 000442.png 581 | 000646.png 582 | 000140.png 583 | 000518.png 584 | 000458.png 585 | 000792.png 586 | 000297.png 587 | 000071.png 588 | 001041.png 589 | 000163.png 590 | 000161.png 591 | 000061.png 592 | 000912.png 593 | 000463.png 594 | 000129.png 595 | 000849.png 596 | 000565.png 597 | 000864.png 598 | 000975.png 599 | 000712.png 600 | 000079.png 601 | 000469.png 602 | 000148.png 603 | 000977.png 604 | 001103.png 605 | 001059.png 606 | 000548.png 607 | 000160.png 608 | 001317.png 609 | 000255.png 610 | 000739.png 611 | 001174.png 612 | 000546.png 613 | 000853.png 614 | 000556.png 615 | 001241.png 616 | 000153.png 617 | 000883.png 618 | 000829.png 619 | 000152.png 620 | 000733.png 621 | 001253.png 622 | 000122.png 623 | 000911.png 624 | 000251.png 625 | 000710.png 626 | 000684.png 627 | 001002.png 628 | 000452.png 629 | 000341.png 630 | 000247.png 631 | 000031.png 632 | 000918.png 633 | 000696.png 634 | 000989.png 635 | 000778.png 636 | 000821.png 637 | 000490.png 638 | 001326.png 639 | 000280.png 640 | 000350.png 641 | 000099.png 642 | 001148.png 643 | 000601.png 644 | 000827.png 645 | 000657.png 646 | 000586.png 647 | 000776.png 648 | 000612.png 649 | 000022.png 650 | 001020.png 651 | 000935.png 652 | 000227.png 653 | 000027.png 654 | 000877.png 655 | 000745.png 656 | 000317.png 657 | 000656.png 658 | 001072.png 659 | 000832.png 660 | 001319.png 661 | 001179.png 662 | 000672.png 663 | 000361.png 664 | 001058.png 665 | 001290.png 666 | 000393.png 667 | 000474.png 668 | 001159.png 669 | 000781.png 670 | 000480.png 671 | 000184.png 672 | 001278.png 673 | 000062.png 674 | 000254.png 675 | 000224.png 676 | 001027.png 677 | 000524.png 678 | 000187.png 679 | 000257.png 680 | 000772.png 681 | 000090.png 682 | 000339.png 683 | 001031.png 684 | 000455.png 685 | 001131.png 686 | 000916.png 687 | 000701.png 688 | 000046.png 689 | 001226.png 690 | 000113.png 691 | 000118.png 692 | 000705.png 693 | 000266.png 694 | 000181.png 695 | 000420.png 696 | 000933.png 697 | 000323.png 698 | 000484.png 699 | 000961.png 700 | 000774.png 701 | 001199.png 702 | 000238.png 703 | 001136.png 704 | 000494.png 705 | 001302.png 706 | 000536.png 707 | 001273.png 708 | 000120.png 709 | 000862.png 710 | 000088.png 711 | 000128.png 712 | 000804.png 713 | 001240.png 714 | 001005.png 715 | 000640.png 716 | 001110.png 717 | 000292.png 718 | 000559.png 719 | 000456.png 720 | 001300.png 721 | 000316.png 722 | 001235.png 723 | 000356.png 724 | 000471.png 725 | 000414.png 726 | 000675.png 727 | 000995.png 728 | 001134.png 729 | 001237.png 730 | 000373.png 731 | 000808.png 732 | 000680.png 733 | 000010.png 734 | 000678.png 735 | 001325.png 736 | 001202.png 737 | 000037.png 738 | 000980.png 739 | 000647.png 740 | 000988.png 741 | 000466.png 742 | 000982.png 743 | 001040.png 744 | 000381.png 745 | 000702.png 746 | 000959.png 747 | 000754.png 748 | 001288.png 749 | 000931.png 750 | 000602.png 751 | 000009.png 752 | 000564.png 753 | 000139.png 754 | 000695.png 755 | 001023.png 756 | 000475.png 757 | 001243.png 758 | 000594.png 759 | 001034.png 760 | 000252.png 761 | 000551.png 762 | 000521.png 763 | 000844.png 764 | 000682.png 765 | 000261.png 766 | 000974.png 767 | 000638.png 768 | 000385.png 769 | 000668.png 770 | 001062.png 771 | 000364.png 772 | 000830.png 773 | 000243.png 774 | 001212.png 775 | 001044.png 776 | 000578.png 777 | 000897.png 778 | 000219.png 779 | 001249.png 780 | 001307.png 781 | 001196.png 782 | 000943.png 783 | 000319.png 784 | 000591.png 785 | 001310.png 786 | 000507.png 787 | 000589.png 788 | 000542.png 789 | 001083.png 790 | 000172.png 791 | 000262.png 792 | 000569.png 793 | 000581.png 794 | 000034.png 795 | 000605.png 796 | 000540.png 797 | 000166.png 798 | 000019.png 799 | 000627.png 800 | 000226.png 801 | 000354.png 802 | 000467.png 803 | 000025.png 804 | 000052.png 805 | 000231.png 806 | 000834.png 807 | 000608.png 808 | 000856.png 809 | 000887.png 810 | 000969.png 811 | 000841.png 812 | 000220.png 813 | 000847.png 814 | 001280.png 815 | 000913.png 816 | 001112.png 817 | 000121.png 818 | 000679.png 819 | 001304.png 820 | 001186.png 821 | 001123.png 822 | 000501.png 823 | 000886.png 824 | 000734.png 825 | 000560.png 826 | 000898.png 827 | 000806.png 828 | 000070.png 829 | 000384.png 830 | 001164.png 831 | 000295.png 832 | 001084.png 833 | 000093.png 834 | 000614.png 835 | 001197.png 836 | 000543.png 837 | 000125.png 838 | 001135.png 839 | 000985.png 840 | 000965.png 841 | 001245.png 842 | 001303.png 843 | 001228.png 844 | 001146.png 845 | 000275.png 846 | 000269.png 847 | 000615.png 848 | 000190.png 849 | 000813.png 850 | 000073.png 851 | 000892.png 852 | 000239.png 853 | 000823.png 854 | 001320.png 855 | 000440.png 856 | 000788.png 857 | 000308.png 858 | 000852.png 859 | 000359.png 860 | 001151.png 861 | 000836.png 862 | 001022.png 863 | 000528.png 864 | 001190.png 865 | 000930.png 866 | 000759.png 867 | 000604.png 868 | 000872.png 869 | 001081.png 870 | 001037.png 871 | 000154.png 872 | 000948.png 873 | 001101.png 874 | 000396.png 875 | 000336.png 876 | 000281.png 877 | 000150.png 878 | 000769.png 879 | 000439.png 880 | 001184.png 881 | 000625.png 882 | 000726.png 883 | 001121.png 884 | 000531.png 885 | 000613.png 886 | 000305.png 887 | 001181.png 888 | 000749.png 889 | 000375.png 890 | 001129.png 891 | 000492.png 892 | 000964.png 893 | 000992.png 894 | 000008.png 895 | 000895.png 896 | 001088.png 897 | 001218.png 898 | 000002.png 899 | 000109.png 900 | 000762.png 901 | 000947.png 902 | 001188.png 903 | 000014.png 904 | 000498.png 905 | 001132.png 906 | 000098.png 907 | 000979.png 908 | 000233.png 909 | 000596.png 910 | 000230.png 911 | 001215.png 912 | 000209.png 913 | 000998.png 914 | 000435.png 915 | 000156.png 916 | 000645.png 917 | 000773.png 918 | 000406.png 919 | 000259.png 920 | 000290.png 921 | 000906.png 922 | 000752.png 923 | 000157.png 924 | 001064.png 925 | 000047.png 926 | 000288.png 927 | 000846.png 928 | 000634.png 929 | 001115.png 930 | 001301.png 931 | 000344.png 932 | 000603.png 933 | 000950.png 934 | 000075.png 935 | 000967.png 936 | 001275.png 937 | 001232.png 938 | 000687.png 939 | 000402.png 940 | 001239.png 941 | 000860.png 942 | 000437.png 943 | 000279.png 944 | 000626.png 945 | 000360.png 946 | 000512.png 947 | 000357.png 948 | 001004.png 949 | 000478.png 950 | 000817.png 951 | 000353.png 952 | 001144.png 953 | 000168.png 954 | 000383.png 955 | 001198.png 956 | 000910.png 957 | 001140.png 958 | 000138.png 959 | 001085.png 960 | 000750.png 961 | 000309.png 962 | 000503.png 963 | 000039.png 964 | 000970.png 965 | 000106.png 966 | 001100.png 967 | 000766.png 968 | 000135.png 969 | 000465.png 970 | 001030.png 971 | 001067.png 972 | 001076.png 973 | 000443.png 974 | 000307.png 975 | 000567.png 976 | 000048.png 977 | 001028.png 978 | 000260.png 979 | 000880.png 980 | 001175.png 981 | 000765.png 982 | 000599.png 983 | 001021.png 984 | 000405.png 985 | 001052.png 986 | 000504.png 987 | 001119.png 988 | 001104.png 989 | 001050.png 990 | 000908.png 991 | 001327.png 992 | 000185.png 993 | 001201.png 994 | 000981.png 995 | 000277.png 996 | 000815.png 997 | 001213.png 998 | 000902.png 999 | 000182.png 1000 | 000476.png 1001 | 000744.png 1002 | 001080.png 1003 | 001048.png 1004 | 000720.png 1005 | 000717.png 1006 | 000258.png 1007 | 000956.png 1008 | 001070.png 1009 | 001223.png 1010 | 000622.png 1011 | 000516.png 1012 | 000671.png 1013 | 000267.png 1014 | 000388.png 1015 | 000562.png 1016 | 001229.png 1017 | 001026.png 1018 | 000485.png 1019 | 000417.png 1020 | 000809.png 1021 | 000775.png 1022 | 000697.png 1023 | 000141.png 1024 | 001128.png 1025 | 000787.png 1026 | 000924.png 1027 | 001172.png 1028 | 001138.png 1029 | 000688.png 1030 | 000206.png 1031 | 000861.png 1032 | 000176.png 1033 | 000835.png 1034 | 000173.png 1035 | 000807.png 1036 | 001210.png 1037 | 000340.png 1038 | 001116.png 1039 | 000249.png 1040 | 000741.png 1041 | 000571.png 1042 | 001283.png 1043 | 000246.png 1044 | 000214.png 1045 | 001029.png 1046 | 000304.png 1047 | 000198.png 1048 | 000164.png 1049 | 000366.png 1050 | 000151.png 1051 | 000799.png 1052 | 000196.png 1053 | 001053.png 1054 | 000114.png 1055 | 001089.png 1056 | 000798.png 1057 | 000921.png 1058 | 000461.png 1059 | 000511.png 1060 | 000389.png 1061 | 000017.png 1062 | -------------------------------------------------------------------------------- /datasets/Sirstv2_512/test.txt: -------------------------------------------------------------------------------- 1 | Misc_70 2 | Misc_214 3 | Misc_96 4 | Misc_311 5 | Misc_158 6 | Misc_127 7 | Misc_323 8 | Misc_54 9 | Misc_396 10 | Misc_6 11 | Misc_276 12 | Misc_172 13 | Misc_153 14 | Misc_138 15 | Misc_295 16 | Misc_139 17 | Misc_179 18 | Misc_79 19 | Misc_386 20 | Misc_58 21 | Misc_171 22 | Misc_303 23 | Misc_316 24 | Misc_343 25 | Misc_264 26 | Misc_274 27 | Misc_154 28 | Misc_174 29 | Misc_238 30 | Misc_23 31 | Misc_270 32 | Misc_110 33 | Misc_413 34 | Misc_345 35 | Misc_205 36 | Misc_243 37 | Misc_387 38 | Misc_258 39 | Misc_95 40 | Misc_277 41 | Misc_116 42 | Misc_229 43 | Misc_145 44 | Misc_34 45 | Misc_156 46 | Misc_209 47 | Misc_25 48 | Misc_47 49 | Misc_357 50 | Misc_349 51 | Misc_395 52 | Misc_346 53 | Misc_43 54 | Misc_66 55 | Misc_92 56 | Misc_176 57 | Misc_129 58 | Misc_15 59 | Misc_29 60 | Misc_248 61 | Misc_235 62 | Misc_423 63 | Misc_374 64 | Misc_194 65 | Misc_220 66 | Misc_347 67 | Misc_120 68 | Misc_382 69 | Misc_192 70 | Misc_240 71 | Misc_175 72 | Misc_203 73 | Misc_91 74 | Misc_356 75 | Misc_111 76 | Misc_124 77 | Misc_306 78 | Misc_317 79 | Misc_263 80 | Misc_8 81 | Misc_72 82 | Misc_250 83 | Misc_419 84 | Misc_394 85 | Misc_331 86 | Misc_160 -------------------------------------------------------------------------------- /datasets/Sirstv2_512/trainval.txt: -------------------------------------------------------------------------------- 1 | Misc_181 2 | Misc_366 3 | Misc_418 4 | Misc_189 5 | Misc_265 6 | Misc_342 7 | Misc_393 8 | Misc_13 9 | Misc_202 10 | Misc_402 11 | Misc_325 12 | Misc_365 13 | Misc_33 14 | Misc_198 15 | Misc_52 16 | Misc_320 17 | Misc_4 18 | Misc_64 19 | Misc_426 20 | Misc_379 21 | Misc_30 22 | Misc_2 23 | Misc_378 24 | Misc_271 25 | Misc_404 26 | Misc_118 27 | Misc_302 28 | Misc_17 29 | Misc_68 30 | Misc_50 31 | Misc_245 32 | Misc_391 33 | Misc_221 34 | Misc_392 35 | Misc_242 36 | Misc_287 37 | Misc_7 38 | Misc_155 39 | Misc_232 40 | Misc_85 41 | Misc_210 42 | Misc_143 43 | Misc_3 44 | Misc_360 45 | Misc_162 46 | Misc_422 47 | Misc_99 48 | Misc_244 49 | Misc_369 50 | Misc_101 51 | Misc_186 52 | Misc_267 53 | Misc_106 54 | Misc_164 55 | Misc_226 56 | Misc_137 57 | Misc_257 58 | Misc_90 59 | Misc_146 60 | Misc_309 61 | Misc_126 62 | Misc_65 63 | Misc_334 64 | Misc_305 65 | Misc_367 66 | Misc_321 67 | Misc_142 68 | Misc_239 69 | Misc_247 70 | Misc_370 71 | Misc_76 72 | Misc_281 73 | Misc_104 74 | Misc_31 75 | Misc_69 76 | Misc_353 77 | Misc_384 78 | Misc_254 79 | Misc_406 80 | Misc_424 81 | Misc_318 82 | Misc_180 83 | Misc_427 84 | Misc_48 85 | Misc_125 86 | Misc_400 87 | Misc_272 88 | Misc_412 89 | Misc_350 90 | Misc_206 91 | Misc_130 92 | Misc_341 93 | Misc_256 94 | Misc_97 95 | Misc_152 96 | Misc_35 97 | Misc_409 98 | Misc_87 99 | Misc_222 100 | Misc_269 101 | Misc_20 102 | Misc_56 103 | Misc_328 104 | Misc_169 105 | Misc_283 106 | Misc_45 107 | Misc_388 108 | Misc_266 109 | Misc_217 110 | Misc_219 111 | Misc_314 112 | Misc_312 113 | Misc_55 114 | Misc_185 115 | Misc_344 116 | Misc_134 117 | Misc_236 118 | Misc_108 119 | Misc_28 120 | Misc_131 121 | Misc_102 122 | Misc_12 123 | Misc_1 124 | Misc_313 125 | Misc_414 126 | Misc_416 127 | Misc_100 128 | Misc_183 129 | Misc_285 130 | Misc_178 131 | Misc_280 132 | Misc_289 133 | Misc_170 134 | Misc_119 135 | Misc_278 136 | Misc_147 137 | Misc_159 138 | Misc_223 139 | Misc_136 140 | Misc_420 141 | Misc_63 142 | Misc_298 143 | Misc_71 144 | Misc_199 145 | Misc_253 146 | Misc_284 147 | Misc_337 148 | Misc_89 149 | Misc_39 150 | Misc_73 151 | Misc_301 152 | Misc_78 153 | Misc_57 154 | Misc_282 155 | Misc_405 156 | Misc_16 157 | Misc_308 158 | Misc_399 159 | Misc_227 160 | Misc_372 161 | Misc_373 162 | Misc_177 163 | Misc_18 164 | Misc_296 165 | Misc_380 166 | Misc_128 167 | Misc_184 168 | Misc_207 169 | Misc_188 170 | Misc_74 171 | Misc_14 172 | Misc_324 173 | Misc_200 174 | Misc_421 175 | Misc_403 176 | Misc_290 177 | Misc_237 178 | Misc_41 179 | Misc_140 180 | Misc_415 181 | Misc_326 182 | Misc_123 183 | Misc_193 184 | Misc_348 185 | Misc_88 186 | Misc_46 187 | Misc_144 188 | Misc_339 189 | Misc_80 190 | Misc_251 191 | Misc_173 192 | Misc_49 193 | Misc_150 194 | Misc_304 195 | Misc_246 196 | Misc_216 197 | Misc_401 198 | Misc_351 199 | Misc_208 200 | Misc_141 201 | Misc_109 202 | Misc_40 203 | Misc_22 204 | Misc_299 205 | Misc_261 206 | Misc_191 207 | Misc_21 208 | Misc_364 209 | Misc_182 210 | Misc_24 211 | Misc_197 212 | Misc_385 213 | Misc_417 214 | Misc_230 215 | Misc_275 216 | Misc_224 217 | Misc_362 218 | Misc_249 219 | Misc_335 220 | Misc_389 221 | Misc_300 222 | Misc_115 223 | Misc_355 224 | Misc_390 225 | Misc_233 226 | Misc_204 227 | Misc_121 228 | Misc_327 229 | Misc_213 230 | Misc_268 231 | Misc_133 232 | Misc_361 233 | Misc_82 234 | Misc_62 235 | Misc_291 236 | Misc_195 237 | Misc_117 238 | Misc_42 239 | Misc_165 240 | Misc_292 241 | Misc_322 242 | Misc_377 243 | Misc_259 244 | Misc_407 245 | Misc_329 246 | Misc_398 247 | Misc_330 248 | Misc_149 249 | Misc_354 250 | Misc_59 251 | Misc_166 252 | Misc_83 253 | Misc_86 254 | Misc_352 255 | Misc_218 256 | Misc_26 257 | Misc_53 258 | Misc_163 259 | Misc_375 260 | Misc_5 261 | Misc_132 262 | Misc_84 263 | Misc_358 264 | Misc_190 265 | Misc_228 266 | Misc_112 267 | Misc_294 268 | Misc_408 269 | Misc_103 270 | Misc_93 271 | Misc_397 272 | Misc_10 273 | Misc_67 274 | Misc_9 275 | Misc_148 276 | Misc_211 277 | Misc_425 278 | Misc_225 279 | Misc_310 280 | Misc_376 281 | Misc_187 282 | Misc_196 283 | Misc_297 284 | Misc_336 285 | Misc_252 286 | Misc_411 287 | Misc_37 288 | Misc_75 289 | Misc_368 290 | Misc_98 291 | Misc_94 292 | Misc_307 293 | Misc_255 294 | Misc_107 295 | Misc_234 296 | Misc_383 297 | Misc_381 298 | Misc_105 299 | Misc_201 300 | Misc_340 301 | Misc_135 302 | Misc_338 303 | Misc_51 304 | Misc_279 305 | Misc_410 306 | Misc_151 307 | Misc_212 308 | Misc_27 309 | Misc_371 310 | Misc_32 311 | Misc_262 312 | Misc_215 313 | Misc_44 314 | Misc_293 315 | Misc_260 316 | Misc_359 317 | Misc_11 318 | Misc_81 319 | Misc_333 320 | Misc_61 321 | Misc_288 322 | Misc_77 323 | Misc_286 324 | Misc_168 325 | Misc_114 326 | Misc_60 327 | Misc_38 328 | Misc_315 329 | Misc_36 330 | Misc_273 331 | Misc_363 332 | Misc_122 333 | Misc_167 334 | Misc_241 335 | Misc_231 336 | Misc_161 337 | Misc_113 338 | Misc_332 339 | Misc_319 340 | Misc_19 341 | Misc_157 -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright by HQ-SAM team 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import os.path as ops 9 | from tqdm import tqdm 10 | import argparse 11 | import logging 12 | import numpy as np 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | import torch.distributed as dist 17 | import torch.nn.functional as F 18 | import torchvision.transforms as T 19 | from torch.autograd import Variable 20 | import matplotlib.pyplot as plt 21 | import cv2 22 | import random 23 | from typing import Dict, List, Tuple 24 | 25 | from thop import profile 26 | 27 | from segment_anything_training.build_IRSAM import build_sam_IRSAM 28 | 29 | from utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter 30 | from utils.metrics import SigmoidMetric, SamplewiseSigmoidMetric 31 | from utils.metric import PD_FA, ROCMetric 32 | from utils.loss_mask import DICE_loss 33 | from utils.log import initialize_logger 34 | import utils.misc as misc 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 37 | 38 | def get_args_parser(): 39 | parser = argparse.ArgumentParser('HQ-SAM', add_help=False) 40 | 41 | parser.add_argument("--output", type=str, required=True, 42 | help="Path to the directory where masks and checkpoints will be output") 43 | parser.add_argument("--model_type", type=str, default="vit_l", 44 | help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']") 45 | parser.add_argument("--checkpoint", type=str, required=True, 46 | help="The path to the SAM checkpoint to use for mask generation.") 47 | parser.add_argument("--no_prompt_checkpoint", type=str, default=None, 48 | help="The path to the SAM checkpoint trained with no prompt") 49 | parser.add_argument("--device", type=str, default="cuda", 50 | help="The device to run generation on.") 51 | 52 | parser.add_argument('--learning_rate', default=1e-4, type=float) 53 | parser.add_argument('--start_epoch', default=0, type=int) 54 | parser.add_argument('--lr_drop_epoch', default=10, type=int) 55 | parser.add_argument('--max_epoch_num', default=1001, type=int) 56 | parser.add_argument('--dataloader_size', default=[512, 512], type=list) 57 | parser.add_argument('--batch_size_train', default=4, type=int) 58 | parser.add_argument('--batch_size_valid', default=1, type=int) 59 | parser.add_argument('--model_save_fre', default=10, type=int) 60 | 61 | parser.add_argument('--eval', action='store_true') 62 | parser.add_argument('--visualize', action='store_true') 63 | parser.add_argument("--restore-model", type=str, 64 | help="The path to the hq_decoder training checkpoint for evaluation") 65 | 66 | return parser.parse_args() 67 | 68 | 69 | def main(valid_datasets, args): 70 | # --- Step 1: Valid dataset --- 71 | print("--- create valid dataloader ---") 72 | valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") 73 | valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list, 74 | my_transforms=[ 75 | Resize(args.dataloader_size) 76 | ], 77 | batch_size=args.batch_size_valid, 78 | training=False) 79 | print(len(valid_dataloaders), " valid dataloaders created") 80 | 81 | # --- Step 2: Load pretrained Network--- 82 | net = build_sam_IRSAM(checkpoint=args.checkpoint) 83 | if torch.cuda.is_available(): 84 | net.cuda() 85 | 86 | # --- Step 3: Train or Evaluate --- 87 | if args.eval: 88 | if args.restore_model: 89 | print("restore model from:", args.restore_model) 90 | if torch.cuda.is_available(): 91 | net.load_state_dict(torch.load(args.restore_model)) 92 | else: 93 | net.load_state_dict(torch.load(args.restore_model, map_location="cpu")) 94 | 95 | evaluate(net, valid_dataloaders) 96 | 97 | def evaluate(net, valid_dataloaders): 98 | net.eval() 99 | metric = dict() 100 | 101 | IoU_metric = SigmoidMetric() 102 | nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=0.5) 103 | 104 | ROC = ROCMetric(1, 10) 105 | Pd_Fa = PD_FA(1, 10) 106 | 107 | IoU_metric.reset() 108 | nIoU_metric.reset() 109 | Pd_Fa.reset() 110 | for k in range(len(valid_dataloaders)): 111 | valid_dataloader = valid_dataloaders[k] 112 | 113 | tbar = tqdm(valid_dataloader) 114 | for data_val in tbar: 115 | imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val[ 116 | 'label'], data_val['shape'], data_val['ori_label'] 117 | 118 | if torch.cuda.is_available(): 119 | inputs_val = inputs_val.cuda() 120 | labels_ori = labels_ori.cuda() 121 | 122 | imgs = inputs_val.permute(0, 2, 3, 1).cpu().numpy() 123 | 124 | batched_input = [] 125 | for b_i in range(len(imgs)): 126 | dict_input = dict() 127 | input_image = (torch.as_tensor((imgs[b_i]).astype(dtype=np.uint8), device=net.device) 128 | .permute(2, 0, 1).contiguous()) 129 | dict_input['image'] = input_image 130 | dict_input['original_size'] = imgs[b_i].shape[:2] 131 | batched_input.append(dict_input) 132 | 133 | masks, edges = net(batched_input) 134 | 135 | torch.cuda.synchronize() 136 | 137 | IoU_metric.update(masks.cpu(), (labels_ori / 255.).cpu().detach()) 138 | nIoU_metric.update(masks.cpu(), (labels_ori / 255.).cpu().detach()) 139 | Pd_Fa.update(masks.cpu(), (labels_ori / 255.).cpu().detach()) 140 | 141 | FA, PD = Pd_Fa.get(len(valid_dataloader)) 142 | _, IoU = IoU_metric.get() 143 | _, nIoU = nIoU_metric.get() 144 | 145 | tbar.set_description('IoU:%f, nIoU:%f, PD:%.8lf, FA:%.8lf' 146 | % (IoU, nIoU, PD[0], FA[0])) 147 | 148 | metric['iou'] = IoU 149 | metric['niou'] = nIoU 150 | metric['pd'] = PD[0] 151 | metric['fa'] = FA[0] 152 | return metric 153 | 154 | 155 | if __name__ == "__main__": 156 | # --------------- Configuring the Valid datasets --------------- 157 | dataset_val_nuaa = {"name": "Sirstv2_512", 158 | "im_dir": "datasets/Sirstv2_512/test_images", 159 | "gt_dir": "datasets/Sirstv2_512/test_masks", 160 | "im_ext": ".png", 161 | "gt_ext": ".png"} 162 | 163 | dataset_val_NUDT = {"name": "NUDT", 164 | "im_dir": "datasets/NUDT-SIRST00/test_images", 165 | "gt_dir": "datasets/NUDT-SIRST00/test_masks", 166 | "im_ext": ".png", 167 | "gt_ext": ".png"} 168 | 169 | dataset_val_IRSTD = {"name": "IRSTD", 170 | "im_dir": "datasets/IRSTD-1k/test_images", 171 | "gt_dir": "datasets/IRSTD-1k/test_masks", 172 | "im_ext": ".png", 173 | "gt_ext": ".png"} 174 | 175 | valid_datasets = [dataset_val_nuaa] 176 | 177 | args = get_args_parser() 178 | 179 | main(valid_datasets, args) 180 | -------------------------------------------------------------------------------- /segment_anything_training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | -------------------------------------------------------------------------------- /segment_anything_training/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/__pycache__/build_IRSAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/__pycache__/build_IRSAM.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/__pycache__/build_sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/__pycache__/build_sam.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/build_IRSAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | from .modeling import MaskDecoder, PromptEncoder, TwoWayTransformer, Sam 5 | from .modeling.IRSAM_decoder import MaskDecoder as EdgeDecoder 6 | from .modeling.IRSAM_encoder import TinyViT as EdgeEncoder 7 | from .modeling.IRSAM_edge import Sam as EdgeIRSAM 8 | 9 | 10 | def build_sam_IRSAM(checkpoint=None): 11 | prompt_embed_dim = 256 12 | image_size = 1024 13 | vit_patch_size = 16 14 | image_embedding_size = image_size // vit_patch_size 15 | mobile_sam = EdgeIRSAM( 16 | image_encoder=EdgeEncoder(img_size=1024, in_chans=3, num_classes=1000, 17 | embed_dims=[64, 128, 160, 320], 18 | depths=[2, 2, 6, 2], 19 | num_heads=[2, 4, 5, 10], 20 | window_sizes=[7, 7, 14, 7], 21 | mlp_ratio=4., 22 | drop_rate=0., 23 | drop_path_rate=0.0, 24 | use_checkpoint=False, 25 | mbconv_expand_ratio=4.0, 26 | local_conv_size=3, 27 | layer_lr_decay=0.8 28 | ), 29 | prompt_encoder=PromptEncoder( 30 | embed_dim=prompt_embed_dim, 31 | image_embedding_size=(image_embedding_size, image_embedding_size), 32 | input_image_size=(image_size, image_size), 33 | mask_in_chans=16, 34 | ), 35 | mask_decoder=EdgeDecoder( 36 | transformer=TwoWayTransformer( 37 | depth=2, 38 | embedding_dim=prompt_embed_dim, 39 | mlp_dim=2048, 40 | num_heads=8, 41 | ), 42 | transformer_dim=prompt_embed_dim, 43 | ), 44 | pixel_mean=[123.675, 116.28, 103.53], 45 | pixel_std=[58.395, 57.12, 57.375], 46 | ) 47 | 48 | mobile_sam.eval() 49 | if checkpoint is not None: 50 | with open(checkpoint, "rb") as f: 51 | state_dict = torch.load(f) 52 | mobile_sam.load_state_dict(state_dict, strict=False) 53 | return mobile_sam -------------------------------------------------------------------------------- /segment_anything_training/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, MaskDecoder_test 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | sam_model_registry = { 47 | "default": build_sam_vit_h, 48 | "vit_h": build_sam_vit_h, 49 | "vit_l": build_sam_vit_l, 50 | "vit_b": build_sam_vit_b, 51 | } 52 | 53 | 54 | def _build_sam( 55 | encoder_embed_dim, 56 | encoder_depth, 57 | encoder_num_heads, 58 | encoder_global_attn_indexes, 59 | checkpoint=None, 60 | ): 61 | prompt_embed_dim = 256 62 | image_size = 1024 63 | vit_patch_size = 16 64 | image_embedding_size = image_size // vit_patch_size 65 | sam = Sam_test( 66 | image_encoder=ImageEncoderViT( 67 | depth=encoder_depth, 68 | embed_dim=encoder_embed_dim, 69 | img_size=image_size, 70 | mlp_ratio=4, 71 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 72 | num_heads=encoder_num_heads, 73 | patch_size=vit_patch_size, 74 | qkv_bias=True, 75 | use_rel_pos=True, 76 | global_attn_indexes=encoder_global_attn_indexes, 77 | window_size=14, 78 | out_chans=prompt_embed_dim, 79 | ), 80 | prompt_encoder=PromptEncoder( 81 | embed_dim=prompt_embed_dim, 82 | image_embedding_size=(image_embedding_size, image_embedding_size), 83 | input_image_size=(image_size, image_size), 84 | mask_in_chans=16, 85 | ), 86 | mask_decoder=MaskDecoder_test( 87 | num_multimask_outputs=3, 88 | transformer=TwoWayTransformer( 89 | depth=2, 90 | embedding_dim=prompt_embed_dim, 91 | mlp_dim=2048, 92 | num_heads=8, 93 | ), 94 | transformer_dim=prompt_embed_dim, 95 | iou_head_depth=3, 96 | iou_head_hidden_dim=256, 97 | ), 98 | pixel_mean=[123.675, 116.28, 103.53], 99 | pixel_std=[58.395, 57.12, 57.375], 100 | ) 101 | sam.eval() 102 | if checkpoint is not None: 103 | with open(checkpoint, "rb") as f: 104 | state_dict = torch.load(f) 105 | sam.image_encoder.load_state_dict(state_dict, strict=False) 106 | return sam 107 | 108 | 109 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/IRSAM_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | # edge tokens 72 | self.edge_token = nn.Embedding(1, transformer_dim) 73 | self.edge_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 74 | self.num_mask_tokens = self.num_mask_tokens + 1 75 | 76 | self.compress_vit_feat = nn.Sequential( 77 | nn.ConvTranspose2d(160, transformer_dim, 2, 2), 78 | LayerNorm2d(transformer_dim), 79 | nn.GELU(), 80 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, 2, 2) 81 | ) 82 | self.embedding_encoder = nn.Sequential( 83 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, 2, 2), 84 | LayerNorm2d(transformer_dim // 4), 85 | nn.GELU(), 86 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, 2, 2) 87 | ) 88 | self.embedding_maskfeature = nn.Sequential( 89 | nn.ConvTranspose2d(transformer_dim // 8, transformer_dim // 4, 3,1,1), 90 | LayerNorm2d(transformer_dim // 4), 91 | nn.GELU(), 92 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, 3,1,1) 93 | ) 94 | self.sigmoid = nn.Sigmoid() 95 | 96 | def forward( 97 | self, 98 | image_embeddings: torch.Tensor, 99 | edge_embeddings: torch.Tensor, 100 | image_pe: torch.Tensor, 101 | sparse_prompt_embeddings: torch.Tensor, 102 | dense_prompt_embeddings: torch.Tensor, 103 | multimask_output: bool = None, 104 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 105 | """ 106 | Predict masks given image and prompt embeddings. 107 | 108 | Arguments: 109 | image_embeddings (torch.Tensor): the embeddings from the image encoder 110 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 111 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 112 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 113 | multimask_output (bool): Whether to return multiple masks or a single 114 | mask. 115 | 116 | Returns: 117 | torch.Tensor: batched predicted masks 118 | torch.Tensor: batched predictions of mask quality 119 | """ 120 | edge_features = edge_embeddings.permute(0, 3, 1, 2) 121 | # edge_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(edge_features) # qian+shen 122 | # edge_features = self.compress_vit_feat(edge_features) # qian 123 | edge_features = self.embedding_encoder(image_embeddings) # shen 124 | 125 | masks, edges, iou_pred = self.predict_masks( 126 | image_embeddings=image_embeddings, 127 | edge_embeddings=edge_features, 128 | image_pe=image_pe, 129 | sparse_prompt_embeddings=sparse_prompt_embeddings, 130 | dense_prompt_embeddings=dense_prompt_embeddings, 131 | ) 132 | 133 | # Select the correct mask or masks for outptu 134 | if multimask_output: 135 | mask_slice = slice(1, None) 136 | else: 137 | mask_slice = slice(0, 1) 138 | masks = masks[:, mask_slice, :, :] 139 | iou_pred = iou_pred[:, mask_slice] 140 | 141 | # Prepare output 142 | return masks, edges, iou_pred 143 | 144 | def predict_masks( 145 | self, 146 | image_embeddings: torch.Tensor, 147 | edge_embeddings: torch.Tensor, 148 | image_pe: torch.Tensor, 149 | sparse_prompt_embeddings: torch.Tensor, 150 | dense_prompt_embeddings: torch.Tensor, 151 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 152 | """Predicts masks. See 'forward' for more details.""" 153 | # Concatenate output tokens 154 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.edge_token.weight], dim=0) 155 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 156 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 157 | 158 | # Expand per-image data in batch direction to be per-mask 159 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 160 | src = src + dense_prompt_embeddings 161 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 162 | b, c, h, w = src.shape 163 | 164 | # Run the transformer 165 | hs, src = self.transformer(src, pos_src, tokens) 166 | iou_token_out = hs[:, 0, :] 167 | mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] 168 | 169 | # Upscale mask embeddings and predict masks using the mask tokens 170 | src = src.transpose(1, 2).view(b, c, h, w) 171 | upscaled_embedding = self.output_upscaling(src) 172 | 173 | edge_embedding = self.embedding_maskfeature(upscaled_embedding) + edge_embeddings.repeat(b, 1, 1, 1) 174 | 175 | hyper_in_list: List[torch.Tensor] = [] 176 | for i in range(self.num_mask_tokens): 177 | if i < self.num_mask_tokens-1: 178 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 179 | else: 180 | hyper_in_list.append(self.edge_mlp(mask_tokens_out[:, i, :])) 181 | hyper_in = torch.stack(hyper_in_list, dim=1) 182 | 183 | b, c, h, w = upscaled_embedding.shape 184 | masks = (hyper_in[:, :self.num_mask_tokens-1] @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 185 | edge = (hyper_in[:, self.num_mask_tokens-1:] @ edge_embedding.view(b, c, h * w)).view(b, -1, h, w) 186 | 187 | edge = self.sigmoid(edge) 188 | 189 | masks = masks * edge + masks 190 | 191 | # Generate mask quality predictions 192 | iou_pred = self.iou_prediction_head(iou_token_out) 193 | 194 | return masks, edge, iou_pred 195 | 196 | 197 | # Lightly adapted from 198 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 199 | class MLP(nn.Module): 200 | def __init__( 201 | self, 202 | input_dim: int, 203 | hidden_dim: int, 204 | output_dim: int, 205 | num_layers: int, 206 | sigmoid_output: bool = False, 207 | ) -> None: 208 | super().__init__() 209 | self.num_layers = num_layers 210 | h = [hidden_dim] * (num_layers - 1) 211 | self.layers = nn.ModuleList( 212 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 213 | ) 214 | self.sigmoid_output = sigmoid_output 215 | 216 | def forward(self, x): 217 | for i, layer in enumerate(self.layers): 218 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 219 | if self.sigmoid_output: 220 | x = F.sigmoid(x) 221 | return x 222 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/IRSAM_edge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | import cv2 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os.path as ops 8 | import torch 9 | from torch import nn, Tensor 10 | from torch.nn import functional as F 11 | 12 | from typing import Any, Dict, List, Tuple, Union 13 | 14 | from .image_encoder import ImageEncoderViT 15 | from .mask_decoder import MaskDecoder 16 | from .prompt_encoder import PromptEncoder 17 | 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | class Sam(nn.Module): 22 | mask_threshold: float = 0.0 23 | image_format: str = "RGB" 24 | 25 | def __init__( 26 | self, 27 | image_encoder: ImageEncoderViT, 28 | prompt_encoder: PromptEncoder, 29 | mask_decoder: MaskDecoder, 30 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 31 | pixel_std: List[float] = [58.395, 57.12, 57.375], 32 | ) -> None: 33 | """ 34 | SAM predicts object masks from an image and input prompts. 35 | 36 | Arguments: 37 | image_encoder (ImageEncoderViT): The backbone used to encode the 38 | image into image embeddings that allow for efficient mask prediction. 39 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 40 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 41 | and encoded prompts. 42 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 43 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 44 | """ 45 | super().__init__() 46 | self.image_encoder = image_encoder 47 | self.prompt_encoder = prompt_encoder 48 | self.mask_decoder = mask_decoder 49 | 50 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 51 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 52 | 53 | @property 54 | def device(self) -> Any: 55 | return self.pixel_mean.device 56 | 57 | def forward( 58 | self, 59 | batched_input: List[Dict[str, Any]], 60 | ) -> [Tensor, Tensor]: 61 | """ 62 | Predicts masks end-to-end from provided images and prompts. 63 | If prompts are not known in advance, using SamPredictor is 64 | recommended over calling the model directly. 65 | 66 | Arguments: 67 | batched_input (list(dict)): A list over input images, each a 68 | dictionary with the following keys. A prompt key can be 69 | excluded if it is not present. 70 | 'image': The image as a torch tensor in 3xHxW format, 71 | already transformed for input to the model. 72 | 'original_size': (tuple(int, int)) The original size of 73 | the image before transformation, as (H, W). 74 | 'point_coords': (torch.Tensor) Batched point prompts for 75 | this image, with shape BxNx2. Already transformed to the 76 | input frame of the model. 77 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 78 | with shape BxN. 79 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 80 | Already transformed to the input frame of the model. 81 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 82 | in the form Bx1xHxW. 83 | multimask_output (bool): Whether the model should predict multiple 84 | disambiguating masks, or return a single mask. 85 | 86 | Returns: 87 | (list(dict)): A list over input images, where each element is 88 | as dictionary with the following keys. 89 | 'masks': (torch.Tensor) Batched binary mask predictions, 90 | with shape BxCxHxW, where B is the number of input promts, 91 | C is determiend by multimask_output, and (H, W) is the 92 | original size of the image. 93 | 'iou_predictions': (torch.Tensor) The model's predictions 94 | of mask quality, in shape BxC. 95 | 'low_res_logits': (torch.Tensor) Low resolution logits with 96 | shape BxCxHxW, where H=W=256. Can be passed as mask input 97 | to subsequent iterations of prediction. 98 | """ 99 | 100 | input_images = torch.cat([self.preprocess(x["image"]) for x in batched_input], dim=0) 101 | 102 | image_embeddings, edge_embeddings = self.image_encoder(input_images) 103 | # print("max:", torch.max(image_embeddings[0][211]), " min:", torch.min(image_embeddings[0][211])) 104 | # print(image_embeddings.shape) 105 | # 106 | # plt.imshow(image_embeddings[0][211].cpu().detach().numpy()*255.) 107 | # plt.show() 108 | # cv2.imwrite("show_sample.png", image_embeddings[0][211].cpu().detach().numpy()*255.) 109 | 110 | outputs = [] 111 | for image_record, curr_embedding, edge_embedding in zip(batched_input, image_embeddings, edge_embeddings): 112 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 113 | points=None, 114 | boxes=image_record.get("boxes", None), 115 | masks=image_record.get("mask_inputs", None), 116 | ) 117 | 118 | low_res_mask, low_res_edge, iou = self.mask_decoder( 119 | image_embeddings=curr_embedding.unsqueeze(0), 120 | edge_embeddings=edge_embedding.unsqueeze(0), 121 | image_pe=self.prompt_encoder.get_dense_pe(), 122 | sparse_prompt_embeddings=sparse_embeddings, 123 | dense_prompt_embeddings=dense_embeddings, 124 | ) 125 | 126 | mask = self.postprocess_masks( 127 | low_res_mask, 128 | input_size=image_record["image"].shape[-2:], 129 | original_size=image_record["original_size"], 130 | ) 131 | 132 | edge = self.postprocess_masks( 133 | low_res_edge, 134 | input_size=image_record["image"].shape[-2:], 135 | original_size=image_record["original_size"], 136 | ) 137 | 138 | outputs.append( 139 | { 140 | "mask": mask, 141 | "edge": edge, 142 | "low_res_logits": low_res_mask, 143 | } 144 | ) 145 | masks = torch.cat([x["mask"] for x in outputs], dim=0) 146 | edges = torch.cat([x["edge"] for x in outputs], dim=0) 147 | 148 | return masks, edges 149 | 150 | def postprocess_masks( 151 | self, 152 | masks: torch.Tensor, 153 | input_size: Tuple[int, ...], 154 | original_size: Tuple[int, ...], 155 | ) -> torch.Tensor: 156 | """ 157 | Remove padding and upscale masks to the original image size. 158 | 159 | Arguments: 160 | masks (torch.Tensor): Batched masks from the mask_decoder, 161 | in BxCxHxW format. 162 | input_size (tuple(int, int)): The size of the image input to the 163 | model, in (H, W) format. Used to remove padding. 164 | original_size (tuple(int, int)): The original size of the image 165 | before resizing for input to the model, in (H, W) format. 166 | 167 | Returns: 168 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 169 | is given by original_size. 170 | """ 171 | # plt.subplot(1, 2, 2) 172 | # plt.imshow(masks[0].permute(1, 2, 0).detach().cpu().numpy()) 173 | # plt.show() 174 | masks = F.interpolate(masks, (512, 512), mode="bilinear") 175 | 176 | return masks 177 | 178 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 179 | """Normalize pixel values and pad to a square input.""" 180 | # Normalize colors 181 | x = (x - self.pixel_mean) / self.pixel_std 182 | 183 | x = F.interpolate(x.unsqueeze(0), (self.image_encoder.img_size, self.image_encoder.img_size), mode="nearest") 184 | # plt.subplot(1, 2, 1) 185 | # plt.imshow(x[0].permute(1, 2, 0).detach().cpu().numpy()) 186 | return x 187 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/PMD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Get_gradient_nopadding(nn.Module): 7 | def __init__(self): 8 | super(Get_gradient_nopadding, self).__init__() 9 | kernel_v = [[0, -1, 0], 10 | [0, 0, 0], 11 | [0, 1, 0]] 12 | kernel_h = [[0, 0, 0], 13 | [-1, 0, 1], 14 | [0, 0, 0]] 15 | kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) 16 | kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) 17 | self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) 18 | self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) 19 | 20 | def forward(self, x): 21 | x_list = [] 22 | for i in range(x.shape[1]): 23 | x_i = x[:, i] 24 | x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) 25 | x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) 26 | x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) 27 | x_list.append(x_i) 28 | 29 | print(x_list[1]-x_list[0]) 30 | x = torch.cat(x_list, dim=1) 31 | return x 32 | 33 | 34 | class Get_curvature(nn.Module): 35 | def __init__(self): 36 | super(Get_curvature, self).__init__() 37 | kernel_v1 = [[0, -1, 0], 38 | [0, 0, 0], 39 | [0, 1, 0]] 40 | kernel_h1 = [[0, 0, 0], 41 | [-1, 0, 1], 42 | [0, 0, 0]] 43 | kernel_h2 = [[0, 0, 0, 0, 0], 44 | [0, 0, 0, 0, 0], 45 | [1, 0, -2, 0, 1], 46 | [0, 0, 0, 0, 0], 47 | [0, 0, 0, 0, 0]] 48 | kernel_v2 = [[0, 0, 1, 0, 0], 49 | [0, 0, 0, 0, 0], 50 | [0, 0, -2, 0, 0], 51 | [0, 0, 0, 0, 0], 52 | [0, 0, 1, 0, 0]] 53 | kernel_w2 = [[1, 0, -1], 54 | [0, 0, 0], 55 | [-1, 0, 1]] 56 | kernel_h1 = torch.FloatTensor(kernel_h1).unsqueeze(0).unsqueeze(0) 57 | kernel_v1 = torch.FloatTensor(kernel_v1).unsqueeze(0).unsqueeze(0) 58 | kernel_v2 = torch.FloatTensor(kernel_v2).unsqueeze(0).unsqueeze(0) 59 | kernel_h2 = torch.FloatTensor(kernel_h2).unsqueeze(0).unsqueeze(0) 60 | kernel_w2 = torch.FloatTensor(kernel_w2).unsqueeze(0).unsqueeze(0) 61 | self.weight_h1 = nn.Parameter(data=kernel_h1, requires_grad=False) 62 | self.weight_v1 = nn.Parameter(data=kernel_v1, requires_grad=False) 63 | self.weight_v2 = nn.Parameter(data=kernel_v2, requires_grad=False) 64 | self.weight_h2 = nn.Parameter(data=kernel_h2, requires_grad=False) 65 | self.weight_w2 = nn.Parameter(data=kernel_w2, requires_grad=False) 66 | 67 | def forward(self, x): 68 | x_list = [] 69 | for i in range(x.shape[1]): 70 | x_i = x[:, i] 71 | x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v1, padding=1) 72 | x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h1, padding=1) 73 | x_i_v2 = F.conv2d(x_i.unsqueeze(1), self.weight_v2, padding=2) 74 | x_i_h2 = F.conv2d(x_i.unsqueeze(1), self.weight_h2, padding=2) 75 | x_i_w2 = F.conv2d(x_i.unsqueeze(1), self.weight_w2, padding=1) 76 | sum = torch.pow((torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2)), 3 / 2) 77 | fg = torch.mul(torch.pow(x_i_v, 2), x_i_v2) + 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul( 78 | torch.pow(x_i_h, 2), x_i_h2) 79 | fh = torch.mul(torch.pow(x_i_v, 2), x_i_h2) - 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul( 80 | torch.pow(x_i_h, 2), x_i_v2) 81 | x_i = torch.div(torch.abs(fg - fh), sum + 1e-10) 82 | x_i = torch.div(torch.abs(fh), sum + 1e-10) 83 | x_list.append(x_i) 84 | x = torch.cat(x_list, dim=1) 85 | return x 86 | 87 | 88 | class FeatureEncoder(nn.Module): 89 | def __init__(self, out_dims): 90 | super(FeatureEncoder, self).__init__() 91 | 92 | self.conv1 = nn.Conv2d(3, out_dims[0], kernel_size=3, padding=1) 93 | self.relu1 = nn.ReLU(inplace=True) 94 | self.conv2 = nn.Conv2d(out_dims[0], out_dims[0], kernel_size=3, padding=1) 95 | self.relu2 = nn.ReLU(inplace=True) 96 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 97 | 98 | self.conv3 = nn.Conv2d(out_dims[0], out_dims[1], kernel_size=3, padding=1) 99 | self.relu3 = nn.ReLU(inplace=True) 100 | self.conv4 = nn.Conv2d(out_dims[1], out_dims[1], kernel_size=3, padding=1) 101 | self.relu4 = nn.ReLU(inplace=True) 102 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 103 | 104 | self.conv5 = nn.Conv2d(out_dims[1], out_dims[2], kernel_size=3, padding=1) 105 | self.relu5 = nn.ReLU(inplace=True) 106 | self.conv6 = nn.Conv2d(out_dims[2], out_dims[2], kernel_size=3, padding=1) 107 | self.relu6 = nn.ReLU(inplace=True) 108 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 109 | 110 | self.conv7 = nn.Conv2d(out_dims[2], out_dims[3], kernel_size=3, padding=1) 111 | self.relu7 = nn.ReLU(inplace=True) 112 | self.conv8 = nn.Conv2d(out_dims[3], out_dims[3], kernel_size=3, padding=1) 113 | self.relu8 = nn.ReLU(inplace=True) 114 | 115 | def forward(self, x): 116 | # Stage 1 117 | x = self.conv1(x) 118 | x = self.relu1(x) 119 | x = self.conv2(x) 120 | x = self.relu2(x) 121 | x = self.maxpool1(x) 122 | x1 = x 123 | 124 | # Stage 2 125 | x = self.conv3(x) 126 | x = self.relu3(x) 127 | x = self.conv4(x) 128 | x = self.relu4(x) 129 | x = self.maxpool2(x) 130 | x2 = x 131 | 132 | # Stage 3 133 | x = self.conv5(x) 134 | x = self.relu5(x) 135 | x = self.conv6(x) 136 | x = self.relu6(x) 137 | x = self.maxpool3(x) 138 | x3 = x 139 | 140 | # Stage 4 141 | x = self.conv7(x) 142 | x = self.relu7(x) 143 | x = self.conv8(x) 144 | x = self.relu8(x) 145 | x4 = x 146 | 147 | return x1, x2, x3, x4 148 | 149 | 150 | class PMD_features(nn.Module): 151 | def __init__(self, out_dims): 152 | super(PMD_features, self).__init__() 153 | # self.PMD_head = Get_curvature() 154 | self.PMD_head = Get_gradient_nopadding() 155 | # self.feature_ext = FeatureEncoder(out_dims) 156 | 157 | def forward(self, images): 158 | PMD_images = self.PMD_head(images) 159 | # PMD_feature = self.feature_ext(PMD_images) 160 | 161 | return PMD_images 162 | 163 | 164 | # class Adapter(nn.Module): 165 | # def __init__(self, out_dims): 166 | # super(Adapter, self).__init__() 167 | # self.PMD_head = Get_gradient_nopadding() 168 | # self.feature_ext = FeatureEncoder(out_dims) 169 | # 170 | # def forward(self, images): 171 | # PMD_images = self.PMD_head(images) 172 | # PMD_feature = self.feature_ext(PMD_images) 173 | # 174 | # return PMD_feature 175 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | 9 | from .image_encoder import ImageEncoderViT 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | from .IRSAM_decoder import MaskDecoder as MaskDecoder_test 14 | 15 | from .IRSAM_decoder import MaskDecoder as Decoder 16 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/IRSAM_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/IRSAM_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/IRSAM_edge.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/IRSAM_edge.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/IRSAM_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/IRSAM_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/PMD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/PMD.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/image_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/image_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/mask_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/mask_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/sam.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/modeling/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type, Any, List 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | import math 16 | import warnings 17 | from itertools import repeat 18 | 19 | import collections.abc as container_abcs 20 | 21 | 22 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 23 | class ImageEncoderViT(nn.Module): 24 | def __init__( 25 | self, 26 | img_size: int = 1024, 27 | patch_size: int = 16, 28 | in_chans: int = 3, 29 | embed_dim: int = 768, 30 | depth: int = 12, 31 | num_heads: int = 12, 32 | mlp_ratio: float = 4.0, 33 | out_chans: int = 256, 34 | qkv_bias: bool = True, 35 | norm_layer: Type[nn.Module] = nn.LayerNorm, 36 | act_layer: Type[nn.Module] = nn.GELU, 37 | use_abs_pos: bool = True, 38 | use_rel_pos: bool = False, 39 | rel_pos_zero_init: bool = True, 40 | window_size: int = 0, 41 | global_attn_indexes: Tuple[int, ...] = (), 42 | ) -> None: 43 | """ 44 | Args: 45 | img_size (int): Input image size. 46 | patch_size (int): Patch size. 47 | in_chans (int): Number of input image channels. 48 | embed_dim (int): Patch embedding dimension. 49 | depth (int): Depth of ViT. 50 | num_heads (int): Number of attention heads in each ViT block. 51 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 52 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 53 | norm_layer (nn.Module): Normalization layer. 54 | act_layer (nn.Module): Activation layer. 55 | use_abs_pos (bool): If True, use absolute positional embeddings. 56 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 57 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 58 | window_size (int): Window size for window attention blocks. 59 | global_attn_indexes (list): Indexes for blocks using global attention. 60 | """ 61 | super().__init__() 62 | self.img_size = img_size 63 | 64 | self.patch_embed = PatchEmbed( 65 | kernel_size=(patch_size, patch_size), 66 | stride=(patch_size, patch_size), 67 | in_chans=in_chans, 68 | embed_dim=embed_dim, 69 | ) 70 | 71 | self.pos_embed: Optional[nn.Parameter] = None 72 | if use_abs_pos: 73 | # Initialize absolute positional embedding with pretrain image size. 74 | self.pos_embed = nn.Parameter( 75 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 76 | ) 77 | 78 | self.blocks = nn.ModuleList() 79 | 80 | for i in range(depth): 81 | block = Block( 82 | dim=embed_dim, 83 | num_heads=num_heads, 84 | mlp_ratio=mlp_ratio, 85 | qkv_bias=qkv_bias, 86 | norm_layer=norm_layer, 87 | act_layer=act_layer, 88 | use_rel_pos=use_rel_pos, 89 | rel_pos_zero_init=rel_pos_zero_init, 90 | window_size=window_size if i not in global_attn_indexes else 0, 91 | input_size=(img_size // patch_size, img_size // patch_size), 92 | ) 93 | self.blocks.append(block) 94 | 95 | self.neck = nn.Sequential( 96 | nn.Conv2d( 97 | embed_dim, 98 | out_chans, 99 | kernel_size=1, 100 | bias=False, 101 | ), 102 | LayerNorm2d(out_chans), 103 | nn.Conv2d( 104 | out_chans, 105 | out_chans, 106 | kernel_size=3, 107 | padding=1, 108 | bias=False, 109 | ), 110 | LayerNorm2d(out_chans), 111 | ) 112 | 113 | def forward(self, x: torch.Tensor) -> Tuple[Any, List[Any]]: 114 | x = self.patch_embed(x) 115 | 116 | if self.pos_embed is not None: 117 | x = x + self.pos_embed 118 | 119 | interm_embeddings = [] 120 | for i, blk in enumerate(self.blocks): 121 | x = blk(x) 122 | interm_embeddings.append(x) 123 | # if blk.window_size == 0: 124 | # interm_embeddings.append(x.permute(0, 3, 1, 2)) 125 | 126 | x = self.neck(x.permute(0, 3, 1, 2)) 127 | 128 | return x, interm_embeddings 129 | 130 | 131 | class Block(nn.Module): 132 | """Transformer blocks with support of window attention and residual propagation blocks""" 133 | 134 | def __init__( 135 | self, 136 | dim: int, 137 | num_heads: int, 138 | mlp_ratio: float = 4.0, 139 | qkv_bias: bool = True, 140 | norm_layer: Type[nn.Module] = nn.LayerNorm, 141 | act_layer: Type[nn.Module] = nn.GELU, 142 | use_rel_pos: bool = False, 143 | rel_pos_zero_init: bool = True, 144 | window_size: int = 0, 145 | input_size: Optional[Tuple[int, int]] = None, 146 | ) -> None: 147 | """ 148 | Args: 149 | dim (int): Number of input channels. 150 | num_heads (int): Number of attention heads in each ViT block. 151 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 152 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 153 | norm_layer (nn.Module): Normalization layer. 154 | act_layer (nn.Module): Activation layer. 155 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 156 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 157 | window_size (int): Window size for window attention blocks. If it equals 0, then 158 | use global attention. 159 | input_size (int or None): Input resolution for calculating the relative positional 160 | parameter size. 161 | """ 162 | super().__init__() 163 | self.norm1 = norm_layer(dim) 164 | self.attn = Attention( 165 | dim, 166 | num_heads=num_heads, 167 | qkv_bias=qkv_bias, 168 | use_rel_pos=use_rel_pos, 169 | rel_pos_zero_init=rel_pos_zero_init, 170 | input_size=input_size if window_size == 0 else (window_size, window_size), 171 | ) 172 | 173 | self.norm2 = norm_layer(dim) 174 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 175 | 176 | self.window_size = window_size 177 | 178 | def forward(self, x: torch.Tensor) -> torch.Tensor: 179 | shortcut = x 180 | x = self.norm1(x) 181 | # Window partition 182 | if self.window_size > 0: 183 | H, W = x.shape[1], x.shape[2] 184 | x, pad_hw = window_partition(x, self.window_size) 185 | 186 | x = self.attn(x) 187 | # Reverse window partition 188 | if self.window_size > 0: 189 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 190 | 191 | x = shortcut + x 192 | x = x + self.mlp(self.norm2(x)) 193 | 194 | return x 195 | 196 | 197 | class Attention(nn.Module): 198 | """Multi-head Attention block with relative position embeddings.""" 199 | 200 | def __init__( 201 | self, 202 | dim: int, 203 | num_heads: int = 8, 204 | qkv_bias: bool = True, 205 | use_rel_pos: bool = False, 206 | rel_pos_zero_init: bool = True, 207 | input_size: Optional[Tuple[int, int]] = None, 208 | ) -> None: 209 | """ 210 | Args: 211 | dim (int): Number of input channels. 212 | num_heads (int): Number of attention heads. 213 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 214 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 215 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 216 | input_size (int or None): Input resolution for calculating the relative positional 217 | parameter size. 218 | """ 219 | super().__init__() 220 | self.num_heads = num_heads 221 | head_dim = dim // num_heads 222 | self.scale = head_dim ** -0.5 223 | 224 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 225 | self.proj = nn.Linear(dim, dim) 226 | 227 | self.use_rel_pos = use_rel_pos 228 | if self.use_rel_pos: 229 | assert ( 230 | input_size is not None 231 | ), "Input size must be provided if using relative positional encoding." 232 | # initialize relative positional embeddings 233 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 234 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 235 | 236 | def forward(self, x: torch.Tensor) -> torch.Tensor: 237 | B, H, W, _ = x.shape 238 | # qkv with shape (3, B, nHead, H * W, C) 239 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 240 | # q, k, v with shape (B * nHead, H * W, C) 241 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 242 | 243 | attn = (q * self.scale) @ k.transpose(-2, -1) 244 | 245 | if self.use_rel_pos: 246 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 247 | 248 | attn = attn.softmax(dim=-1) 249 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 250 | x = self.proj(x) 251 | return x 252 | 253 | 254 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 255 | """ 256 | Partition into non-overlapping windows with padding if needed. 257 | Args: 258 | x (tensor): input tokens with [B, H, W, C]. 259 | window_size (int): window size. 260 | 261 | Returns: 262 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 263 | (Hp, Wp): padded height and width before partition 264 | """ 265 | B, H, W, C = x.shape 266 | 267 | pad_h = (window_size - H % window_size) % window_size 268 | pad_w = (window_size - W % window_size) % window_size 269 | if pad_h > 0 or pad_w > 0: 270 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 271 | Hp, Wp = H + pad_h, W + pad_w 272 | 273 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 274 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 275 | return windows, (Hp, Wp) 276 | 277 | 278 | def window_unpartition( 279 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 280 | ) -> torch.Tensor: 281 | """ 282 | Window unpartition into original sequences and removing padding. 283 | Args: 284 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 285 | window_size (int): window size. 286 | pad_hw (Tuple): padded height and width (Hp, Wp). 287 | hw (Tuple): original height and width (H, W) before padding. 288 | 289 | Returns: 290 | x: unpartitioned sequences with [B, H, W, C]. 291 | """ 292 | Hp, Wp = pad_hw 293 | H, W = hw 294 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 295 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 296 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 297 | 298 | if Hp > H or Wp > W: 299 | x = x[:, :H, :W, :].contiguous() 300 | return x 301 | 302 | 303 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Get relative positional embeddings according to the relative positions of 306 | query and key sizes. 307 | Args: 308 | q_size (int): size of query q. 309 | k_size (int): size of key k. 310 | rel_pos (Tensor): relative position embeddings (L, C). 311 | 312 | Returns: 313 | Extracted positional embeddings according to relative positions. 314 | """ 315 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 316 | # Interpolate rel pos if needed. 317 | if rel_pos.shape[0] != max_rel_dist: 318 | # Interpolate rel pos. 319 | rel_pos_resized = F.interpolate( 320 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 321 | size=max_rel_dist, 322 | mode="linear", 323 | ) 324 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 325 | else: 326 | rel_pos_resized = rel_pos 327 | 328 | # Scale the coords with short length if shapes for q and k are different. 329 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 330 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 331 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 332 | 333 | return rel_pos_resized[relative_coords.long()] 334 | 335 | 336 | def add_decomposed_rel_pos( 337 | attn: torch.Tensor, 338 | q: torch.Tensor, 339 | rel_pos_h: torch.Tensor, 340 | rel_pos_w: torch.Tensor, 341 | q_size: Tuple[int, int], 342 | k_size: Tuple[int, int], 343 | ) -> torch.Tensor: 344 | """ 345 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 346 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 347 | Args: 348 | attn (Tensor): attention map. 349 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 350 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 351 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 352 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 353 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 354 | 355 | Returns: 356 | attn (Tensor): attention map with added relative positional embeddings. 357 | """ 358 | q_h, q_w = q_size 359 | k_h, k_w = k_size 360 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 361 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 362 | 363 | B, _, dim = q.shape 364 | r_q = q.reshape(B, q_h, q_w, dim) 365 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 366 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 367 | 368 | attn = ( 369 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 370 | ).view(B, q_h * q_w, k_h * k_w) 371 | 372 | return attn 373 | 374 | 375 | class PatchEmbed(nn.Module): 376 | """ 377 | Image to Patch Embedding. 378 | """ 379 | 380 | def __init__( 381 | self, 382 | kernel_size: Tuple[int, int] = (16, 16), 383 | stride: Tuple[int, int] = (16, 16), 384 | padding: Tuple[int, int] = (0, 0), 385 | in_chans: int = 3, 386 | embed_dim: int = 768, 387 | ) -> None: 388 | """ 389 | Args: 390 | kernel_size (Tuple): kernel size of the projection layer. 391 | stride (Tuple): stride of the projection layer. 392 | padding (Tuple): padding size of the projection layer. 393 | in_chans (int): Number of input image channels. 394 | embed_dim (int): embed_dim (int): Patch embedding dimension. 395 | """ 396 | super().__init__() 397 | 398 | self.proj = nn.Conv2d( 399 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 400 | ) 401 | 402 | def forward(self, x: torch.Tensor) -> torch.Tensor: 403 | x = self.proj(x) 404 | # B C H W -> B H W C 405 | x = x.permute(0, 2, 3, 1) 406 | return x 407 | 408 | 409 | 410 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | # patch_embeddings: torch.Tensor, 75 | image_pe: torch.Tensor, 76 | sparse_prompt_embeddings: torch.Tensor, 77 | dense_prompt_embeddings: torch.Tensor, 78 | multimask_output: bool = None, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | # patch_embeddings=patch_embeddings, 98 | image_pe=image_pe, 99 | sparse_prompt_embeddings=sparse_prompt_embeddings, 100 | dense_prompt_embeddings=dense_prompt_embeddings, 101 | ) 102 | 103 | # Select the correct mask or masks for outptu 104 | if multimask_output: 105 | mask_slice = slice(1, None) 106 | else: 107 | mask_slice = slice(0, 1) 108 | masks = masks[:, mask_slice, :, :] 109 | iou_pred = iou_pred[:, mask_slice] 110 | 111 | # Prepare output 112 | return masks, iou_pred 113 | 114 | def predict_masks( 115 | self, 116 | image_embeddings: torch.Tensor, 117 | # patch_embeddings: torch.Tensor, 118 | image_pe: torch.Tensor, 119 | sparse_prompt_embeddings: torch.Tensor, 120 | dense_prompt_embeddings: torch.Tensor, 121 | ) -> Tuple[torch.Tensor, torch.Tensor]: 122 | """Predicts masks. See 'forward' for more details.""" 123 | # Concatenate output tokens 124 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 125 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 126 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 127 | 128 | # Expand per-image data in batch direction to be per-mask 129 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 130 | src = src + dense_prompt_embeddings 131 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 132 | b, c, h, w = src.shape 133 | 134 | # Run the transformer 135 | hs, src = self.transformer(src, pos_src, tokens) 136 | iou_token_out = hs[:, 0, :] 137 | mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] 138 | 139 | # Upscale mask embeddings and predict masks using the mask tokens 140 | src = src.transpose(1, 2).view(b, c, h, w) 141 | upscaled_embedding = self.output_upscaling(src) 142 | hyper_in_list: List[torch.Tensor] = [] 143 | for i in range(self.num_mask_tokens): 144 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 145 | hyper_in = torch.stack(hyper_in_list, dim=1) 146 | b, c, h, w = upscaled_embedding.shape 147 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 148 | 149 | # Generate mask quality predictions 150 | iou_pred = self.iou_prediction_head(iou_token_out) 151 | 152 | return masks, iou_pred 153 | 154 | 155 | # Lightly adapted from 156 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 157 | class MLP(nn.Module): 158 | def __init__( 159 | self, 160 | input_dim: int, 161 | hidden_dim: int, 162 | output_dim: int, 163 | num_layers: int, 164 | sigmoid_output: bool = False, 165 | ) -> None: 166 | super().__init__() 167 | self.num_layers = num_layers 168 | h = [hidden_dim] * (num_layers - 1) 169 | self.layers = nn.ModuleList( 170 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 171 | ) 172 | self.sigmoid_output = sigmoid_output 173 | 174 | def forward(self, x): 175 | for i, layer in enumerate(self.layers): 176 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 177 | if self.sigmoid_output: 178 | x = F.sigmoid(x) 179 | return x 180 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from torch.nn import functional as F 12 | 13 | from typing import Any, Optional, Tuple, Type 14 | 15 | from .common import LayerNorm2d 16 | 17 | 18 | class PromptEncoder(nn.Module): 19 | def __init__( 20 | self, 21 | embed_dim: int, 22 | image_embedding_size: Tuple[int, int], 23 | input_image_size: Tuple[int, int], 24 | mask_in_chans: int, 25 | activation: Type[nn.Module] = nn.GELU, 26 | ) -> None: 27 | """ 28 | Encodes prompts for input to SAM's mask decoder. 29 | 30 | Arguments: 31 | embed_dim (int): The prompts' embedding dimension 32 | image_embedding_size (tuple(int, int)): The spatial size of the 33 | image embedding, as (H, W). 34 | input_image_size (int): The padded size of the image as input 35 | to the image encoder, as (H, W). 36 | mask_in_chans (int): The number of hidden channels used for 37 | encoding input masks. 38 | activation (nn.Module): The activation to use when encoding 39 | input masks. 40 | """ 41 | super().__init__() 42 | self.embed_dim = embed_dim 43 | self.input_image_size = input_image_size 44 | self.image_embedding_size = image_embedding_size 45 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 46 | self.pe_layer_2 = PositionEmbeddingRandom(embed_dim // 8) 47 | 48 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 49 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 54 | self.mask_downscaling = nn.Sequential( 55 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans // 4), 57 | activation(), 58 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans), 60 | activation(), 61 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 62 | ) 63 | self.no_mask_embed = nn.Embedding(1, embed_dim) 64 | 65 | def get_dense_pe_2(self) -> torch.Tensor: 66 | """ 67 | Returns the positional encoding used to encode point prompts, 68 | applied to a dense set of points the shape of the image encoding. 69 | 70 | Returns: 71 | torch.Tensor: Positional encoding with shape 72 | 1x(embed_dim)x(embedding_h)x(embedding_w) 73 | """ 74 | h, w = self.image_embedding_size 75 | return self.pe_layer_2((h * 4, w * 4)).unsqueeze(0) 76 | 77 | def get_dense_pe(self) -> torch.Tensor: 78 | """ 79 | Returns the positional encoding used to encode point prompts, 80 | applied to a dense set of points the shape of the image encoding. 81 | 82 | Returns: 83 | torch.Tensor: Positional encoding with shape 84 | 1x(embed_dim)x(embedding_h)x(embedding_w) 85 | """ 86 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 87 | 88 | def get_dense_pe_large(self) -> torch.Tensor: 89 | """ 90 | Returns the positional encoding used to encode point prompts, 91 | applied to a dense set of points the shape of the image encoding. 92 | 93 | Returns: 94 | torch.Tensor: Positional encoding with shape 95 | 1x(embed_dim)x(embedding_h)x(embedding_w) 96 | """ 97 | h, w = self.image_embedding_size 98 | 99 | return self.pe_layer((h * 2, w * 2)).unsqueeze(0) 100 | 101 | def _embed_points( 102 | self, 103 | points: torch.Tensor, 104 | labels: torch.Tensor, 105 | pad: bool, 106 | ) -> torch.Tensor: 107 | """Embeds point prompts.""" 108 | points = points + 0.5 # Shift to center of pixel 109 | if pad: 110 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 111 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 112 | points = torch.cat([points, padding_point], dim=1) 113 | labels = torch.cat([labels, padding_label], dim=1) 114 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 115 | point_embedding[labels == -1] = 0.0 116 | point_embedding[labels == -1] += self.not_a_point_embed.weight 117 | point_embedding[labels == 0] += self.point_embeddings[0].weight 118 | point_embedding[labels == 1] += self.point_embeddings[1].weight 119 | return point_embedding 120 | 121 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 122 | """Embeds box prompts.""" 123 | boxes = boxes + 0.5 # Shift to center of pixel 124 | coords = boxes.reshape(-1, 2, 2) 125 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 126 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 127 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 128 | return corner_embedding 129 | 130 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 131 | """Embeds mask inputs.""" 132 | mask_embedding = self.mask_downscaling(masks) 133 | return mask_embedding 134 | 135 | def _get_batch_size( 136 | self, 137 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 138 | boxes: Optional[torch.Tensor], 139 | masks: Optional[torch.Tensor], 140 | ) -> int: 141 | """ 142 | Gets the batch size of the output given the batch size of the input prompts. 143 | """ 144 | if points is not None: 145 | return points[0].shape[0] 146 | elif boxes is not None: 147 | return boxes.shape[0] 148 | elif masks is not None: 149 | return masks.shape[0] 150 | else: 151 | return 1 152 | 153 | def _get_device(self) -> torch.device: 154 | return self.point_embeddings[0].weight.device 155 | 156 | def forward( 157 | self, 158 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 159 | boxes: Optional[torch.Tensor], 160 | masks: Optional[torch.Tensor], 161 | ) -> Tuple[torch.Tensor, torch.Tensor]: 162 | """ 163 | Embeds different types of prompts, returning both sparse and dense 164 | embeddings. 165 | 166 | Arguments: 167 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 168 | and labels to embed. 169 | boxes (torch.Tensor or none): boxes to embed 170 | masks (torch.Tensor or none): masks to embed 171 | 172 | Returns: 173 | torch.Tensor: sparse embeddings for the points and boxes, with shape 174 | BxNx(embed_dim), where N is determined by the number of input points 175 | and boxes. 176 | torch.Tensor: dense embeddings for the masks, in the shape 177 | Bx(embed_dim)x(embed_H)x(embed_W) 178 | """ 179 | bs = self._get_batch_size(points, boxes, masks) 180 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 181 | if points is not None: 182 | coords, labels = points 183 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 184 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 185 | if boxes is not None: 186 | box_embeddings = self._embed_boxes(boxes) 187 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 188 | 189 | if masks is not None: 190 | masks = F.interpolate(masks, self.mask_input_size, mode="bilinear") 191 | dense_embeddings = self._embed_masks(masks) 192 | else: 193 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 194 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 195 | ) 196 | 197 | return sparse_embeddings, dense_embeddings 198 | 199 | 200 | class PositionEmbeddingRandom(nn.Module): 201 | """ 202 | Positional encoding using random spatial frequencies. 203 | """ 204 | 205 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 206 | super().__init__() 207 | if scale is None or scale <= 0.0: 208 | scale = 1.0 209 | self.register_buffer( 210 | "positional_encoding_gaussian_matrix", 211 | scale * torch.randn((2, num_pos_feats)), 212 | ) 213 | 214 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 215 | """Positionally encode points that are normalized to [0,1].""" 216 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 217 | coords = 2 * coords - 1 218 | coords = coords @ self.positional_encoding_gaussian_matrix 219 | coords = 2 * np.pi * coords 220 | # outputs d_1 x ... x d_n x C shape 221 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 222 | 223 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 224 | """Generate positional encoding for a grid of the specified size.""" 225 | h, w = size 226 | device: Any = self.positional_encoding_gaussian_matrix.device 227 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 228 | y_embed = grid.cumsum(dim=0) - 0.5 229 | x_embed = grid.cumsum(dim=1) - 0.5 230 | y_embed = y_embed / h 231 | x_embed = x_embed / w 232 | 233 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 234 | return pe.permute(2, 0, 1) # C x H x W 235 | 236 | def forward_with_coords( 237 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 238 | ) -> torch.Tensor: 239 | """Positionally encode points that are not normalized to [0,1].""" 240 | coords = coords_input.clone() 241 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 242 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 243 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 244 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple, Union 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | class Sam(nn.Module): 21 | mask_threshold: float = 0.0 22 | image_format: str = "RGB" 23 | 24 | def __init__( 25 | self, 26 | image_encoder: ImageEncoderViT, 27 | prompt_encoder: PromptEncoder, 28 | mask_decoder: MaskDecoder, 29 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 30 | pixel_std: List[float] = [58.395, 57.12, 57.375], 31 | ) -> None: 32 | """ 33 | SAM predicts object masks from an image and input prompts. 34 | 35 | Arguments: 36 | image_encoder (ImageEncoderViT): The backbone used to encode the 37 | image into image embeddings that allow for efficient mask prediction. 38 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 39 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 40 | and encoded prompts. 41 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 42 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 43 | """ 44 | super().__init__() 45 | self.image_encoder = image_encoder 46 | self.prompt_encoder = prompt_encoder 47 | 48 | for n, p in self.named_parameters(): 49 | p.requires_grad = False 50 | 51 | self.mask_decoder = mask_decoder 52 | 53 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 54 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 55 | 56 | @property 57 | def device(self) -> Any: 58 | return self.pixel_mean.device 59 | 60 | def forward( 61 | self, 62 | batched_input: List[Dict[str, Any]], 63 | multimask_output: bool = None, 64 | ) -> [Tensor, Tensor]: 65 | """ 66 | Predicts masks end-to-end from provided images and prompts. 67 | If prompts are not known in advance, using SamPredictor is 68 | recommended over calling the model directly. 69 | 70 | Arguments: 71 | batched_input (list(dict)): A list over input images, each a 72 | dictionary with the following keys. A prompt key can be 73 | excluded if it is not present. 74 | 'image': The image as a torch tensor in 3xHxW format, 75 | already transformed for input to the model. 76 | 'original_size': (tuple(int, int)) The original size of 77 | the image before transformation, as (H, W). 78 | 'point_coords': (torch.Tensor) Batched point prompts for 79 | this image, with shape BxNx2. Already transformed to the 80 | input frame of the model. 81 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 82 | with shape BxN. 83 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 84 | Already transformed to the input frame of the model. 85 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 86 | in the form Bx1xHxW. 87 | multimask_output (bool): Whether the model should predict multiple 88 | disambiguating masks, or return a single mask. 89 | 90 | Returns: 91 | (list(dict)): A list over input images, where each element is 92 | as dictionary with the following keys. 93 | 'masks': (torch.Tensor) Batched binary mask predictions, 94 | with shape BxCxHxW, where B is the number of input promts, 95 | C is determiend by multimask_output, and (H, W) is the 96 | original size of the image. 97 | 'iou_predictions': (torch.Tensor) The model's predictions 98 | of mask quality, in shape BxC. 99 | 'low_res_logits': (torch.Tensor) Low resolution logits with 100 | shape BxCxHxW, where H=W=256. Can be passed as mask input 101 | to subsequent iterations of prediction. 102 | """ 103 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 104 | 105 | image_embeddings = self.image_encoder(input_images) 106 | outputs = [] 107 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 108 | if "point_coords" in image_record: 109 | points = (image_record["point_coords"], image_record["point_labels"]) 110 | else: 111 | points = None 112 | # plt.imshow(image_record.get("mask_inputs", None)[0].permute(1, 2, 0).cpu().detach().numpy()) 113 | # plt.show() 114 | # print(image_record.get("mask_inputs", None).shape) 115 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 116 | points=points, 117 | boxes=image_record.get("boxes", None), 118 | masks=image_record.get("mask_inputs", None), 119 | ) 120 | 121 | low_res_masks, iou_pred = self.mask_decoder( 122 | image_embeddings=curr_embedding, 123 | image_pe=self.prompt_encoder.get_dense_pe(), 124 | sparse_prompt_embeddings=sparse_embeddings, 125 | dense_prompt_embeddings=dense_embeddings, 126 | ) 127 | 128 | masks = self.postprocess_masks( 129 | low_res_masks, 130 | input_size=image_record["image"].shape[-2:], 131 | original_size=image_record["original_size"], 132 | ) 133 | 134 | # masks = masks > self.mask_threshold 135 | 136 | outputs.append( 137 | { 138 | "masks": masks, 139 | "low_res_logits": low_res_masks, 140 | "encoder_embedding": curr_embedding.unsqueeze(0), 141 | "image_pe": self.prompt_encoder.get_dense_pe(), 142 | "sparse_embeddings": sparse_embeddings, 143 | "dense_embeddings": dense_embeddings, 144 | } 145 | ) 146 | masks = torch.cat([x["masks"] for x in outputs], dim=0) 147 | 148 | return masks 149 | 150 | def postprocess_masks( 151 | self, 152 | masks: torch.Tensor, 153 | input_size: Tuple[int, ...], 154 | original_size: Tuple[int, ...], 155 | ) -> torch.Tensor: 156 | """ 157 | Remove padding and upscale masks to the original image size. 158 | 159 | Arguments: 160 | masks (torch.Tensor): Batched masks from the mask_decoder, 161 | in BxCxHxW format. 162 | input_size (tuple(int, int)): The size of the image input to the 163 | model, in (H, W) format. Used to remove padding. 164 | original_size (tuple(int, int)): The original size of the image 165 | before resizing for input to the model, in (H, W) format. 166 | 167 | Returns: 168 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 169 | is given by original_size. 170 | """ 171 | masks = F.interpolate( 172 | masks, 173 | (self.image_encoder.img_size, self.image_encoder.img_size), 174 | mode="bilinear", 175 | ) 176 | masks = masks[..., : int(input_size[0]), : int(input_size[1])] 177 | masks = F.interpolate(masks, original_size, mode="bilinear") 178 | return masks 179 | 180 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 181 | """Normalize pixel values and pad to a square input.""" 182 | # Normalize colors 183 | x = (x - self.pixel_mean) / self.pixel_std 184 | 185 | # Pad 186 | h, w = x.shape[-2:] 187 | padh = self.image_encoder.img_size - h 188 | padw = self.image_encoder.img_size - w 189 | x = F.pad(x, (0, padw, 0, padh)) 190 | 191 | return x 192 | -------------------------------------------------------------------------------- /segment_anything_training/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | from .prompt_encoder import PositionEmbeddingRandom 13 | 14 | from .common import MLPBlock, LayerNorm2d 15 | 16 | from timm.models.vision_transformer import VisionTransformer, Block 17 | from timm.layers import Mlp 18 | 19 | 20 | class TwoWayTransformer_1(nn.Module): 21 | def __init__( 22 | self, 23 | depth: int, 24 | embedding_dim: int, 25 | num_heads: int, 26 | mlp_dim: int, 27 | activation: Type[nn.Module] = nn.ReLU, 28 | attention_downsample_rate: int = 2, 29 | ) -> None: 30 | """ 31 | A transformer decoder that attends to an input image using 32 | queries whose positional embedding is supplied. 33 | 34 | Args: 35 | depth (int): number of layers in the transformer 36 | embedding_dim (int): the channel dimension for the input embeddings 37 | num_heads (int): the number of heads for multihead attention. Must 38 | divide embedding_dim 39 | mlp_dim (int): the channel dimension internal to the MLP block 40 | activation (nn.Module): the activation to use in the MLP block 41 | """ 42 | super().__init__() 43 | self.depth = depth 44 | self.embedding_dim = embedding_dim 45 | self.num_heads = num_heads 46 | self.mlp_dim = mlp_dim 47 | self.layers = nn.ModuleList() 48 | 49 | for i in range(depth): 50 | self.layers.append( 51 | TwoWayAttentionBlock( 52 | embedding_dim=embedding_dim, 53 | num_heads=num_heads, 54 | mlp_dim=mlp_dim, 55 | activation=activation, 56 | attention_downsample_rate=attention_downsample_rate, 57 | skip_first_layer_pe=(i == 0), 58 | ) 59 | ) 60 | 61 | self.upsample_1 = nn.Sequential( 62 | nn.Conv2d(embedding_dim, embedding_dim * 4, 3, 1, 1), 63 | nn.BatchNorm2d(embedding_dim * 4), 64 | activation(), 65 | nn.PixelShuffle(2), 66 | ) 67 | 68 | self.upsample_2 = nn.Sequential( 69 | nn.Conv2d(embedding_dim, embedding_dim * 4, 3, 1, 1), 70 | nn.BatchNorm2d(embedding_dim * 4), 71 | activation(), 72 | nn.PixelShuffle(2), 73 | ) 74 | 75 | self.pe = PositionEmbeddingRandom(256 // 2) 76 | 77 | self.final_attn_token_to_image = Attention( 78 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 79 | ) 80 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 81 | 82 | def forward( 83 | self, 84 | image_embedding: Tensor, 85 | image_pe: Tensor, 86 | point_embedding: Tensor, 87 | ) -> Tuple[Tensor, Tensor]: 88 | """ 89 | Args: 90 | image_embedding (torch.Tensor): image to attend to. Should be shape 91 | B x embedding_dim x h x w for any h and w. 92 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 93 | have the same shape as image_embedding. 94 | point_embedding (torch.Tensor): the embedding to add to the query points. 95 | Must have shape B x N_points x embedding_dim for any N_points. 96 | 97 | Returns: 98 | torch.Tensor: the processed point_embedding 99 | torch.Tensor: the processed image_embedding 100 | """ 101 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 102 | embedding = image_embedding 103 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 104 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 105 | 106 | # Prepare queries 107 | queries = point_embedding 108 | keys = image_embedding 109 | 110 | # Apply transformer blocks and final layernorm 111 | for layer in self.layers: 112 | queries, keys = layer( 113 | queries=queries, 114 | keys=keys, 115 | query_pe=point_embedding, 116 | key_pe=image_pe, 117 | ) 118 | 119 | # embeddings_128 120 | embedding_1 = self.upsample_1(embedding) 121 | image_embedding_1 = embedding_1.flatten(2).permute(0, 2, 1) 122 | image_pe_1 = self.pe((128, 128)).unsqueeze(0).flatten(2).permute(0, 2, 1) 123 | 124 | # Prepare queries 125 | queries_1 = point_embedding 126 | keys_1 = image_embedding_1 127 | 128 | # Apply transformer blocks and final layernorm 129 | for layer in self.layers: 130 | queries_1, keys_1 = layer( 131 | queries=queries, 132 | keys=keys_1, 133 | query_pe=point_embedding, 134 | key_pe=image_pe_1, 135 | ) 136 | 137 | # embeddings_128 138 | embedding_2 = self.upsample_1(embedding_1) 139 | image_embedding_2 = embedding_2.flatten(2).permute(0, 2, 1) 140 | image_pe_2 = self.pe((256, 256)).unsqueeze(0).flatten(2).permute(0, 2, 1) 141 | 142 | # Prepare queries 143 | queries_2 = point_embedding 144 | keys_2 = image_embedding_2 145 | 146 | # Apply transformer blocks and final layernorm 147 | for layer in self.layers: 148 | queries_2, keys_2 = layer( 149 | queries=queries, 150 | keys=keys_2, 151 | query_pe=point_embedding, 152 | key_pe=image_pe_2, 153 | ) 154 | 155 | # Apply the final attenion layer from the points to the image 156 | q = queries + queries_1 + queries_2 + point_embedding 157 | keys = keys.view(1, 64, 64, 256).permute(0, 3, 1, 2) 158 | keys_1 = keys_1.view(1, 128, 128, 256).permute(0, 3, 1, 2) 159 | keys_2 = keys_2.view(1, 256, 256, 256).permute(0, 3, 1, 2) 160 | keys = (self.upsample_2(self.upsample_1(keys) + keys_1) + keys_2).flatten(2).permute(0, 2, 1) 161 | k = keys + image_pe_2 162 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 163 | 164 | queries = queries + queries_1 + queries_2 + attn_out 165 | queries = self.norm_final_attn(queries) 166 | 167 | return queries, keys 168 | 169 | 170 | class TwoWayTransformer(nn.Module): 171 | def __init__( 172 | self, 173 | depth: int, 174 | embedding_dim: int, 175 | num_heads: int, 176 | mlp_dim: int, 177 | activation: Type[nn.Module] = nn.ReLU, 178 | attention_downsample_rate: int = 2, 179 | ) -> None: 180 | """ 181 | A transformer decoder that attends to an input image using 182 | queries whose positional embedding is supplied. 183 | 184 | Args: 185 | depth (int): number of layers in the transformer 186 | embedding_dim (int): the channel dimension for the input embeddings 187 | num_heads (int): the number of heads for multihead attention. Must 188 | divide embedding_dim 189 | mlp_dim (int): the channel dimension internal to the MLP block 190 | activation (nn.Module): the activation to use in the MLP block 191 | """ 192 | super().__init__() 193 | self.depth = depth 194 | self.embedding_dim = embedding_dim 195 | self.num_heads = num_heads 196 | self.mlp_dim = mlp_dim 197 | self.layers = nn.ModuleList() 198 | 199 | for i in range(depth): 200 | self.layers.append( 201 | TwoWayAttentionBlock( 202 | embedding_dim=embedding_dim, 203 | num_heads=num_heads, 204 | mlp_dim=mlp_dim, 205 | activation=activation, 206 | attention_downsample_rate=attention_downsample_rate, 207 | skip_first_layer_pe=(i == 0), 208 | ) 209 | ) 210 | 211 | self.final_attn_token_to_image = Attention( 212 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 213 | ) 214 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 215 | 216 | def forward( 217 | self, 218 | image_embedding: Tensor, 219 | image_pe: Tensor, 220 | point_embedding: Tensor, 221 | ) -> Tuple[Tensor, Tensor]: 222 | """ 223 | Args: 224 | image_embedding (torch.Tensor): image to attend to. Should be shape 225 | B x embedding_dim x h x w for any h and w. 226 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 227 | have the same shape as image_embedding. 228 | point_embedding (torch.Tensor): the embedding to add to the query points. 229 | Must have shape B x N_points x embedding_dim for any N_points. 230 | 231 | Returns: 232 | torch.Tensor: the processed point_embedding 233 | torch.Tensor: the processed image_embedding 234 | """ 235 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 236 | bs, c, h, w = image_embedding.shape 237 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 238 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 239 | 240 | # Prepare queries 241 | queries = point_embedding 242 | keys = image_embedding 243 | 244 | # Apply transformer blocks and final layernorm 245 | for layer in self.layers: 246 | queries, keys = layer( 247 | queries=queries, 248 | keys=keys, 249 | query_pe=point_embedding, 250 | key_pe=image_pe, 251 | ) 252 | # print("query.shape:", queries.shape) 253 | 254 | # Apply the final attenion layer from the points to the image 255 | q = queries + point_embedding 256 | k = keys + image_pe 257 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 258 | queries = queries + attn_out 259 | queries = self.norm_final_attn(queries) 260 | 261 | return queries, keys 262 | 263 | 264 | class TwoWayAttentionBlock(nn.Module): 265 | def __init__( 266 | self, 267 | embedding_dim: int, 268 | num_heads: int, 269 | mlp_dim: int = 2048, 270 | activation: Type[nn.Module] = nn.ReLU, 271 | attention_downsample_rate: int = 2, 272 | skip_first_layer_pe: bool = False, 273 | ) -> None: 274 | """ 275 | A transformer block with four layers: (1) self-attention of sparse 276 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 277 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 278 | inputs. 279 | 280 | Arguments: 281 | embedding_dim (int): the channel dimension of the embeddings 282 | num_heads (int): the number of heads in the attention layers 283 | mlp_dim (int): the hidden dimension of the mlp block 284 | activation (nn.Module): the activation of the mlp block 285 | skip_first_layer_pe (bool): skip the PE on the first layer 286 | """ 287 | super().__init__() 288 | self.self_attn = Attention(embedding_dim, num_heads) 289 | self.norm1 = nn.LayerNorm(embedding_dim) 290 | 291 | self.cross_attn_token_to_image = Attention( 292 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 293 | ) 294 | self.norm2 = nn.LayerNorm(embedding_dim) 295 | 296 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 297 | self.norm3 = nn.LayerNorm(embedding_dim) 298 | 299 | self.norm4 = nn.LayerNorm(embedding_dim) 300 | self.cross_attn_image_to_token = Attention( 301 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 302 | ) 303 | 304 | self.skip_first_layer_pe = skip_first_layer_pe 305 | 306 | def forward( 307 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 308 | ) -> Tuple[Tensor, Tensor]: 309 | # Self attention block 310 | if self.skip_first_layer_pe: 311 | queries = self.self_attn(q=queries, k=queries, v=queries) 312 | else: 313 | q = queries + query_pe 314 | attn_out = self.self_attn(q=q, k=q, v=queries) 315 | queries = queries + attn_out 316 | queries = self.norm1(queries) 317 | 318 | # Cross attention block, tokens attending to image embedding 319 | q = queries + query_pe 320 | k = keys + key_pe 321 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 322 | queries = queries + attn_out 323 | queries = self.norm2(queries) 324 | 325 | # MLP block 326 | mlp_out = self.mlp(queries) 327 | queries = queries + mlp_out 328 | queries = self.norm3(queries) 329 | 330 | # Cross attention block, image embedding attending to tokens 331 | q = queries + query_pe 332 | k = keys + key_pe 333 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 334 | keys = keys + attn_out 335 | keys = self.norm4(keys) 336 | # print("max_keys:", torch.max(keys), "min_keys:", torch.min(keys)) 337 | 338 | return queries, keys 339 | 340 | 341 | class Attention(nn.Module): 342 | """ 343 | An attention layer that allows for downscaling the size of the embedding 344 | after projection to queries, keys, and values. 345 | """ 346 | 347 | def __init__( 348 | self, 349 | embedding_dim: int, 350 | num_heads: int, 351 | downsample_rate: int = 1, 352 | ) -> None: 353 | super().__init__() 354 | self.embedding_dim = embedding_dim 355 | self.internal_dim = embedding_dim // downsample_rate 356 | self.num_heads = num_heads 357 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 358 | 359 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 360 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 361 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 362 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 363 | 364 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 365 | b, n, c = x.shape 366 | x = x.reshape(b, n, num_heads, c // num_heads) 367 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 368 | 369 | def _recombine_heads(self, x: Tensor) -> Tensor: 370 | b, n_heads, n_tokens, c_per_head = x.shape 371 | x = x.transpose(1, 2) 372 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 373 | 374 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 375 | # Input projections 376 | q = self.q_proj(q) 377 | k = self.k_proj(k) 378 | v = self.v_proj(v) 379 | 380 | # Separate into heads 381 | q = self._separate_heads(q, self.num_heads) 382 | k = self._separate_heads(k, self.num_heads) 383 | v = self._separate_heads(v, self.num_heads) 384 | 385 | # Attention 386 | _, _, _, c_per_head = q.shape 387 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 388 | attn = attn / math.sqrt(c_per_head) 389 | attn = torch.softmax(attn, dim=-1) 390 | 391 | # Get output 392 | out = attn @ v 393 | out = self._recombine_heads(out) 394 | out = self.out_proj(out) 395 | 396 | return out 397 | -------------------------------------------------------------------------------- /segment_anything_training/utils/PMD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from typing import Optional, Tuple 6 | 7 | 8 | class LayerNorm2d(nn.Module): 9 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(num_channels)) 12 | self.bias = nn.Parameter(torch.zeros(num_channels)) 13 | self.eps = eps 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | u = x.mean(1, keepdim=True) 17 | s = (x - u).pow(2).mean(1, keepdim=True) 18 | x = (x - u) / torch.sqrt(s + self.eps) 19 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 20 | return x 21 | 22 | 23 | class Get_curvature(nn.Module): 24 | def __init__(self): 25 | super(Get_curvature, self).__init__() 26 | kernel_v1 = [[0, -1, 0], 27 | [0, 0, 0], 28 | [0, 1, 0]] 29 | kernel_h1 = [[0, 0, 0], 30 | [-1, 0, 1], 31 | [0, 0, 0]] 32 | kernel_h2 = [[0, 0, 0, 0, 0], 33 | [0, 0, 0, 0, 0], 34 | [1, 0, -2, 0, 1], 35 | [0, 0, 0, 0, 0], 36 | [0, 0, 0, 0, 0]] 37 | kernel_v2 = [[0, 0, 1, 0, 0], 38 | [0, 0, 0, 0, 0], 39 | [0, 0, -2, 0, 0], 40 | [0, 0, 0, 0, 0], 41 | [0, 0, 1, 0, 0]] 42 | kernel_w2 = [[1, 0, -1], 43 | [0, 0, 0], 44 | [-1, 0, 1]] 45 | kernel_h1 = torch.FloatTensor(kernel_h1).unsqueeze(0).unsqueeze(0) 46 | kernel_v1 = torch.FloatTensor(kernel_v1).unsqueeze(0).unsqueeze(0) 47 | kernel_v2 = torch.FloatTensor(kernel_v2).unsqueeze(0).unsqueeze(0) 48 | kernel_h2 = torch.FloatTensor(kernel_h2).unsqueeze(0).unsqueeze(0) 49 | kernel_w2 = torch.FloatTensor(kernel_w2).unsqueeze(0).unsqueeze(0) 50 | self.weight_h1 = nn.Parameter(data=kernel_h1, requires_grad=False) 51 | self.weight_v1 = nn.Parameter(data=kernel_v1, requires_grad=False) 52 | self.weight_v2 = nn.Parameter(data=kernel_v2, requires_grad=False) 53 | self.weight_h2 = nn.Parameter(data=kernel_h2, requires_grad=False) 54 | self.weight_w2 = nn.Parameter(data=kernel_w2, requires_grad=False) 55 | 56 | def forward(self, x): 57 | x_list = [] 58 | for i in range(x.shape[1]): 59 | x_i = x[:, i] 60 | x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v1, padding=1) 61 | x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h1, padding=1) 62 | x_i_v2 = F.conv2d(x_i.unsqueeze(1), self.weight_v2, padding=2) 63 | x_i_h2 = F.conv2d(x_i.unsqueeze(1), self.weight_h2, padding=2) 64 | x_i_w2 = F.conv2d(x_i.unsqueeze(1), self.weight_w2, padding=1) 65 | sum = torch.pow((torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2)), 3 / 2) 66 | fg = torch.mul(torch.pow(x_i_v, 2), x_i_v2) + 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul( 67 | torch.pow(x_i_h, 2), x_i_h2) 68 | fh = torch.mul(torch.pow(x_i_v, 2), x_i_h2) - 2 * torch.mul(torch.mul(x_i_v, x_i_h), x_i_w2) + torch.mul( 69 | torch.pow(x_i_h, 2), x_i_v2) 70 | x_i = torch.div(torch.abs(fg - fh), sum + 1e-10) 71 | x_i = torch.div(torch.abs(fh), sum + 1e-10) 72 | x_list.append(x_i) 73 | x = torch.cat(x_list, dim=1) 74 | return x 75 | 76 | 77 | class PatchEmbed(nn.Module): 78 | """ 79 | Image to Patch Embedding. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | kernel_size: Tuple[int, int] = (16, 16), 85 | stride: Tuple[int, int] = (16, 16), 86 | padding: Tuple[int, int] = (0, 0), 87 | in_chans: int = 3, 88 | embed_dim: int = 768, 89 | ) -> None: 90 | """ 91 | Args: 92 | kernel_size (Tuple): kernel size of the projection layer. 93 | stride (Tuple): stride of the projection layer. 94 | padding (Tuple): padding size of the projection layer. 95 | in_chans (int): Number of input image channels. 96 | embed_dim (int): embed_dim (int): Patch embedding dimension. 97 | """ 98 | super().__init__() 99 | 100 | self.proj = nn.Conv2d( 101 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 102 | ) 103 | 104 | def forward(self, x: torch.Tensor) -> torch.Tensor: 105 | x = self.proj(x) 106 | # B C H W -> B H W C 107 | x = x.permute(0, 2, 3, 1) 108 | return x 109 | 110 | 111 | class get_PMD_embeddings(nn.Module): 112 | def __init__(self): 113 | super(get_PMD_embeddings, self).__init__() 114 | 115 | self.PMD_Head = Get_curvature() 116 | 117 | self.patch_embed = PatchEmbed( 118 | kernel_size=(16, 16), 119 | stride=(16, 16), 120 | in_chans=3, 121 | embed_dim=1280, 122 | ) 123 | 124 | self.pos_embed: Optional[nn.Parameter] = None 125 | # Initialize absolute positional embedding with pretrain image size. 126 | self.pos_embed = nn.Parameter( 127 | torch.zeros(1, 1024 // 16, 1024 // 16, 1280) 128 | ) 129 | 130 | self.neck = nn.Sequential( 131 | nn.Conv2d( 132 | 1280, 133 | 256, 134 | kernel_size=1, 135 | bias=False, 136 | ), 137 | LayerNorm2d(256), 138 | nn.Conv2d( 139 | 256, 140 | 256, 141 | kernel_size=3, 142 | padding=1, 143 | bias=False, 144 | ), 145 | LayerNorm2d(256), 146 | ) 147 | 148 | def forward(self, images): 149 | images_PMD = self.PMD_Head(images) 150 | 151 | PMD_patch = self.patch_embed(images_PMD) 152 | PMD_pos_embeddings = PMD_patch + self.pos_embed 153 | 154 | PMD_embeddings = self.neck(PMD_pos_embeddings.permute(0, 3, 1, 2)) 155 | 156 | return(PMD_embeddings) 157 | 158 | -------------------------------------------------------------------------------- /segment_anything_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /segment_anything_training/utils/__pycache__/PMD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/utils/__pycache__/PMD.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/segment_anything_training/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /segment_anything_training/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/log.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss_mask.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/loss_mask.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IPIC-Lab/IRSAM/ec65744f582f49fc8d83a174500b10fc5e594506/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright by HQ-SAM team 2 | # All rights reserved. 3 | 4 | # data loader 5 | from __future__ import print_function, division 6 | 7 | import cv2 8 | import numpy as np 9 | import random 10 | from copy import deepcopy 11 | from skimage import io 12 | import os 13 | from glob import glob 14 | 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 17 | from torchvision import transforms, utils 18 | from torchvision.transforms.functional import normalize 19 | import torch.nn.functional as F 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | 23 | # --------------------- dataloader online ---------------------#### 24 | 25 | def get_im_gt_name_dict(datasets, flag='valid'): 26 | print("------------------------------", flag, "--------------------------------") 27 | name_im_gt_list = [] 28 | 29 | for i in range(len(datasets)): 30 | print("--->>>", flag, " dataset ", i, "/", len(datasets), " ", datasets[i]["name"], "<<<---") 31 | tmp_im_list, tmp_gt_list = [], [] 32 | tmp_im_list = glob(datasets[i]["im_dir"] + os.sep + '*' + datasets[i]["im_ext"]) 33 | print('-im-', datasets[i]["name"], datasets[i]["im_dir"], ': ', len(tmp_im_list)) 34 | 35 | if datasets[i]["gt_dir"] == "": 36 | print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') 37 | tmp_gt_list = [] 38 | else: 39 | tmp_gt_list = [ 40 | datasets[i]["gt_dir"] + os.sep + x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0] + datasets[i][ 41 | "gt_ext"] for x in tmp_im_list] 42 | print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', len(tmp_gt_list)) 43 | 44 | name_im_gt_list.append({"dataset_name": datasets[i]["name"], 45 | "im_path": tmp_im_list, 46 | "gt_path": tmp_gt_list, 47 | "im_ext": datasets[i]["im_ext"], 48 | "gt_ext": datasets[i]["gt_ext"]}) 49 | 50 | return name_im_gt_list 51 | 52 | 53 | def create_dataloaders(name_im_gt_list, my_transforms=[], batch_size=1, training=False): 54 | gos_dataloaders = [] 55 | gos_datasets = [] 56 | 57 | if len(name_im_gt_list) == 0: 58 | return gos_dataloaders, gos_datasets 59 | 60 | num_workers_ = 1 61 | if batch_size > 1: 62 | num_workers_ = 2 63 | if batch_size > 4: 64 | num_workers_ = 4 65 | if batch_size > 8: 66 | num_workers_ = 8 67 | 68 | if training: 69 | for i in range(len(name_im_gt_list)): 70 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform=transforms.Compose(my_transforms)) 71 | gos_datasets.append(gos_dataset) 72 | 73 | gos_dataset = ConcatDataset(gos_datasets) 74 | dataloader = DataLoader(gos_dataset, batch_size=batch_size, shuffle=True) 75 | 76 | gos_dataloaders = dataloader 77 | gos_datasets = gos_dataset 78 | 79 | else: 80 | for i in range(len(name_im_gt_list)): 81 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform=transforms.Compose(my_transforms), 82 | eval_ori_resolution=True) 83 | dataloader = DataLoader(gos_dataset, batch_size=batch_size) 84 | 85 | gos_dataloaders.append(dataloader) 86 | gos_datasets.append(gos_dataset) 87 | 88 | return gos_dataloaders, gos_datasets 89 | 90 | 91 | class RandomHFlip(object): 92 | def __init__(self, prob=0.5): 93 | self.prob = prob 94 | 95 | def __call__(self, sample): 96 | imidx, image, label, edge, shape = sample['imidx'], sample['image'], sample['label'], sample['edge'], sample['shape'] 97 | 98 | # random horizontal flip 99 | if random.random() >= self.prob: 100 | image = torch.flip(image, dims=[2]) 101 | label = torch.flip(label, dims=[2]) 102 | edge = torch.flip(edge, dims=[2]) 103 | 104 | return {'imidx': imidx, 'image': image, 'label': label, 'edge': edge, 'shape': shape} 105 | 106 | 107 | class Resize(object): 108 | def __init__(self, size=[320, 320]): 109 | self.size = size 110 | 111 | def __call__(self, sample): 112 | imidx, image, label, edge, shape = sample['imidx'], sample['image'], sample['label'], sample['edge'], sample['shape'] 113 | 114 | image = torch.squeeze(F.interpolate(torch.unsqueeze(image, 0), self.size, mode='bilinear'), dim=0) 115 | label = torch.squeeze(F.interpolate(torch.unsqueeze(label, 0), self.size, mode='bilinear'), dim=0) 116 | edge = torch.squeeze(F.interpolate(torch.unsqueeze(edge, 0), self.size, mode='bilinear'), dim=0) 117 | 118 | return {'imidx': imidx, 'image': image, 'label': label, 'edge':edge, 'shape': torch.tensor(self.size)} 119 | 120 | 121 | class RandomCrop(object): 122 | def __init__(self, size=[288, 288]): 123 | self.size = size 124 | 125 | def __call__(self, sample): 126 | imidx, image, label, edge, shape = sample['imidx'], sample['image'], sample['label'], sample['edge'], sample['shape'] 127 | 128 | h, w = image.shape[1:] 129 | new_h, new_w = self.size 130 | 131 | top = np.random.randint(0, h - new_h) 132 | left = np.random.randint(0, w - new_w) 133 | 134 | image = image[:, top:top + new_h, left:left + new_w] 135 | label = label[:, top:top + new_h, left:left + new_w] 136 | edge = edge[:, top:top + new_h, left:left + new_w] 137 | 138 | return {'imidx': imidx, 'image': image, 'label': label, 'edge': edge, 'shape': torch.tensor(self.size)} 139 | 140 | 141 | class Normalize(object): 142 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 143 | self.mean = mean 144 | self.std = std 145 | 146 | def __call__(self, sample): 147 | imidx, image, label, edge, shape = sample['imidx'], sample['image'], sample['label'], sample['edge'], sample['shape'] 148 | image = normalize(image, self.mean, self.std) 149 | 150 | return {'imidx': imidx, 'image': image, 'label': label, 'edge': edge, 'shape': shape} 151 | 152 | 153 | class LargeScaleJitter(object): 154 | """ 155 | implementation of large scale jitter from copy_paste 156 | https://github.com/gaopengcuhk/Pretrained-Pix2Seq/blob/7d908d499212bfabd33aeaa838778a6bfb7b84cc/datasets/transforms.py 157 | """ 158 | 159 | def __init__(self, output_size=512, aug_scale_min=0.8, aug_scale_max=1.5): 160 | self.desired_size = torch.tensor(output_size) 161 | self.aug_scale_min = aug_scale_min 162 | self.aug_scale_max = aug_scale_max 163 | 164 | def pad_target(self, padding, target): 165 | target = target.copy() 166 | if "masks" in target: 167 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0])) 168 | return target 169 | 170 | def __call__(self, sample): 171 | imidx, image, label, edge, image_size = sample['imidx'], sample['image'], sample['label'], sample['edge'], sample['shape'] 172 | 173 | # resize keep ratio 174 | out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() 175 | 176 | random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min 177 | scaled_size = (random_scale * self.desired_size).round() 178 | 179 | scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) 180 | scaled_size = (image_size * scale).round().long() 181 | 182 | scaled_image = torch.squeeze(F.interpolate(torch.unsqueeze(image, 0), scaled_size.tolist(), mode='bilinear'), 183 | dim=0) 184 | scaled_label = torch.squeeze(F.interpolate(torch.unsqueeze(label, 0), scaled_size.tolist(), mode='bilinear'), 185 | dim=0) 186 | scaled_edge = torch.squeeze(F.interpolate(torch.unsqueeze(edge, 0), scaled_size.tolist(), mode='bilinear'), 187 | dim=0) 188 | 189 | # random crop 190 | crop_size = (min(self.desired_size, scaled_size[0]), min(self.desired_size, scaled_size[1])) 191 | 192 | margin_h = max(scaled_size[0] - crop_size[0], 0).item() 193 | margin_w = max(scaled_size[1] - crop_size[1], 0).item() 194 | offset_h = np.random.randint(0, margin_h + 1) 195 | offset_w = np.random.randint(0, margin_w + 1) 196 | crop_y1, crop_y2 = offset_h, offset_h + crop_size[0].item() 197 | crop_x1, crop_x2 = offset_w, offset_w + crop_size[1].item() 198 | 199 | scaled_image = scaled_image[:, crop_y1:crop_y2, crop_x1:crop_x2] 200 | scaled_label = scaled_label[:, crop_y1:crop_y2, crop_x1:crop_x2] 201 | scaled_edge = scaled_edge[:, crop_y1:crop_y2, crop_x1:crop_x2] 202 | 203 | # pad 204 | padding_h = max(self.desired_size - scaled_image.size(1), 0).item() 205 | padding_w = max(self.desired_size - scaled_image.size(2), 0).item() 206 | image = F.pad(scaled_image, [0, padding_w, 0, padding_h], value=128) 207 | label = F.pad(scaled_label, [0, padding_w, 0, padding_h], value=0) 208 | edge = F.pad(scaled_edge, [0, padding_w, 0, padding_h], value=0) 209 | 210 | return {'imidx': imidx, 'image': image, 'label': label, 'edge': edge, 'shape': torch.tensor(image.shape[-2:])} 211 | 212 | 213 | class OnlineDataset(Dataset): 214 | def __init__(self, name_im_gt_list, transform=None, eval_ori_resolution=False): 215 | 216 | self.transform = transform 217 | self.dataset = {} 218 | # combine different datasets into one 219 | dataset_names = [] 220 | dt_name_list = [] # dataset name per image 221 | im_name_list = [] # image name 222 | im_path_list = [] # im path 223 | gt_path_list = [] # gt path 224 | im_ext_list = [] # im ext 225 | gt_ext_list = [] # gt ext 226 | for i in range(0, len(name_im_gt_list)): 227 | dataset_names.append(name_im_gt_list[i]["dataset_name"]) 228 | # dataset name repeated based on the number of images in this dataset 229 | dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) 230 | im_name_list.extend( 231 | [x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) 232 | im_path_list.extend(name_im_gt_list[i]["im_path"]) 233 | gt_path_list.extend(name_im_gt_list[i]["gt_path"]) 234 | im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) 235 | gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) 236 | 237 | self.dataset["data_name"] = dt_name_list 238 | self.dataset["im_name"] = im_name_list 239 | self.dataset["im_path"] = im_path_list 240 | self.dataset["ori_im_path"] = deepcopy(im_path_list) 241 | self.dataset["gt_path"] = gt_path_list 242 | self.dataset["ori_gt_path"] = deepcopy(gt_path_list) 243 | self.dataset["im_ext"] = im_ext_list 244 | self.dataset["gt_ext"] = gt_ext_list 245 | 246 | self.eval_ori_resolution = eval_ori_resolution 247 | 248 | def __len__(self): 249 | return len(self.dataset["im_path"]) 250 | 251 | def __getitem__(self, idx): 252 | im_path = self.dataset["im_path"][idx] 253 | gt_path = self.dataset["gt_path"][idx] 254 | im = io.imread(im_path) 255 | gt = io.imread(gt_path) 256 | 257 | if len(gt.shape) > 2: 258 | gt = gt[:, :, 0] 259 | if len(im.shape) < 3: 260 | im = im[:, :, np.newaxis] 261 | if im.shape[2] == 1: 262 | im = np.repeat(im, 3, axis=2) 263 | 264 | edge = cv2.Canny(gt, 100, 200) 265 | im = torch.tensor(im.copy(), dtype=torch.float32) 266 | im = torch.transpose(torch.transpose(im, 1, 2), 0, 1) 267 | gt = torch.unsqueeze(torch.tensor(gt, dtype=torch.float32), 0) 268 | edge = torch.unsqueeze(torch.tensor(edge, dtype=torch.float32), 0) 269 | 270 | sample = { 271 | "imidx": torch.from_numpy(np.array(idx)), 272 | "image": im, 273 | "label": gt, 274 | "edge": edge, 275 | "shape": torch.tensor(im.shape[-2:]), 276 | "path": self.dataset["im_path"][idx] 277 | } 278 | 279 | if self.transform: 280 | sample = self.transform(sample) 281 | 282 | if self.eval_ori_resolution: 283 | sample["ori_label"] = gt.type(torch.uint8) # NOTE for evaluation only. And no flip here 284 | sample['ori_im_path'] = self.dataset["im_path"][idx] 285 | sample['ori_gt_path'] = self.dataset["gt_path"][idx] 286 | 287 | return sample 288 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import gmtime, strftime 3 | import os.path 4 | 5 | 6 | def initialize_logger(output_dir): 7 | logname = strftime("%Y_%m_%d_%H_%M_%S", gmtime()) 8 | 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.DEBUG) 11 | 12 | # create console handler and set level to info 13 | handler = logging.StreamHandler() 14 | handler.setLevel(logging.INFO) 15 | formatter = \ 16 | logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 17 | handler.setFormatter(formatter) 18 | logger.addHandler(handler) 19 | 20 | handler = logging.FileHandler(os.path.join(output_dir, \ 21 | "info_{}.log".format(logname)), "w", encoding=None, delay="true") 22 | handler.setLevel(logging.INFO) 23 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s") 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | 27 | # # create error file handler and set level to error 28 | # handler = logging.FileHandler(os.path.join(output_dir, "error.log"),"w", encoding=None, delay="true") 29 | # handler.setLevel(logging.ERROR) 30 | # formatter = logging.Formatter("%(levelname)s - %(message)s") 31 | # handler.setFormatter(formatter) 32 | # logger.addHandler(handler) 33 | 34 | # # create debug file handler and set level to debug 35 | # handler = logging.FileHandler(os.path.join(output_dir, "all.log"),"w") 36 | # handler.setLevel(logging.DEBUG) 37 | # formatter = logging.Formatter("%(levelname)s - %(message)s") 38 | # handler.setFormatter(formatter) 39 | # logger.addHandler(handler) 40 | -------------------------------------------------------------------------------- /utils/loss_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import List, Optional 4 | 5 | 6 | def point_sample(input, point_coords, **kwargs): 7 | """ 8 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 9 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 10 | [0, 1] x [0, 1] square. 11 | Args: 12 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 13 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 14 | [0, 1] x [0, 1] normalized point coordinates. 15 | Returns: 16 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 17 | features for points in `point_coords`. The features are obtained via bilinear 18 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 19 | """ 20 | add_dim = False 21 | if point_coords.dim() == 3: 22 | add_dim = True 23 | point_coords = point_coords.unsqueeze(2) 24 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 25 | if add_dim: 26 | output = output.squeeze(3) 27 | return output 28 | 29 | 30 | def cat(tensors: List[torch.Tensor], dim: int = 0): 31 | """ 32 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 33 | """ 34 | assert isinstance(tensors, (list, tuple)) 35 | if len(tensors) == 1: 36 | return tensors[0] 37 | return torch.cat(tensors, dim) 38 | 39 | 40 | def get_uncertain_point_coords_with_randomness( 41 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 42 | ): 43 | """ 44 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 45 | are calculated for each point using 'uncertainty_func' function that takes point's logit 46 | prediction as input. 47 | See PointRend paper for details. 48 | Args: 49 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 50 | class-specific or class-agnostic prediction. 51 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 52 | contains logit predictions for P points and returns their uncertainties as a Tensor of 53 | shape (N, 1, P). 54 | num_points (int): The number of points P to sample. 55 | oversample_ratio (int): Oversampling parameter. 56 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 57 | Returns: 58 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 59 | sampled points. 60 | """ 61 | assert oversample_ratio >= 1 62 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 63 | num_boxes = coarse_logits.shape[0] 64 | num_sampled = int(num_points * oversample_ratio) 65 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 66 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 67 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 68 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 69 | # to incorrect results. 70 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 71 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 72 | # However, if we calculate uncertainties for the coarse predictions first, 73 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 74 | point_uncertainties = uncertainty_func(point_logits) 75 | num_uncertain_points = int(importance_sample_ratio * num_points) 76 | num_random_points = num_points - num_uncertain_points 77 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 78 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 79 | idx += shift[:, None] 80 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 81 | num_boxes, num_uncertain_points, 2 82 | ) 83 | if num_random_points > 0: 84 | point_coords = cat( 85 | [ 86 | point_coords, 87 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 88 | ], 89 | dim=1, 90 | ) 91 | return point_coords 92 | 93 | def DICE_loss( 94 | inputs: torch.Tensor, 95 | targets: torch.Tensor, 96 | eps = 1e-6, 97 | ): 98 | """ 99 | Compute the DICE loss, similar to generalized IOU for masks 100 | Args: 101 | inputs: A float tensor of arbitrary shape. 102 | The predictions for each example. 103 | targets: A float tensor with the same shape as inputs. Stores the binary 104 | classification label for each element in inputs 105 | (0 for the negative class and 1 for the positive class). 106 | """ 107 | inputs = inputs.sigmoid() 108 | 109 | inputs = inputs.view(-1) 110 | targets = targets.view(-1) 111 | 112 | intersection = 2 * (inputs * targets).sum() 113 | denominator = inputs.sum() + targets.sum() 114 | 115 | inter = (inputs * targets).sum() 116 | comb = inputs.sum() + targets.sum() - inter 117 | 118 | loss = 1 - (intersection+eps) / (denominator+eps) 119 | loss_iou = 1 - (inter+eps) / (comb+eps) 120 | 121 | return loss.mean(), loss_iou.mean() 122 | 123 | 124 | def dice_loss( 125 | inputs: torch.Tensor, 126 | targets: torch.Tensor, 127 | num_masks: float, 128 | ): 129 | """ 130 | Compute the DICE loss, similar to generalized IOU for masks 131 | Args: 132 | inputs: A float tensor of arbitrary shape. 133 | The predictions for each example. 134 | targets: A float tensor with the same shape as inputs. Stores the binary 135 | classification label for each element in inputs 136 | (0 for the negative class and 1 for the positive class). 137 | """ 138 | inputs = inputs.sigmoid() 139 | inputs = inputs.flatten(1) 140 | numerator = 2 * (inputs * targets).sum(-1) 141 | denominator = inputs.sum(-1) + targets.sum(-1) 142 | loss = 1 - (numerator + 1) / (denominator + 1) 143 | return loss.sum() / num_masks 144 | 145 | 146 | dice_loss_jit = torch.jit.script( 147 | dice_loss 148 | ) # type: torch.jit.ScriptModule 149 | 150 | 151 | def sigmoid_ce_loss( 152 | inputs: torch.Tensor, 153 | targets: torch.Tensor, 154 | num_masks: float, 155 | ): 156 | """ 157 | Args: 158 | inputs: A float tensor of arbitrary shape. 159 | The predictions for each example. 160 | targets: A float tensor with the same shape as inputs. Stores the binary 161 | classification label for each element in inputs 162 | (0 for the negative class and 1 for the positive class). 163 | Returns: 164 | Loss tensor 165 | """ 166 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 167 | 168 | return loss.mean(1).sum() / num_masks 169 | 170 | 171 | sigmoid_ce_loss_jit = torch.jit.script( 172 | sigmoid_ce_loss 173 | ) # type: torch.jit.ScriptModule 174 | 175 | 176 | def calculate_uncertainty(logits): 177 | """ 178 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 179 | foreground class in `classes`. 180 | Args: 181 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 182 | class-agnostic, where R is the total number of predicted masks in all images and C is 183 | the number of foreground classes. The values are logits. 184 | Returns: 185 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 186 | the most uncertain locations having the highest uncertainty score. 187 | """ 188 | assert logits.shape[1] == 1 189 | gt_class_logits = logits.clone() 190 | return -(torch.abs(gt_class_logits)) 191 | 192 | 193 | def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0): 194 | """Compute the losses related to the masks: the focal loss and the dice loss. 195 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 196 | """ 197 | 198 | # No need to upsample predictions as we are using normalized coordinates :) 199 | 200 | with torch.no_grad(): 201 | # sample point_coords 202 | point_coords = get_uncertain_point_coords_with_randomness( 203 | src_masks, 204 | lambda logits: calculate_uncertainty(logits), 205 | 112 * 112, 206 | oversample_ratio, 207 | 0.75, 208 | ) 209 | # get gt labels 210 | point_labels = point_sample( 211 | target_masks, 212 | point_coords, 213 | align_corners=False, 214 | ).squeeze(1) 215 | 216 | point_logits = point_sample( 217 | src_masks, 218 | point_coords, 219 | align_corners=False, 220 | ).squeeze(1) 221 | 222 | loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks) 223 | loss_dice = dice_loss_jit(point_logits, point_labels, num_masks) 224 | 225 | del src_masks 226 | del target_masks 227 | return loss_mask, loss_dice 228 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from skimage import measure 5 | import numpy 6 | 7 | 8 | class ROCMetric(): 9 | """Computes pixAcc and mIoU metric scores 10 | """ 11 | 12 | def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 13 | super(ROCMetric, self).__init__() 14 | self.nclass = nclass 15 | self.bins = bins 16 | self.tp_arr = np.zeros(self.bins + 1) 17 | self.pos_arr = np.zeros(self.bins + 1) 18 | self.fp_arr = np.zeros(self.bins + 1) 19 | self.neg_arr = np.zeros(self.bins + 1) 20 | self.class_pos = np.zeros(self.bins + 1) 21 | # self.reset() 22 | 23 | def update(self, preds, labels): 24 | for iBin in range(self.bins + 1): 25 | score_thresh = (iBin + 0.0) / self.bins 26 | # print(iBin, "-th, score_thresh: ", score_thresh) 27 | i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) 28 | self.tp_arr[iBin] += i_tp 29 | self.pos_arr[iBin] += i_pos 30 | self.fp_arr[iBin] += i_fp 31 | self.neg_arr[iBin] += i_neg 32 | self.class_pos[iBin] += i_class_pos 33 | 34 | def get(self): 35 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 36 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 37 | 38 | recall = self.tp_arr / (self.pos_arr + 0.001) 39 | precision = self.tp_arr / (self.class_pos + 0.001) 40 | 41 | return tp_rates, fp_rates, recall, precision 42 | 43 | def reset(self): 44 | self.tp_arr = np.zeros([11]) 45 | self.pos_arr = np.zeros([11]) 46 | self.fp_arr = np.zeros([11]) 47 | self.neg_arr = np.zeros([11]) 48 | self.class_pos = np.zeros([11]) 49 | 50 | 51 | class PD_FA(): 52 | def __init__(self, nclass, bins): 53 | super(PD_FA, self).__init__() 54 | self.nclass = nclass 55 | self.bins = bins 56 | self.image_area_total = [] 57 | self.image_area_match = [] 58 | self.FA = np.zeros(self.bins + 1) 59 | self.PD = np.zeros(self.bins + 1) 60 | self.target = np.zeros(self.bins + 1) 61 | 62 | def update(self, preds, labels): 63 | W = preds.shape[3] 64 | for iBin in range(self.bins + 1): 65 | score_thresh = iBin * (255 / self.bins) 66 | predits = np.array((preds > score_thresh).cpu()).astype('int64') 67 | if W == 512: 68 | predits = np.reshape(predits, (512, 512)) # 512 69 | labelss = np.array((labels).cpu()).astype('int64') # P 70 | labelss = np.reshape(labelss, (512, 512)) # 512 71 | elif W == 384: 72 | predits = np.reshape(predits, (384, 384)) # 512 73 | labelss = np.array((labels).cpu()).astype('int64') # P 74 | labelss = np.reshape(labelss, (384, 384)) # 512 75 | else: 76 | predits = np.reshape(predits, (512 // 2, 512 // 2)) # 512 77 | labelss = np.array((labels).cpu()).astype('int64') # P 78 | labelss = np.reshape(labelss, (512 // 2, 512 // 2)) # 512 79 | 80 | image = measure.label(predits, connectivity=2) 81 | coord_image = measure.regionprops(image) 82 | label = measure.label(labelss, connectivity=2) 83 | coord_label = measure.regionprops(label) 84 | 85 | self.target[iBin] += len(coord_label) 86 | self.image_area_total = [] 87 | self.image_area_match = [] 88 | self.distance_match = [] 89 | self.dismatch = [] 90 | 91 | for K in range(len(coord_image)): 92 | area_image = np.array(coord_image[K].area) 93 | self.image_area_total.append(area_image) 94 | 95 | for i in range(len(coord_label)): 96 | centroid_label = np.array(list(coord_label[i].centroid)) 97 | for m in range(len(coord_image)): 98 | centroid_image = np.array(list(coord_image[m].centroid)) 99 | distance = np.linalg.norm(centroid_image - centroid_label) 100 | area_image = np.array(coord_image[m].area) 101 | if distance < 3: 102 | self.distance_match.append(distance) 103 | self.image_area_match.append(area_image) 104 | 105 | del coord_image[m] 106 | break 107 | 108 | self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] 109 | self.FA[iBin] += np.sum(self.dismatch) 110 | self.PD[iBin] += len(self.distance_match) 111 | # print(len(self.image_area_total)) 112 | 113 | def get(self, img_num): 114 | 115 | Final_FA = self.FA / ((512 * 512) * img_num) # 512 116 | Final_PD = self.PD / self.target 117 | 118 | return Final_FA, Final_PD 119 | 120 | def reset(self): 121 | self.FA = np.zeros([self.bins + 1]) 122 | self.PD = np.zeros([self.bins + 1]) 123 | self.target = np.zeros(self.bins + 1) 124 | 125 | 126 | class mIoU(): 127 | 128 | def __init__(self, nclass): 129 | super(mIoU, self).__init__() 130 | self.nclass = nclass 131 | self.reset() 132 | 133 | def update(self, preds, labels): 134 | # print('come_ininin') 135 | 136 | correct, labeled = batch_pix_accuracy(preds, labels) 137 | inter, union = batch_intersection_union(preds, labels, self.nclass) 138 | self.total_correct += correct 139 | self.total_label += labeled 140 | self.total_inter += inter 141 | self.total_union += union 142 | 143 | def get(self): 144 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 145 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 146 | mIoU = IoU.mean() 147 | return pixAcc, mIoU 148 | 149 | def reset(self): 150 | self.total_inter = 0 151 | self.total_union = 0 152 | self.total_correct = 0 153 | self.total_label = 0 154 | 155 | 156 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 157 | predict = (torch.sigmoid(output) > score_thresh).float() 158 | if len(target.shape) == 3: 159 | target = np.expand_dims(target.float(), axis=1) 160 | elif len(target.shape) == 4: 161 | target = target.float() 162 | else: 163 | raise ValueError("Unknown target dimension") 164 | 165 | intersection = predict * (predict == target) 166 | 167 | tp = intersection.sum() 168 | fp = (predict * (predict != target)).sum() 169 | tn = ((1 - predict) * (predict == target)).sum() 170 | fn = ((predict != target) * (1 - predict)).sum() 171 | pos = tp + fn 172 | neg = fp + tn 173 | class_pos = tp + fp 174 | 175 | return tp, pos, fp, neg, class_pos 176 | 177 | 178 | def batch_pix_accuracy(output, target): 179 | if len(target.shape) == 3: 180 | target = np.expand_dims(target.float(), axis=1) 181 | elif len(target.shape) == 4: 182 | target = target.float() 183 | else: 184 | raise ValueError("Unknown target dimension") 185 | 186 | assert output.shape == target.shape, "Predict and Label Shape Don't Match" 187 | predict = (output > 0).float() 188 | pixel_labeled = (target > 0).float().sum() 189 | pixel_correct = (((predict == target).float()) * ((target > 0)).float()).sum() 190 | 191 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 192 | return pixel_correct, pixel_labeled 193 | 194 | 195 | def batch_intersection_union(output, target, nclass): 196 | mini = 1 197 | maxi = 1 198 | nbins = 1 199 | predict = (output > 0).float() 200 | if len(target.shape) == 3: 201 | target = np.expand_dims(target.float(), axis=1) 202 | elif len(target.shape) == 4: 203 | target = target.float() 204 | else: 205 | raise ValueError("Unknown target dimension") 206 | intersection = predict * ((predict == target).float()) 207 | 208 | area_inter, _ = np.histogram(intersection.cpu(), bins=nbins, range=(mini, maxi)) 209 | area_pred, _ = np.histogram(predict.cpu(), bins=nbins, range=(mini, maxi)) 210 | area_lab, _ = np.histogram(target.cpu(), bins=nbins, range=(mini, maxi)) 211 | area_union = area_pred + area_lab - area_inter 212 | 213 | assert (area_inter <= area_union).all(), \ 214 | "Error: Intersection area should be smaller than Union area" 215 | return area_inter, area_union 216 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | class SigmoidMetric(): 9 | def __init__(self): 10 | self.reset() 11 | 12 | def update(self, pred, labels): 13 | correct, labeled = self.batch_pix_accuracy(pred, labels) 14 | inter, union = self.batch_intersection_union(pred, labels) 15 | 16 | self.total_correct += correct 17 | self.total_label += labeled 18 | self.total_inter += inter 19 | self.total_union += union 20 | 21 | def get(self): 22 | """Gets the current evaluation result.""" 23 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 24 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 25 | mIoU = IoU.mean() 26 | return pixAcc, mIoU 27 | 28 | def reset(self): 29 | """Resets the internal evaluation result to initial state.""" 30 | self.total_inter = 0 31 | self.total_union = 0 32 | self.total_correct = 0 33 | self.total_label = 0 34 | 35 | def batch_pix_accuracy(self, output, target): 36 | assert output.shape == target.shape 37 | output = output.detach().numpy() 38 | target = target.detach().numpy() 39 | 40 | predict = (output > 0).astype('int64') # P 41 | pixel_labeled = np.sum(target > 0) # T 42 | pixel_correct = np.sum((predict == target) * (target > 0)) # TP 43 | assert pixel_correct <= pixel_labeled 44 | return pixel_correct, pixel_labeled 45 | 46 | def batch_intersection_union(self, output, target): 47 | mini = 1 48 | maxi = 1 # nclass 49 | nbins = 1 # nclass 50 | predict = (output.detach().numpy() > 0).astype('int64') # P 51 | target = target.numpy().astype('int64') # T 52 | intersection = predict * (predict == target) # TP 53 | 54 | # areas of intersection and union 55 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 56 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 57 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 58 | area_union = area_pred + area_lab - area_inter 59 | assert (area_inter <= area_union).all() 60 | return area_inter, area_union 61 | 62 | 63 | class SamplewiseSigmoidMetric(): 64 | def __init__(self, nclass, score_thresh=0.5): 65 | self.nclass = nclass 66 | self.score_thresh = score_thresh 67 | self.reset() 68 | 69 | def update(self, preds, labels): 70 | """Updates the internal evaluation result.""" 71 | inter_arr, union_arr = self.batch_intersection_union(preds, labels, 72 | self.nclass, self.score_thresh) 73 | self.total_inter = np.append(self.total_inter, inter_arr) 74 | self.total_union = np.append(self.total_union, union_arr) 75 | 76 | def get(self): 77 | """Gets the current evaluation result.""" 78 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 79 | mIoU = IoU.mean() 80 | return IoU, mIoU 81 | 82 | def reset(self): 83 | """Resets the internal evaluation result to initial state.""" 84 | self.total_inter = np.array([]) 85 | self.total_union = np.array([]) 86 | self.total_correct = np.array([]) 87 | self.total_label = np.array([]) 88 | 89 | def batch_intersection_union(self, output, target, nclass, score_thresh): 90 | """mIoU""" 91 | # inputs are tensor 92 | # the category 0 is ignored class, typically for background / boundary 93 | mini = 1 94 | maxi = 1 # nclass 95 | nbins = 1 # nclass 96 | 97 | predict = (F.sigmoid(output).detach().numpy() > score_thresh).astype('int64') # P 98 | target = target.detach().numpy().astype('int64') # T 99 | intersection = predict * (predict == target) # TP 100 | 101 | num_sample = intersection.shape[0] 102 | area_inter_arr = np.zeros(num_sample) 103 | area_pred_arr = np.zeros(num_sample) 104 | area_lab_arr = np.zeros(num_sample) 105 | area_union_arr = np.zeros(num_sample) 106 | 107 | for b in range(num_sample): 108 | # areas of intersection and union 109 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 110 | area_inter_arr[b] = area_inter 111 | 112 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 113 | area_pred_arr[b] = area_pred 114 | 115 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 116 | area_lab_arr[b] = area_lab 117 | 118 | area_union = area_pred + area_lab - area_inter 119 | area_union_arr[b] = area_union 120 | 121 | assert (area_inter <= area_union).all() 122 | 123 | return area_inter_arr, area_union_arr 124 | 125 | 126 | class ROCMetric(): 127 | def __init__(self, nclass, bins): 128 | self.nclass = nclass 129 | self.bins = bins 130 | self.tp_arr = np.zeros(self.bins + 1) 131 | self.pos_arr = np.zeros(self.bins + 1) 132 | self.fp_arr = np.zeros(self.bins + 1) 133 | self.neg_arr = np.zeros(self.bins + 1) 134 | 135 | def update(self, preds, labels): 136 | for iBin in range(self.bins + 1): 137 | score_thresh = (iBin + 0.0) / self.bins 138 | i_tp, i_pos, i_fp, i_neg = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) 139 | 140 | self.tp_arr[iBin] += i_tp 141 | self.pos_arr[iBin] += i_pos 142 | self.fp_arr[iBin] += i_fp 143 | self.neg_arr[iBin] += i_neg 144 | 145 | def get(self): 146 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 147 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 148 | 149 | return tp_rates, fp_rates 150 | 151 | 152 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 153 | mini = 1 154 | maxi = 1 # nclass 155 | nbins = 1 # nclass 156 | 157 | predict = (F.sigmoid(output).detach().numpy() > score_thresh).astype('int64') # P 158 | target = target.detach().numpy().astype('int64') # T 159 | intersection = predict * (predict == target) # TP 160 | tp = intersection.sum() 161 | fp = (predict * (predict != target)).sum() # FP 162 | tn = ((1 - predict) * (predict == target)).sum() # TN 163 | fn = ((predict != target) * (1 - predict)).sum() # FN 164 | pos = tp + fn 165 | neg = fp + tn 166 | return tp, pos, fp, neg 167 | 168 | # pred = torch.randn(4,16,256,256) 169 | # target = torch.randn(4,16,256,256) 170 | # iou_m = SigmoidMetric() 171 | # iou_m.reset() 172 | # iou_m.update(pred,target) 173 | # iou = iou_m.get() 174 | # print(iou) 175 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Misc functions, including distributed helpers. 4 | 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | import os 8 | import random 9 | import subprocess 10 | import time 11 | from collections import OrderedDict, defaultdict, deque 12 | import datetime 13 | import pickle 14 | from typing import Optional, List 15 | from packaging import version 16 | 17 | import json, time 18 | import numpy as np 19 | import torch 20 | import torch.distributed as dist 21 | from torch import Tensor 22 | 23 | import colorsys 24 | import torch.nn.functional as F 25 | 26 | import cv2 27 | 28 | # needed due to empty tensor bug in pytorch and torchvision 0.5 29 | import torchvision 30 | if version.parse(torchvision.__version__) < version.parse('0.7'): 31 | from torchvision.ops import _new_empty_tensor 32 | from torchvision.ops.misc import _output_size 33 | 34 | 35 | class SmoothedValue(object): 36 | """Track a series of values and provide access to smoothed values over a 37 | window or the global series average. 38 | """ 39 | 40 | def __init__(self, window_size=20, fmt=None): 41 | if fmt is None: 42 | fmt = "{median:.4f} ({global_avg:.4f})" 43 | self.deque = deque(maxlen=window_size) 44 | self.total = 0.0 45 | self.count = 0 46 | self.fmt = fmt 47 | 48 | def update(self, value, n=1): 49 | self.deque.append(value) 50 | self.count += n 51 | self.total += value * n 52 | 53 | def synchronize_between_processes(self): 54 | """ 55 | Warning: does not synchronize the deque! 56 | """ 57 | if not is_dist_avail_and_initialized(): 58 | return 59 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 60 | dist.barrier() 61 | dist.all_reduce(t) 62 | t = t.tolist() 63 | self.count = int(t[0]) 64 | self.total = t[1] 65 | 66 | @property 67 | def median(self): 68 | d = torch.tensor(list(self.deque)) 69 | if d.shape[0] == 0: 70 | return 0 71 | return d.median().item() 72 | 73 | @property 74 | def avg(self): 75 | d = torch.tensor(list(self.deque), dtype=torch.float32) 76 | return d.mean().item() 77 | 78 | @property 79 | def global_avg(self): 80 | return self.total / self.count 81 | 82 | @property 83 | def max(self): 84 | return max(self.deque) 85 | 86 | @property 87 | def value(self): 88 | return self.deque[-1] 89 | 90 | def __str__(self): 91 | return self.fmt.format( 92 | median=self.median, 93 | avg=self.avg, 94 | global_avg=self.global_avg, 95 | max=self.max, 96 | value=self.value) 97 | 98 | 99 | def all_gather(data): 100 | """ 101 | Run all_gather on arbitrary picklable data (not necessarily tensors) 102 | Args: 103 | data: any picklable object 104 | Returns: 105 | list[data]: list of data gathered from each rank 106 | """ 107 | world_size = get_world_size() 108 | if world_size == 1: 109 | return [data] 110 | 111 | # serialized to a Tensor 112 | buffer = pickle.dumps(data) 113 | storage = torch.ByteStorage.from_buffer(buffer) 114 | tensor = torch.ByteTensor(storage).to("cuda") 115 | 116 | # obtain Tensor size of each rank 117 | local_size = torch.tensor([tensor.numel()], device="cuda") 118 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 119 | dist.all_gather(size_list, local_size) 120 | size_list = [int(size.item()) for size in size_list] 121 | max_size = max(size_list) 122 | 123 | # receiving Tensor from all ranks 124 | # we pad the tensor because torch all_gather does not support 125 | # gathering tensors of different shapes 126 | tensor_list = [] 127 | for _ in size_list: 128 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 129 | if local_size != max_size: 130 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 131 | tensor = torch.cat((tensor, padding), dim=0) 132 | dist.all_gather(tensor_list, tensor) 133 | 134 | data_list = [] 135 | for size, tensor in zip(size_list, tensor_list): 136 | buffer = tensor.cpu().numpy().tobytes()[:size] 137 | data_list.append(pickle.loads(buffer)) 138 | 139 | return data_list 140 | 141 | 142 | def reduce_dict(input_dict, average=True): 143 | """ 144 | Args: 145 | input_dict (dict): all the values will be reduced 146 | average (bool): whether to do average or sum 147 | Reduce the values in the dictionary from all processes so that all processes 148 | have the averaged results. Returns a dict with the same fields as 149 | input_dict, after reduction. 150 | """ 151 | world_size = get_world_size() 152 | if world_size < 2: 153 | return input_dict 154 | with torch.no_grad(): 155 | names = [] 156 | values = [] 157 | # sort the keys so that they are consistent across processes 158 | for k in sorted(input_dict.keys()): 159 | names.append(k) 160 | values.append(input_dict[k]) 161 | values = torch.stack(values, dim=0) 162 | dist.all_reduce(values) 163 | if average: 164 | values /= world_size 165 | reduced_dict = {k: v for k, v in zip(names, values)} 166 | return reduced_dict 167 | 168 | 169 | class MetricLogger(object): 170 | def __init__(self, delimiter="\t"): 171 | self.meters = defaultdict(SmoothedValue) 172 | self.delimiter = delimiter 173 | 174 | def update(self, **kwargs): 175 | for k, v in kwargs.items(): 176 | if isinstance(v, torch.Tensor): 177 | v = v.item() 178 | assert isinstance(v, (float, int)) 179 | self.meters[k].update(v) 180 | 181 | def __getattr__(self, attr): 182 | if attr in self.meters: 183 | return self.meters[attr] 184 | if attr in self.__dict__: 185 | return self.__dict__[attr] 186 | raise AttributeError("'{}' object has no attribute '{}'".format( 187 | type(self).__name__, attr)) 188 | 189 | def __str__(self): 190 | loss_str = [] 191 | for name, meter in self.meters.items(): 192 | # print(name, str(meter)) 193 | # import ipdb;ipdb.set_trace() 194 | if meter.count > 0: 195 | loss_str.append( 196 | "{}: {}".format(name, str(meter)) 197 | ) 198 | return self.delimiter.join(loss_str) 199 | 200 | def synchronize_between_processes(self): 201 | for meter in self.meters.values(): 202 | meter.synchronize_between_processes() 203 | 204 | def add_meter(self, name, meter): 205 | self.meters[name] = meter 206 | 207 | def log_every(self, iterable, print_freq, header=None, logger=None): 208 | if logger is None: 209 | print_func = print 210 | else: 211 | print_func = logger.info 212 | 213 | i = 0 214 | if not header: 215 | header = '' 216 | start_time = time.time() 217 | end = time.time() 218 | iter_time = SmoothedValue(fmt='{avg:.4f}') 219 | data_time = SmoothedValue(fmt='{avg:.4f}') 220 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 221 | if torch.cuda.is_available(): 222 | log_msg = self.delimiter.join([ 223 | header, 224 | '[{0' + space_fmt + '}/{1}]', 225 | 'eta: {eta}', 226 | '{meters}', 227 | 'time: {time}', 228 | 'data: {data}', 229 | 'max mem: {memory:.0f}' 230 | ]) 231 | else: 232 | log_msg = self.delimiter.join([ 233 | header, 234 | '[{0' + space_fmt + '}/{1}]', 235 | 'eta: {eta}', 236 | '{meters}', 237 | 'time: {time}', 238 | 'data: {data}' 239 | ]) 240 | MB = 1024.0 * 1024.0 241 | for obj in iterable: 242 | data_time.update(time.time() - end) 243 | yield obj 244 | 245 | iter_time.update(time.time() - end) 246 | if i == len(iterable) - 1: # i % print_freq == 0 or i == len(iterable) - 1 247 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 248 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 249 | if torch.cuda.is_available(): 250 | print_func(log_msg.format( 251 | i, len(iterable), eta=eta_string, 252 | meters=str(self), 253 | time=str(iter_time), data=str(data_time), 254 | memory=torch.cuda.max_memory_allocated() / MB)) 255 | else: 256 | print_func(log_msg.format( 257 | i, len(iterable), eta=eta_string, 258 | meters=str(self), 259 | time=str(iter_time), data=str(data_time))) 260 | i += 1 261 | end = time.time() 262 | total_time = time.time() - start_time 263 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 264 | print_func('{} Total time: {} ({:.4f} s / it)'.format( 265 | header, total_time_str, total_time / len(iterable))) 266 | 267 | 268 | def get_sha(): 269 | cwd = os.path.dirname(os.path.abspath(__file__)) 270 | 271 | def _run(command): 272 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 273 | 274 | sha = 'N/A' 275 | diff = "clean" 276 | branch = 'N/A' 277 | try: 278 | sha = _run(['git', 'rev-parse', 'HEAD']) 279 | subprocess.check_output(['git', 'diff'], cwd=cwd) 280 | diff = _run(['git', 'diff-index', 'HEAD']) 281 | diff = "has uncommited changes" if diff else "clean" 282 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 283 | except Exception: 284 | pass 285 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 286 | return message 287 | 288 | 289 | def setup_for_distributed(is_master): 290 | """ 291 | This function disables printing when not in master process 292 | """ 293 | import builtins as __builtin__ 294 | builtin_print = __builtin__.print 295 | 296 | def print(*args, **kwargs): 297 | force = kwargs.pop('force', False) 298 | if is_master or force: 299 | builtin_print(*args, **kwargs) 300 | 301 | __builtin__.print = print 302 | 303 | 304 | def is_dist_avail_and_initialized(): 305 | if not dist.is_available(): 306 | return False 307 | if not dist.is_initialized(): 308 | return False 309 | return True 310 | 311 | 312 | def get_world_size(): 313 | if not is_dist_avail_and_initialized(): 314 | return 1 315 | return dist.get_world_size() 316 | 317 | 318 | def get_rank(): 319 | if not is_dist_avail_and_initialized(): 320 | return 0 321 | return dist.get_rank() 322 | 323 | 324 | def is_main_process(): 325 | return get_rank() == 0 326 | 327 | 328 | def save_on_master(*args, **kwargs): 329 | if is_main_process(): 330 | torch.save(*args, **kwargs) 331 | 332 | 333 | def init_distributed_mode(args): 334 | if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and 335 | # args.rank = int(os.environ["RANK"]) 336 | # args.world_size = int(os.environ['WORLD_SIZE']) 337 | # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 338 | 339 | # launch by torch.distributed.launch 340 | # Single node 341 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ... 342 | # Multi nodes 343 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 344 | # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ... 345 | local_world_size = int(os.environ['WORLD_SIZE']) 346 | args.world_size = args.world_size * local_world_size 347 | args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) 348 | args.rank = args.rank * local_world_size + args.local_rank 349 | print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) 350 | print(json.dumps(dict(os.environ), indent=2)) 351 | elif 'SLURM_PROCID' in os.environ: 352 | args.rank = int(os.environ['SLURM_PROCID']) 353 | args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) 354 | args.world_size = int(os.environ['SLURM_NPROCS']) 355 | 356 | print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, 357 | args.local_rank, 358 | torch.cuda.device_count())) 359 | else: 360 | print('Not using distributed mode') 361 | args.distributed = False 362 | args.world_size = 1 363 | args.rank = 0 364 | args.local_rank = 0 365 | return 366 | 367 | print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) 368 | args.distributed = True 369 | torch.cuda.set_device(args.local_rank) 370 | args.dist_backend = 'nccl' 371 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) 372 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 373 | world_size=args.world_size, rank=args.rank) 374 | print("Before torch.distributed.barrier()") 375 | torch.distributed.barrier() 376 | print("End torch.distributed.barrier()") 377 | setup_for_distributed(args.rank == 0) 378 | 379 | 380 | def masks_to_boxes(masks): 381 | """Compute the bounding boxes around the provided masks 382 | 383 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 384 | 385 | Returns a [N, 4] tensors, with the boxes in xyxy format 386 | """ 387 | if masks.numel() == 0: 388 | return torch.zeros((0, 4), device=masks.device) 389 | 390 | h, w = masks.shape[-2:] 391 | 392 | y = torch.arange(0, h, dtype=torch.float) 393 | x = torch.arange(0, w, dtype=torch.float) 394 | y, x = torch.meshgrid(y, x) 395 | y = y.to(masks) 396 | x = x.to(masks) 397 | 398 | x_mask = ((masks > 128) * x.unsqueeze(0)) 399 | x_max = x_mask.flatten(1).max(-1)[0] 400 | x_min = x_mask.masked_fill(~(masks > 128), 1e8).flatten(1).min(-1)[0] 401 | 402 | y_mask = ((masks > 128) * y.unsqueeze(0)) 403 | y_max = y_mask.flatten(1).max(-1)[0] 404 | y_min = y_mask.masked_fill(~(masks > 128), 1e8).flatten(1).min(-1)[0] 405 | 406 | return torch.stack([x_min, y_min, x_max, y_max], 1) 407 | 408 | 409 | def box_cxcywh_to_xyxy(x): 410 | x_c, y_c, w, h = x.unbind(-1) 411 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 412 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 413 | return torch.stack(b, dim=-1) 414 | 415 | 416 | def box_xyxy_to_cxcywh(x): 417 | x0, y0, x1, y1 = x.unbind(-1) 418 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 419 | (x1 - x0), (y1 - y0)] 420 | return torch.stack(b, dim=-1) 421 | 422 | 423 | def box_noise(boxes, box_noise_scale=0): 424 | known_bbox_expand = box_xyxy_to_cxcywh(boxes) 425 | 426 | diff = torch.zeros_like(known_bbox_expand) 427 | diff[:, :2] = known_bbox_expand[:, 2:] / 2 428 | diff[:, 2:] = known_bbox_expand[:, 2:] 429 | known_bbox_expand += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), diff).cuda() * box_noise_scale 430 | boxes = box_cxcywh_to_xyxy(known_bbox_expand) 431 | boxes = boxes.clamp(min=0.0, max=512) 432 | 433 | return boxes 434 | 435 | 436 | def masks_sample_points(masks, k=10): 437 | """Sample points on mask 438 | """ 439 | if masks.numel() == 0: 440 | return torch.zeros((0, 2), device=masks.device) 441 | 442 | h, w = masks.shape[-2:] 443 | 444 | y = torch.arange(0, h, dtype=torch.float) 445 | x = torch.arange(0, w, dtype=torch.float) 446 | y, x = torch.meshgrid(y, x) 447 | y = y.to(masks) 448 | x = x.to(masks) 449 | 450 | # k = 10 451 | samples = [] 452 | for b_i in range(len(masks)): 453 | select_mask = (masks[b_i] > 128) 454 | x_idx = torch.masked_select(x, select_mask) 455 | y_idx = torch.masked_select(y, select_mask) 456 | 457 | perm = torch.randperm(x_idx.size(0)) 458 | idx = perm[:k] 459 | samples_x = x_idx[idx] 460 | samples_y = y_idx[idx] 461 | samples_xy = torch.cat((samples_x[:, None], samples_y[:, None]), dim=1) 462 | samples.append(samples_xy) 463 | 464 | samples = torch.stack(samples) 465 | return samples 466 | 467 | 468 | # Add noise to mask input 469 | # From Mask Transfiner https://github.com/SysCV/transfiner 470 | def masks_noise(masks): 471 | def get_incoherent_mask(input_masks, sfact): 472 | mask = input_masks.float() 473 | w = input_masks.shape[-1] 474 | h = input_masks.shape[-2] 475 | mask_small = F.interpolate(mask, (h // sfact, w // sfact), mode='bilinear') 476 | mask_recover = F.interpolate(mask_small, (h, w), mode='bilinear') 477 | mask_residue = (mask - mask_recover).abs() 478 | mask_residue = (mask_residue >= 0.01).float() 479 | return mask_residue 480 | 481 | gt_masks_vector = masks / 255 482 | mask_noise = torch.randn(gt_masks_vector.shape, device=gt_masks_vector.device) * 1.0 483 | inc_masks = get_incoherent_mask(gt_masks_vector, 8) 484 | gt_masks_vector = ((gt_masks_vector + mask_noise * inc_masks) > 0.5).float() 485 | gt_masks_vector = gt_masks_vector * 255 486 | 487 | return gt_masks_vector 488 | 489 | 490 | def mask_iou(pred_label, label): 491 | ''' 492 | calculate mask iou for pred_label and gt_label 493 | ''' 494 | 495 | pred_label = (pred_label > 0)[0].int() 496 | label = (label > 128)[0].int() 497 | 498 | intersection = ((label * pred_label) > 0).sum() 499 | union = ((label + pred_label) > 0).sum() 500 | return intersection / union 501 | 502 | 503 | # General util function to get the boundary of a binary mask. 504 | # https://gist.github.com/bowenc0221/71f7a02afee92646ca05efeeb14d687d 505 | def mask_to_boundary(mask, dilation_ratio=0.02): 506 | """ 507 | Convert binary mask to boundary mask. 508 | :param mask (numpy array, uint8): binary mask 509 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal 510 | :return: boundary mask (numpy array) 511 | """ 512 | h, w = mask.shape 513 | img_diag = np.sqrt(h ** 2 + w ** 2) 514 | dilation = int(round(dilation_ratio * img_diag)) 515 | if dilation < 1: 516 | dilation = 1 517 | # Pad image so mask truncated by the image border is also considered as boundary. 518 | new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) 519 | kernel = np.ones((3, 3), dtype=np.uint8) 520 | new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation) 521 | mask_erode = new_mask_erode[1: h + 1, 1: w + 1] 522 | # G_d intersects G in the paper. 523 | return mask - mask_erode 524 | 525 | 526 | def boundary_iou(gt, dt, dilation_ratio=0.02): 527 | """ 528 | Compute boundary iou between two binary masks. 529 | :param gt (numpy array, uint8): binary mask 530 | :param dt (numpy array, uint8): binary mask 531 | :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal 532 | :return: boundary iou (float) 533 | """ 534 | device = gt.device 535 | dt = (dt > 0)[0].cpu().byte().numpy() 536 | gt = (gt > 128)[0].cpu().byte().numpy() 537 | 538 | gt_boundary = mask_to_boundary(gt, dilation_ratio) 539 | dt_boundary = mask_to_boundary(dt, dilation_ratio) 540 | intersection = ((gt_boundary * dt_boundary) > 0).sum() 541 | union = ((gt_boundary + dt_boundary) > 0).sum() 542 | boundary_iou = intersection / union 543 | return torch.tensor(boundary_iou).float().to(device) 544 | 545 | 546 | def _max_by_axis(the_list): 547 | # type: (List[List[int]]) -> List[int] 548 | maxes = the_list[0] 549 | for sublist in the_list[1:]: 550 | for index, item in enumerate(sublist): 551 | maxes[index] = max(maxes[index], item) 552 | return maxes 553 | 554 | 555 | class NestedTensor(object): 556 | def __init__(self, tensors, mask: Optional[Tensor]): 557 | self.tensors = tensors 558 | self.mask = mask 559 | 560 | def to(self, device): 561 | # type: (Device) -> NestedTensor # noqa 562 | cast_tensor = self.tensors.to(device) 563 | mask = self.mask 564 | if mask is not None: 565 | assert mask is not None 566 | cast_mask = mask.to(device) 567 | else: 568 | cast_mask = None 569 | return NestedTensor(cast_tensor, cast_mask) 570 | 571 | def decompose(self): 572 | return self.tensors, self.mask 573 | 574 | def __repr__(self): 575 | return str(self.tensors) 576 | 577 | 578 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 579 | # TODO make this more general 580 | if tensor_list[0].ndim == 3: 581 | if torchvision._is_tracing(): 582 | # nested_tensor_from_tensor_list() does not export well to ONNX 583 | # call _onnx_nested_tensor_from_tensor_list() instead 584 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 585 | 586 | # TODO make it support different-sized images 587 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 588 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 589 | batch_shape = [len(tensor_list)] + max_size 590 | b, c, h, w = batch_shape 591 | dtype = tensor_list[0].dtype 592 | device = tensor_list[0].device 593 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 594 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 595 | for img, pad_img, m in zip(tensor_list, tensor, mask): 596 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 597 | m[: img.shape[1], :img.shape[2]] = False 598 | else: 599 | raise ValueError('not supported') 600 | return NestedTensor(tensor, mask) 601 | 602 | 603 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 604 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 605 | @torch.jit.unused 606 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 607 | max_size = [] 608 | for i in range(tensor_list[0].dim()): 609 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 610 | max_size.append(max_size_i) 611 | max_size = tuple(max_size) 612 | 613 | # work around for 614 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 615 | # m[: img.shape[1], :img.shape[2]] = False 616 | # which is not yet supported in onnx 617 | padded_imgs = [] 618 | padded_masks = [] 619 | for img in tensor_list: 620 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 621 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 622 | padded_imgs.append(padded_img) 623 | 624 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 625 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 626 | padded_masks.append(padded_mask.to(torch.bool)) 627 | 628 | tensor = torch.stack(padded_imgs) 629 | mask = torch.stack(padded_masks) 630 | 631 | return NestedTensor(tensor, mask=mask) 632 | 633 | 634 | @torch.no_grad() 635 | def accuracy(output, target, topk=(1,)): 636 | """Computes the precision@k for the specified values of k""" 637 | if target.numel() == 0: 638 | return [torch.zeros([], device=output.device)] 639 | maxk = max(topk) 640 | batch_size = target.size(0) 641 | 642 | _, pred = output.topk(maxk, 1, True, True) 643 | pred = pred.t() 644 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 645 | 646 | res = [] 647 | for k in topk: 648 | correct_k = correct[:k].view(-1).float().sum(0) 649 | res.append(correct_k.mul_(100.0 / batch_size)) 650 | return res 651 | 652 | 653 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 654 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 655 | """ 656 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 657 | This will eventually be supported natively by PyTorch, and this 658 | class can go away. 659 | """ 660 | if version.parse(torchvision.__version__) < version.parse('0.7'): 661 | if input.numel() > 0: 662 | return torch.nn.functional.interpolate( 663 | input, size, scale_factor, mode, align_corners 664 | ) 665 | 666 | output_shape = _output_size(2, input, size, scale_factor) 667 | output_shape = list(input.shape[:-2]) + list(output_shape) 668 | return _new_empty_tensor(input, output_shape) 669 | else: 670 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 671 | --------------------------------------------------------------------------------