├── .github └── workflows │ └── test_install.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── amg_example ├── README.md ├── amg_example.py ├── amg_example_trace.json.gz ├── dog.jpg ├── dog_mask.png └── dog_mask_fast.png ├── experiments ├── README.md ├── bar_chart.svg ├── data.py ├── eval_combo.py ├── metrics.py ├── p4d_results │ ├── results_bs1.csv │ ├── results_bs1_vit_h.csv │ ├── results_bs32.csv │ ├── results_bs32_vit_h.csv │ ├── results_bs8.csv │ └── results_bs8_vit_h.csv ├── requirements.txt ├── results.csv ├── run_experiments.py └── summary_chart.py ├── segment_anything_fast ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── configs │ ├── __init__.py │ └── flash_4_configs_a100.p ├── flash_4.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py ├── sparse.py ├── tools.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── setup.py └── test ├── test_flash_4.py └── test_mask_to_rle.py /.github/workflows/test_install.yml: -------------------------------------------------------------------------------- 1 | name: Test Installation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test-installation: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | 24 | - name: Install dependencies 25 | run: | 26 | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 27 | pip install -e . 28 | 29 | - name: Test import 30 | run: | 31 | python -c "import segment_anything_fast; from segment_anything_fast import sam_model_registry" 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # READ THIS BEFORE YOU REFACTOR ME 2 | # 3 | # setup.py uses the list of patterns in this file to decide 4 | # what to delete, but it's not 100% sound. So, for example, 5 | # if you delete aten/build/ because it's redundant with build/, 6 | # aten/build/ will stop being cleaned. So be careful when 7 | # refactoring this file! 8 | 9 | experiments/experiments_data/checkpoints/* 10 | experiments/experiments_data/tmp/* 11 | experiments/experiments_data/datasets/* 12 | amg_example/checkpoints/* 13 | 14 | ## PyTorch 15 | 16 | .coverage 17 | coverage.xml 18 | .dmypy.json 19 | .gradle 20 | .hypothesis 21 | .mypy_cache 22 | /.extracted_scripts/ 23 | **/.pytorch_specified_test_cases.csv 24 | **/.pytorch-disabled-tests.json 25 | **/.pytorch-slow-tests.json 26 | **/.pytorch-test-times.json 27 | **/.pytorch-test-file-ratings.json 28 | */*.pyc 29 | */*.so* 30 | */**/__pycache__ 31 | */**/*.dylib* 32 | */**/*.pyc 33 | */**/*.pyd 34 | */**/*.so* 35 | */**/**/*.pyc 36 | */**/**/**/*.pyc 37 | */**/**/**/**/*.pyc 38 | aten/build/ 39 | aten/src/ATen/Config.h 40 | aten/src/ATen/cuda/CUDAConfig.h 41 | benchmarks/.data 42 | caffe2/cpp_test/ 43 | dist/ 44 | docs/build/ 45 | docs/cpp/src 46 | docs/src/**/* 47 | docs/cpp/build 48 | docs/cpp/source/api 49 | docs/cpp/source/html/ 50 | docs/cpp/source/latex/ 51 | docs/source/compile/generated/ 52 | docs/source/generated/ 53 | docs/source/compile/generated/ 54 | log 55 | usage_log.txt 56 | test-reports/ 57 | test/*.bak 58 | test/**/*.bak 59 | test/.coverage 60 | test/.hypothesis/ 61 | test/cpp/api/mnist 62 | test/custom_operator/model.pt 63 | test/jit_hooks/*.pt 64 | test/data/legacy_modules.t7 65 | test/data/*.pt 66 | test/forward_backward_compatibility/nightly_schemas.txt 67 | dropout_model.pt 68 | test/generated_type_hints_smoketest.py 69 | test/htmlcov 70 | test/cpp_extensions/install/ 71 | third_party/build/ 72 | tools/coverage_plugins_package/pip-wheel-metadata/ 73 | tools/shared/_utils_internal.py 74 | tools/fast_nvcc/wrap_nvcc.sh 75 | tools/fast_nvcc/wrap_nvcc.bat 76 | tools/fast_nvcc/tmp/ 77 | torch.egg-info/ 78 | torch/_C/__init__.pyi 79 | torch/_C/_nn.pyi 80 | torch/_C/_VariableFunctions.pyi 81 | torch/_VF.pyi 82 | torch/return_types.pyi 83 | torch/nn/functional.pyi 84 | torch/utils/data/datapipes/datapipe.pyi 85 | torch/csrc/autograd/generated/* 86 | torch/csrc/lazy/generated/*.[!m]* 87 | torch_compile_debug/ 88 | # Listed manually because some files in this directory are not generated 89 | torch/testing/_internal/generated/annotated_fn_args.py 90 | torch/testing/_internal/data/*.pt 91 | torch/csrc/api/include/torch/version.h 92 | torch/csrc/cudnn/cuDNN.cpp 93 | torch/csrc/generated 94 | torch/csrc/generic/TensorMethods.cpp 95 | torch/csrc/jit/generated/* 96 | torch/csrc/jit/fuser/config.h 97 | torch/csrc/nn/THCUNN.cpp 98 | torch/csrc/nn/THCUNN.cwrap 99 | torch/bin/ 100 | torch/cmake/ 101 | torch/lib/*.a* 102 | torch/lib/*.dll* 103 | torch/lib/*.exe* 104 | torch/lib/*.dylib* 105 | torch/lib/*.h 106 | torch/lib/*.lib 107 | torch/lib/*.pdb 108 | torch/lib/*.so* 109 | torch/lib/protobuf*.pc 110 | torch/lib/build 111 | torch/lib/caffe2/ 112 | torch/lib/cmake 113 | torch/lib/include 114 | torch/lib/pkgconfig 115 | torch/lib/protoc 116 | torch/lib/protobuf/ 117 | torch/lib/tmp_install 118 | torch/lib/torch_shm_manager 119 | torch/lib/site-packages/ 120 | torch/lib/python* 121 | torch/lib64 122 | torch/include/ 123 | torch/share/ 124 | torch/test/ 125 | torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h 126 | torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h 127 | torch/version.py 128 | minifier_launcher.py 129 | # Root level file used in CI to specify certain env configs. 130 | # E.g., see .circleci/config.yaml 131 | env 132 | .circleci/scripts/COMMIT_MSG 133 | scripts/release_notes/*.json 134 | sccache-stats*.json 135 | 136 | # These files get copied over on invoking setup.py 137 | torchgen/packaged/* 138 | !torchgen/packaged/README.md 139 | 140 | # IPython notebook checkpoints 141 | .ipynb_checkpoints 142 | 143 | # Editor temporaries 144 | *.swa 145 | *.swb 146 | *.swc 147 | *.swd 148 | *.swe 149 | *.swf 150 | *.swg 151 | *.swh 152 | *.swi 153 | *.swj 154 | *.swk 155 | *.swl 156 | *.swm 157 | *.swn 158 | *.swo 159 | *.swp 160 | *~ 161 | .~lock.* 162 | 163 | # macOS dir files 164 | .DS_Store 165 | 166 | # Ninja files 167 | .ninja_deps 168 | .ninja_log 169 | compile_commands.json 170 | *.egg-info/ 171 | docs/source/scripts/activation_images/ 172 | docs/source/scripts/quantization_backend_configs/ 173 | 174 | ## General 175 | 176 | # Compiled Object files 177 | *.slo 178 | *.lo 179 | *.o 180 | *.cuo 181 | *.obj 182 | 183 | # Compiled Dynamic libraries 184 | *.so 185 | *.dylib 186 | *.dll 187 | 188 | # Compiled Static libraries 189 | *.lai 190 | *.la 191 | *.a 192 | *.lib 193 | 194 | # Compiled protocol buffers 195 | *.pb.h 196 | *.pb.cc 197 | *_pb2.py 198 | 199 | # Compiled python 200 | *.pyc 201 | *.pyd 202 | 203 | # Compiled MATLAB 204 | *.mex* 205 | 206 | # IPython notebook checkpoints 207 | .ipynb_checkpoints 208 | 209 | # Editor temporaries 210 | *.swn 211 | *.swo 212 | *.swp 213 | *~ 214 | 215 | # NFS handle files 216 | **/.nfs* 217 | 218 | # Sublime Text settings 219 | *.sublime-workspace 220 | *.sublime-project 221 | 222 | # Eclipse Project settings 223 | *.*project 224 | .settings 225 | 226 | # QtCreator files 227 | *.user 228 | 229 | # PyCharm files 230 | .idea 231 | 232 | # GDB history 233 | .gdb_history 234 | 235 | ## Caffe2 236 | 237 | # build, distribute, and bins (+ python proto bindings) 238 | build/ 239 | # Allow tools/build/ for build support. 240 | !tools/build/ 241 | build_host_protoc 242 | build_android 243 | build_ios 244 | .build_debug/* 245 | .build_release/* 246 | .build_profile/* 247 | distribute/* 248 | *.testbin 249 | *.bin 250 | cmake_build 251 | .cmake_build 252 | gen 253 | .setuptools-cmake-build 254 | .pytest_cache 255 | aten/build/* 256 | 257 | # Bram 258 | plsdontbreak 259 | 260 | # Generated documentation 261 | docs/_site 262 | docs/gathered 263 | _site 264 | doxygen 265 | docs/dev 266 | 267 | # LevelDB files 268 | *.sst 269 | *.ldb 270 | LOCK 271 | CURRENT 272 | MANIFEST-* 273 | 274 | # generated version file 275 | caffe2/version.py 276 | 277 | # setup.py intermediates 278 | .eggs 279 | caffe2.egg-info 280 | MANIFEST 281 | 282 | # Atom/Watchman required file 283 | .watchmanconfig 284 | 285 | # Files generated by CLion 286 | cmake-build-debug 287 | 288 | # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) 289 | # 290 | # Below files are not deleted by "setup.py clean". 291 | 292 | # Downloaded bazel 293 | tools/bazel 294 | 295 | # Visual Studio Code files 296 | .vs 297 | /.vscode/* 298 | !/.vscode/extensions.json 299 | !/.vscode/settings_recommended.json 300 | 301 | # YouCompleteMe config file 302 | .ycm_extra_conf.py 303 | 304 | # Files generated when a patch is rejected 305 | *.orig 306 | *.rej 307 | 308 | # Files generated by ctags 309 | CTAGS 310 | GTAGS 311 | GRTAGS 312 | GSYMS 313 | GPATH 314 | tags 315 | TAGS 316 | 317 | 318 | # ccls file 319 | .ccls-cache/ 320 | 321 | # clang tooling storage location 322 | .clang-format-bin 323 | .clang-tidy-bin 324 | .lintbin 325 | 326 | # clangd background index 327 | .clangd/ 328 | .cache/ 329 | 330 | # bazel symlinks 331 | bazel-* 332 | 333 | # xla repo 334 | xla/ 335 | 336 | # direnv, posh-direnv 337 | .env 338 | .envrc 339 | .psenvrc 340 | 341 | # generated shellcheck directories 342 | .shellcheck_generated*/ 343 | 344 | # zip archives 345 | *.zip 346 | 347 | # core dump files 348 | **/core.[1-9]* 349 | 350 | # Generated if you use the pre-commit script for clang-tidy 351 | pr.diff 352 | 353 | # coverage files 354 | */**/.coverage.* 355 | 356 | # buck generated files 357 | .buckd/ 358 | .lsp-buck-out/ 359 | .lsp.buckd/ 360 | buck-out/ 361 | 362 | # Downloaded libraries 363 | third_party/ruy/ 364 | third_party/glog/ 365 | 366 | # Virtualenv 367 | venv/ 368 | 369 | # Log files 370 | *.log 371 | sweep/ 372 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything-fast 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Meta's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## License 31 | By contributing to `segment-anything-fast`, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] Lightning AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment anything ... Fast 2 | 3 | This work is based on a fork of https://github.com/facebookresearch/segment-anything 4 | 5 | The corresponding blog post is https://pytorch.org/blog/accelerating-generative-ai/ 6 | 7 | 8 | ## Installation 9 | 10 | 11 | Step 1 12 | 13 | Get latest PyTorch nightly 14 | 15 | 16 | For example: 17 | ``` 18 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 19 | ``` 20 | or 21 | ``` 22 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu 23 | ``` 24 | 25 | Installation instructions vary by platform. Please see the website https://pytorch.org/ 26 | 27 | 28 | Step 2 29 | 30 | Install the package 31 | 32 | ``` 33 | pip install git+https://github.com/pytorch-labs/segment-anything-fast.git 34 | ``` 35 | 36 | ## Usage 37 | 38 | The package acts like a drop-in replacement for segment-anything. 39 | 40 | So, for example, if you're currently doing `from segment_anything import sam_model_registry` you should be able to do `from segment_anything_fast import sam_model_registry`. 41 | 42 | However, you're likely here because you want to try a fast, inference version. So we also created a `sam_model_fast_registry` that automatically applies 43 | - Sets `eval` mode 44 | - Uses `bfloat16` 45 | - Enables torch.compile with max-autotune 46 | - Uses a custom Triton kernel that implements SDPA for relative positional encodings for long sequence lengths 47 | 48 | The custom Triton kernel in particular was written for A100. If you're not using an A100, we will try to rerun autotuning on your device and locally save the best configs. 49 | You might still run into performance issues, so you can disable the kernel by setting the environment variable `SEGMENT_ANYTHING_FAST_USE_FLASH_4=0` 50 | 51 | Please also note that the first time you're running this model you'll likely need to wait a bit for it to compile. 52 | 53 | If you'd like to see the details on how to reproduce all results, please see the README in the experiments folder above. 54 | 55 | Please don't be shy to open a Github issue if you're missing functionality or find an issue. Thank you. 56 | 57 | ## Results 58 | 59 | The results show a waterfall of techniques. 60 | 61 | Left to right these techniques are combined. 62 | 63 | That means the very last bar is the combination of 64 | - bfloat16 65 | - torch.compile with max-autotune 66 | - [torch.scaled_dot_product_attention](https://pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.html) 67 | - A custom Triton kernel that implements SDPA for relative positional encodings for long sequence lengths 68 | - NestedTensors 69 | - Dynamic int8 symmetric quantization 70 | - 2:4 sparse format 71 | 72 | ![High level results](experiments/bar_chart.svg) 73 | 74 | ## License 75 | 76 | `segment-anything-fast` is released under the [Apache 2.0](https://github.com/pytorch-labs/segment-anything-fast/main/LICENSE) license. 77 | -------------------------------------------------------------------------------- /amg_example/README.md: -------------------------------------------------------------------------------- 1 | To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints 2 | 3 | You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 4 | 5 | To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/ 6 | -------------------------------------------------------------------------------- /amg_example/amg_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch.utils.benchmark as benchmark 6 | 7 | def profiler_runner(path, fn, *args, **kwargs): 8 | with torch.profiler.profile( 9 | activities=[torch.profiler.ProfilerActivity.CPU, 10 | torch.profiler.ProfilerActivity.CUDA], 11 | record_shapes=True) as prof: 12 | result = fn(*args, **kwargs) 13 | print(f"Saving trace under {path}") 14 | prof.export_chrome_trace(path) 15 | return result 16 | 17 | def show_anns(anns): 18 | if len(anns) == 0: 19 | return 20 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 21 | ax = plt.gca() 22 | ax.set_autoscale_on(False) 23 | 24 | img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) 25 | img[:,:,3] = 0 26 | for ann in sorted_anns: 27 | m = ann['segmentation'] 28 | color_mask = np.concatenate([np.random.random(3), [0.35]]) 29 | img[m] = color_mask 30 | ax.imshow(img) 31 | 32 | image = cv2.imread('dog.jpg') 33 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 34 | 35 | 36 | from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator 37 | 38 | sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth" 39 | model_type = "vit_h" 40 | device = "cuda" 41 | 42 | sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) 43 | sam.to(device=device) 44 | mask_generator = SamAutomaticMaskGenerator(sam, process_batch_size=8) 45 | 46 | # Run thrice for warmup 47 | masks = mask_generator.generate(image) 48 | masks = mask_generator.generate(image) 49 | masks = mask_generator.generate(image) 50 | 51 | # Save an example 52 | plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100) 53 | plt.imshow(image) 54 | show_anns(masks) 55 | plt.axis('off') 56 | plt.tight_layout() 57 | plt.savefig('dog_mask_fast.png', format='png') 58 | 59 | # Benchmark 60 | torch.cuda.synchronize() 61 | start_event = torch.cuda.Event(enable_timing=True) 62 | end_event = torch.cuda.Event(enable_timing=True) 63 | start_event.record() 64 | for _ in range(10): 65 | masks = mask_generator.generate(image) 66 | end_event.record() 67 | torch.cuda.synchronize() 68 | print(start_event.elapsed_time(end_event) / 10.) 69 | 70 | # Save a GPU trace 71 | profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image) 72 | 73 | # Write out memory usage 74 | max_memory_allocated_bytes = torch.cuda.max_memory_allocated() 75 | _, total_memory = torch.cuda.mem_get_info() 76 | max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) 77 | max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 78 | print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}") 79 | -------------------------------------------------------------------------------- /amg_example/amg_example_trace.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/amg_example/amg_example_trace.json.gz -------------------------------------------------------------------------------- /amg_example/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/amg_example/dog.jpg -------------------------------------------------------------------------------- /amg_example/dog_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/amg_example/dog_mask.png -------------------------------------------------------------------------------- /amg_example/dog_mask_fast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/amg_example/dog_mask_fast.png -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | To run the experiments you need to update the script paths and install fire, pandas and tqdm 2 | 3 | ## Model Checkpoints 4 | 5 | Need checkpoints from https://github.com/facebookresearch/segment-anything 6 | 7 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 8 | 9 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 10 | 11 | ## COCO2017 dataset 12 | 13 | Need to download 14 | 15 | wget http://images.cocodataset.org/zips/val2017.zip 16 | 17 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 18 | 19 | ## Folder structure of experimental data 20 | ``` 21 | experiments_data/tmp 22 | experiments_data/tmp/sam_coco_mask_center_cache 23 | experiments_data/tmp/sam_eval_masks_out 24 | experiments_data/datasets 25 | experiments_data/datasets/coco2017 26 | experiments_data/datasets/coco2017/val2017 27 | experiments_data/datasets/coco2017/annotations 28 | experiments_data/checkpoints 29 | ``` 30 | ## Environment details 31 | 32 | ### Hardware 33 | These experiments were run on an Amazon p4d.24xlarge instance. See the Product details of the EC2 website for the exact details. A few key highlights are 34 | 35 | - 8 A100 GPUs with 40960MiB running at 400W 36 | - 96 vCPUs 37 | - 1152 GiB of RAM 38 | - Software 39 | 40 | Meanwhile, these experiments (fp32, bf16, compile, SDPA, Triton, NT) can run on CPU platform as well. Experiment results will be shown in the near future. 41 | 42 | ### Versions 43 | 44 | - PyTorch nightly and Python 3.10 45 | - https://github.com/cpuhrsch/segment-anything fork of https://github.com/facebookresearch/segment-anything with additional commits if you want to reproduce baseline and first few experiments 46 | - This https://github.com/pytorch-labs/segment-anything-fast 47 | 48 | ### Installation instructions 49 | 50 | ``` 51 | $ conda create -n nightlypy310 52 | $ conda activate nightlypy310 53 | $ conda install python=3.10 54 | For GPU, 55 | - $ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl 56 | - $ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl 57 | For CPU, 58 | - $ pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240530%2Bcpu-cp310-cp310-linux_x86_64.whl 59 | - $ pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240530%2Bcpu-cp310-cp310-linux_x86_64.whl 60 | - $ install triton based on https://github.com/triton-lang/triton?tab=readme-ov-file#quick-installation 61 | 62 | $ git clone https://github.com/cpuhrsch/segment-anything.git 63 | $ cd segment-anything 64 | $ pip install -e . 65 | $ cd .. 66 | $ git clone https://github.com/pytorch-labs/segment-anything-fast.git 67 | $ cd segment-anything-fast 68 | $ pip install -e . 69 | ``` 70 | 71 | If you plan to run the scripts that run the experiments from segment-anything-fast it is important to install the segment-anything fork in editable mode so that the script can switch between different commits of the fork automatically. 72 | 73 | 74 | ### How to run experiments 75 | 76 | For GPU platform, 77 | ``` 78 | $ python run_experiments.py 16 vit_b --run-experiments --num-workers 32 79 | ``` 80 | 81 | For CPU platform, set SEGMENT_ANYTHING_FAST_USE_FLASH_4 as 0, since Custom flash attention kernels were written specifically for A100. 82 | ``` 83 | $ SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python run_experiments.py 16 vit_b --run-experiments --num-workers 32 --device cpu 84 | ``` 85 | 86 | If at any point you run into issue, please note that you can increase verbosity by adding `--capture_output False` to above command. Also, please don't hesitate to open an issue. 87 | 88 | 89 | ### Data 90 | We are using the COCO2017 Validation (Val images) dataset. We use this dataset to serve as a somewhat realistic distribution of input images and aim to measure a) accuracy and b) performance. 91 | Measurement 92 | Accuracy 93 | Our main goal is to verify that our performance optimizations do not degrade the accuracy of the model. We do not aim to reproduce any paper results or aim to make statements about the accuracy of this model on the dataset. This measurement serves as an additional integration test in conjunction with numerous unit and other separate integration tests. 94 | 95 | We calculate the center points of the mask annotations using a rudimentary version of https://arxiv.org/pdf/2304.02643.pdf, section D.1.Point Sampling ([code](https://github.com/pytorch-labs/segment-anything-fast/blob/67d5c894569e99b9fdba55cfcf2f724be9f68994/experiments/data.py#L10-L120)). These center points serve as annotations per image. Note that the number of masks and thus number of annotations per image vary. 96 | 97 | These images and annotations are given to the predict_torch method of an instance of SamPredictor to predict masks. These are then compared to the ground truth masks using the Intersection over Union (IoU) metric ([code](https://github.com/pytorch-labs/segment-anything-fast/blob/67d5c894569e99b9fdba55cfcf2f724be9f68994/experiments/metrics.py#L4-L22)). We calculate the mean IoU (mIoU) metric over the entire 5000 images of the validation dataset to track accuracy. 98 | Performance 99 | Our goal is to measure the runtime of PyTorch models. We purposefully exclude data movements or calculation of the metrics. Specifically we measure the execution time on the GPU of running the image encoder (e.g. vit_h) and SamPredictor.predict_torch ([code](https://github.com/pytorch-labs/segment-anything-fast/blob/67d5c894569e99b9fdba55cfcf2f724be9f68994/experiments/eval_combo.py#L127-L165), [code](https://github.com/pytorch-labs/segment-anything-fast/blob/67d5c894569e99b9fdba55cfcf2f724be9f68994/experiments/eval_combo.py#L68-L99)). 100 | 101 | Each experiment is run in a separate Python process created from scratch. We run three batches of warmup before each experiment. This also implies that we are excluding compilation time from benchmarking. 102 | 103 | We measure the execution time and calculate the number of images that can be processed per image (img/s). We also measure the maximum amount of memory allocated at the end of the process using torch.cuda.max_memory_allocated. 104 | Tracing 105 | 106 | We collect kernel and memory traces using PyTorch native tooling and analyze it with [Perfetto UI](https://perfetto.dev/). When collecting these traces and profiles we typically only limit us to a few batches. Otherwise the files can become very large and difficult to load. 107 | 108 | ### Kernel traces 109 | 110 | One can write a simple wrapper that runs a function under the tracer context and writes out the result to a compressed json file. The resulting chrome trace can then be analyzed with Perfetto UI. 111 | 112 | ``` 113 | def profiler_runner(path, fn, *args, **kwargs): 114 | with torch.profiler.profile( 115 | activities=[torch.profiler.ProfilerActivity.CPU, 116 | torch.profiler.ProfilerActivity.CUDA], 117 | record_shapes=True) as prof: 118 | result = fn(*args, **kwargs) 119 | prof.export_chrome_trace(path) 120 | return result 121 | ``` 122 | 123 | It can be very useful to annotate certain regions in these traces to map (pieces of) the code to the overall traces. For this we frequently use record_function. Consider the following as an example. 124 | 125 | ``` 126 | with torch.autograd.profiler.record_function("timed region"): 127 | with torch.autograd.profiler.record_function("image encoder"): 128 | features_batch = encoder(input_image_batch) 129 | features_batch = features_batch[:orig_input_image_batch_size] 130 | 131 | with torch.autograd.profiler.record_function("nt predict_torch"): 132 | predictor.reset_image() 133 | [...] 134 | ``` 135 | 136 | ### Memory profiles 137 | 138 | We record the memory history and use memory_viz.py to convert the result into a human readable html file. 139 | 140 | ``` 141 | def memory_runner(path, fn, *args, **kwargs): 142 | print("Start memory recording") 143 | torch.cuda.synchronize() 144 | torch.cuda.memory._record_memory_history( 145 | True, 146 | trace_alloc_max_entries=100000, 147 | trace_alloc_record_context=True 148 | ) 149 | result = fn(*args, **kwargs) 150 | torch.cuda.synchronize() 151 | snapshot = torch.cuda.memory._snapshot() 152 | print("Finish memory recording") 153 | import pickle 154 | with open(path, 'wb') as f: 155 | pickle.dump(snapshot, f) 156 | # Use to convert pickle file into html 157 | # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html 158 | return result 159 | ``` 160 | -------------------------------------------------------------------------------- /experiments/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import diskcache 3 | from pycocotools.coco import COCO 4 | import numpy as np 5 | from scipy import ndimage 6 | import skimage.io as io 7 | import skimage.color as color 8 | 9 | 10 | def _get_center_point(mask, ann_id, cache): 11 | """ 12 | This is a rudimentary version of https://arxiv.org/pdf/2304.02643.pdf, 13 | section D.1.Point Sampling 14 | 15 | From the paper: "The first point is chosen deterministically as the point 16 | farthest from the object boundary." 17 | 18 | The code below is an approximation of this. 19 | 20 | First, we try to calculate the center of mass. If it's inside the mask, we 21 | stop here. 22 | 23 | The centroid may be outside of the mask for some mask shapes. In this case 24 | we do a slow hack, specifically, we check for the 25 | minumum of the maximum distance from the boundary in four directions 26 | (up, right, down, left), and take the point with the maximum of these 27 | minimums. Note: this is not performant for large masks. 28 | 29 | Returns the center point in (x, y) format 30 | """ 31 | if ann_id in cache: 32 | return cache[ann_id] 33 | 34 | # try the center of mass, keep it if it's inside the mask 35 | com_y, com_x = ndimage.center_of_mass(mask) 36 | com_y, com_x = int(round(com_y, 0)), int(round(com_x, 0)) 37 | if mask[com_y][com_x]: 38 | cache[ann_id] = (com_x, com_y) 39 | return (com_x, com_y) 40 | 41 | # if center of mass didn't work, do the slow manual approximation 42 | 43 | # up, right, down, left 44 | # TODO(future): approximate better by adding more directions 45 | distances_to_check_deg = [0, 90, 180, 270] 46 | 47 | global_min_max_distance = float('-inf') 48 | global_coords = None 49 | # For now, terminate early to speed up the calculation as long as 50 | # the point sample is gooe enough. This sacrifices the quality of point 51 | # sampling for speed. In the future we can make this more accurate. 52 | DISTANCE_GOOD_ENOUGH_THRESHOLD = 20 53 | 54 | # Note: precalculating the bounding box could be somewhat 55 | # helpful, but checked the performance gain and it's not much 56 | # so leaving it out to keep the code simple. 57 | # Note: tried binary search instead of incrementing by one to 58 | # travel up/right/left/down, but that does not handle masks 59 | # with all shapes properly (there could be multiple boundaries). 60 | for row_idx in range(mask.shape[0]): 61 | for col_idx in range(mask.shape[1]): 62 | cur_point = mask[row_idx, col_idx] 63 | 64 | # skip points inside bounding box but outside mask 65 | if not cur_point: 66 | continue 67 | 68 | max_distances = [] 69 | for direction in distances_to_check_deg: 70 | # TODO(future) binary search instead of brute forcing it if we 71 | # need a speedup, with the cache it doesn't really matter though 72 | if direction == 0: 73 | # UP 74 | cur_row_idx = row_idx 75 | 76 | while cur_row_idx >= 0 and mask[cur_row_idx, col_idx]: 77 | cur_row_idx = cur_row_idx - 1 78 | cur_row_idx += 1 79 | distance = row_idx - cur_row_idx 80 | max_distances.append(distance) 81 | 82 | elif direction == 90: 83 | # RIGHT 84 | cur_col_idx = col_idx 85 | 86 | while cur_col_idx <= mask.shape[1] - 1 and \ 87 | mask[row_idx, cur_col_idx]: 88 | cur_col_idx += 1 89 | cur_col_idx -= 1 90 | distance = cur_col_idx - col_idx 91 | max_distances.append(distance) 92 | 93 | elif direction == 180: 94 | # DOWN 95 | cur_row_idx = row_idx 96 | while cur_row_idx <= mask.shape[0] - 1 and \ 97 | mask[cur_row_idx, col_idx]: 98 | cur_row_idx = cur_row_idx + 1 99 | cur_row_idx -= 1 100 | distance = cur_row_idx - row_idx 101 | max_distances.append(distance) 102 | 103 | elif direction == 270: 104 | # LEFT 105 | cur_col_idx = col_idx 106 | while cur_col_idx >= 0 and mask[row_idx, cur_col_idx]: 107 | cur_col_idx -= 1 108 | cur_col_idx += 1 109 | distance = col_idx - cur_col_idx 110 | max_distances.append(distance) 111 | 112 | min_max_distance = min(max_distances) 113 | if min_max_distance > global_min_max_distance: 114 | global_min_max_distance = min_max_distance 115 | global_coords = (col_idx, row_idx) 116 | if global_min_max_distance >= DISTANCE_GOOD_ENOUGH_THRESHOLD: 117 | break 118 | 119 | cache[ann_id] = global_coords 120 | return global_coords 121 | 122 | 123 | def build_datapoint(imgId, 124 | coco, 125 | pixel_mean, 126 | pixel_std, 127 | coco_root_dir, 128 | coco_slice_name, 129 | catIds, 130 | cache, 131 | predictor, 132 | pad_input_image_batch): 133 | img = coco.loadImgs(imgId)[0] 134 | 135 | file_location = f'{coco_root_dir}/{coco_slice_name}/{img["file_name"]}' 136 | I = io.imread(file_location) 137 | if len(I.shape) == 2: 138 | # some images, like img_id==61418, are grayscale 139 | # convert to RGB to ensure the rest of the pipeline works 140 | I = color.gray2rgb(I) 141 | 142 | # load and display instance annotations 143 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) 144 | anns = coco.loadAnns(annIds) 145 | 146 | # approximate the center point of each mask 147 | coords_list = [] 148 | gt_masks_list = [] 149 | for ann in anns: 150 | ann_id = ann['id'] 151 | mask = coco.annToMask(ann) 152 | gt_masks_list.append(torch.tensor(mask)) 153 | coords = _get_center_point(mask, ann_id, cache) 154 | coords_list.append(coords) 155 | 156 | image = I 157 | 158 | # predictor_set_image begin 159 | # Transform the image to the form expected by the model 160 | input_image = predictor.transform.apply_image(image) 161 | input_image_torch = torch.as_tensor(input_image) 162 | input_image_torch = input_image_torch.permute( 163 | 2, 0, 1).contiguous()[None, :, :, :] 164 | predictor_input_size = input_image_torch.shape[-2:] 165 | 166 | # Preprocess 167 | x = input_image_torch 168 | # Normalize colors 169 | x = (x - pixel_mean) / pixel_std 170 | 171 | if pad_input_image_batch: 172 | # Pad 173 | h, w = x.shape[-2:] 174 | padh = predictor.model.image_encoder.img_size - h 175 | padw = predictor.model.image_encoder.img_size - w 176 | x = torch.nn.functional.pad(x, (0, padw, 0, padh)) 177 | else: 178 | x = x.squeeze(0) 179 | 180 | gt_masks_list = torch.stack(gt_masks_list) if len(gt_masks_list) else None 181 | return image, coords_list, gt_masks_list, anns, x, predictor_input_size 182 | 183 | 184 | def build_data(coco_img_ids, 185 | coco, 186 | catIds, 187 | coco_root_dir, 188 | coco_slice_name, 189 | point_sampling_cache_dir, 190 | predictor, 191 | use_half, 192 | use_nested_tensor, 193 | pad_input_image_batch): 194 | cache = diskcache.Cache(point_sampling_cache_dir) 195 | # make sure you clear the cache if you change the point sampling algorithm 196 | # cache.clear() 197 | 198 | pixel_mean = predictor.model.pixel_mean.cpu() 199 | pixel_std = predictor.model.pixel_std.cpu() 200 | 201 | def build_batch(indicies): 202 | batch = [[], [], [], [], [], [], [], [], [], [], []] 203 | batch[3] = [0] 204 | batch[6] = [0] 205 | for img_idx in indicies: 206 | imgId = coco_img_ids[img_idx] 207 | 208 | datapoint = build_datapoint(imgId, 209 | coco, 210 | pixel_mean, 211 | pixel_std, 212 | coco_root_dir, 213 | coco_slice_name, 214 | catIds, 215 | cache, 216 | predictor, 217 | pad_input_image_batch) 218 | I, coords_list, gt_masks_list, anns, x, predictor_input_size = datapoint 219 | if len(coords_list) == 0: 220 | continue 221 | batch[0].append(x) 222 | # batch[0].append(x[0]) 223 | coords_list = predictor.transform.apply_coords( 224 | np.array(coords_list), I.shape[:2]) 225 | coords_list = torch.tensor(coords_list, dtype=torch.float) 226 | 227 | batch[1].append(coords_list.reshape(-1)) 228 | batch[2].append(coords_list.size()) 229 | batch[3].append(coords_list.numel() + batch[3][-1]) 230 | 231 | batch[4].append(gt_masks_list.reshape(-1)) 232 | batch[5].append(gt_masks_list.size()) 233 | batch[6].append(gt_masks_list.numel() + batch[6][-1]) 234 | 235 | batch[7].append(anns) 236 | batch[8].append(I) 237 | batch[9].append(predictor_input_size) 238 | batch[10].append(img_idx) 239 | 240 | def cat_and_cast(b, use_half): 241 | b = torch.cat(b) if len(b) > 0 else None 242 | if use_half is not None and b is not None: 243 | return b.to(use_half) 244 | return b 245 | 246 | def to_nested_tensor(data, sizes=None, use_half=None): 247 | if len(data) == 0: 248 | return None 249 | dtype = use_half if use_half is not None else torch.float32 250 | 251 | if sizes is not None: 252 | data = [d.view(s) for (d, s) in zip(data, sizes)] 253 | 254 | return torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged) 255 | 256 | if pad_input_image_batch: 257 | batch[0] = cat_and_cast(batch[0], use_half) 258 | else: 259 | batch[0] = to_nested_tensor(batch[0], use_half=use_half) 260 | 261 | if use_nested_tensor: 262 | batch[1] = to_nested_tensor(batch[1], batch[2], use_half) 263 | batch[2] = None 264 | batch[3] = None 265 | else: 266 | batch[1] = cat_and_cast(batch[1], use_half) 267 | 268 | batch[4] = cat_and_cast(batch[4], False) 269 | 270 | return batch 271 | 272 | return build_batch 273 | 274 | 275 | def setup_coco_img_ids(coco_root_dir, coco_slice_name, coco_category_names, img_id): 276 | annFile = '{}/annotations/instances_{}.json'.format( 277 | coco_root_dir, coco_slice_name) 278 | 279 | # initialize COCO api for instance annotations 280 | coco = COCO(annFile) 281 | 282 | # display COCO categories and supercategories 283 | cats = coco.loadCats(coco.getCatIds()) 284 | cat_id_to_cat = {cat['id']: cat for cat in cats} 285 | nms = [cat['name'] for cat in cats] 286 | # print('COCO categories: \n{}\n'.format(' '.join(nms))) 287 | 288 | # nms = set([cat['supercategory'] for cat in cats]) 289 | # print('COCO supercategories: \n{}'.format(' '.join(nms))) 290 | 291 | if coco_category_names is not None: 292 | catIds = coco.getCatIds(catNms=coco_category_names) 293 | else: 294 | catIds = coco.getCatIds() 295 | 296 | if img_id is not None: 297 | coco_img_ids = [img_id] 298 | elif coco_category_names is None: 299 | coco_img_ids = coco.getImgIds() 300 | else: 301 | coco_img_ids = coco.getImgIds(catIds=catIds) 302 | 303 | return coco_img_ids, cat_id_to_cat, catIds, coco 304 | -------------------------------------------------------------------------------- /experiments/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | def create_result_entry(anns, gt_masks_list, masks, scores, img_idx): 5 | argmax_scores = torch.argmax(scores, dim=1) 6 | inference_masks = masks.gather(1, argmax_scores.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand( 7 | (masks.size(0), 1, masks.size(2), masks.size(3)))).squeeze(1) 8 | 9 | def _iou(mask1, mask2): 10 | assert mask1.dim() == 3 11 | assert mask2.dim() == 3 12 | intersection = torch.logical_and(mask1, mask2) 13 | union = torch.logical_or(mask1, mask2) 14 | return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))) 15 | 16 | top_score_ious = _iou(inference_masks, gt_masks_list) 17 | 18 | entry = [] 19 | for idx in range(top_score_ious.size(0)): 20 | entry.append( 21 | [img_idx, anns[idx]['id'], anns[idx]['category_id'], top_score_ious[idx]]) 22 | return entry 23 | 24 | 25 | def calculate_miou(results, mask_debug_out_dir, silent, cat_id_to_cat): 26 | df = pd.DataFrame(results, columns=['img_id', 'ann_id', 'cat_id', 'iou']) 27 | df.to_csv(f'{mask_debug_out_dir}/df.csv') 28 | df['supercategory'] = df['cat_id'].map( 29 | lambda cat_id: cat_id_to_cat[cat_id]['supercategory']) 30 | df['category'] = df['cat_id'].map( 31 | lambda cat_id: cat_id_to_cat[cat_id]['name']) 32 | 33 | # TODO: cross reference the specifics of how we calculate mIoU with 34 | # the SAM folks (should it be per dataset, per category, per image, etc) 35 | # currently, just calculate them all 36 | 37 | # TODO: QOL save the summaries to file 38 | 39 | # per category 40 | per_category = pd.pivot_table( 41 | df, values='iou', index=['cat_id', 'supercategory', 'category'], 42 | aggfunc=('mean', 'count')) 43 | if not silent: 44 | print('\nmIoU averaged per category') 45 | print(per_category) 46 | 47 | # per super-category 48 | per_supercategory = pd.pivot_table( 49 | df, values='iou', index=['supercategory'], 50 | aggfunc=('mean', 'count')) 51 | if not silent: 52 | print('\nmIoU averaged per supercategory') 53 | print(per_supercategory) 54 | 55 | # per all selected masks 56 | per_all_masks_agg = df['iou'].agg(['mean', 'count']) 57 | if not silent: 58 | print('\nmIoU averaged per all selected masks') 59 | print(per_all_masks_agg) 60 | 61 | return df['iou'].agg(['mean', 'count'])['mean'] 62 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs1.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,9.094957113265991,default,2.2.0.dev20231023+cu121,vit_b,1,2789,6,11.067140075172762,90.3575804776637,0.5335705502003885,False,None,None,False,False,False,True,True,32,4952,4952,None,None 3 | bf16,3.1106399655342103,codesign,2.2.0.dev20231023+cu121,vit_b,1,1416,3,36.64852561007727,27.286227299824287,0.5415806118803995,False,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 4 | compile,2.823416900634766,codesign,2.2.0.dev20231023+cu121,vit_b,1,1153,2,51.26125899723373,19.507909473194253,0.5413443029413655,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 5 | SDPA,2.3619819323221845,sdpa-decoder,2.2.0.dev20231023+cu121,vit_b,1,1104,2,66.46056735286892,15.046516149802816,0.5354021280520971,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 6 | Triton,2.1993035833040873,local-fork,2.2.0.dev20231023+cu121,vit_b,1,1104,2,72.37143115549199,13.817607086579134,0.5216650318919089,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 7 | int8,2.495520865917206,local-fork,2.2.0.dev20231023+cu121,vit_b,1,1022,2,68.93271871445695,14.50689917138397,0.5200620686121938,max-autotune,torch.bfloat16,dynamic_quant,False,False,False,True,True,32,4952,4952,None,None 8 | sparse,3.8284019509951275,local-fork,2.2.0.dev20231023+cu121,vit_b,1,1530,3,46.319219168299384,21.5893103976242,0.45554241210035895,max-autotune,torch.bfloat16,sparse,False,False,False,True,True,32,4952,4952,None,None 9 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs1_vit_h.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,35.23240319093068,default,2.2.0.dev20231024+cu121,vit_h,1,5758,14,2.65870772568268,376.1225765209783,0.584173340367447,False,None,None,False,False,False,True,True,32,4952,4952,None,None 3 | bf16,7.502882464726766,codesign,2.2.0.dev20231024+cu121,vit_h,1,2886,7,13.45076982255068,74.34518716716612,0.580705424101227,False,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 4 | compile,6.757023398081461,codesign,2.2.0.dev20231024+cu121,vit_h,1,2832,6,18.488662613774757,54.087200404369,0.5812414398978033,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 5 | SDPA,6.168465121587118,sdpa-decoder,2.2.0.dev20231024+cu121,vit_h,1,2517,6,20.26718162443607,49.340851556503715,0.5811542724237428,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 6 | Triton,6.0584907333056135,local-fork,2.2.0.dev20231024+cu121,vit_h,1,2517,6,20.76397793545598,48.160328580027446,0.5821862905931964,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,4952,4952,None,None 7 | int8,1.8060946822166444,local-fork,2.2.0.dev20231024+cu121,ERROR 8 | sparse,6.413145796457926,local-fork,2.2.0.dev20231024+cu121,vit_h,1,5286,13,22.45630994267693,44.53091369653557,0.5294066301420467,max-autotune,torch.bfloat16,sparse,False,False,False,True,True,32,4952,4952,None,None 9 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs32.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,0.23617535432179768,default,2.2.0.dev20231024+cu121,ERROR 3 | bf16,0.21061046520868937,codesign,2.2.0.dev20231024+cu121,ERROR 4 | compile,2.914792573451996,codesign,2.2.0.dev20231024+cu121,vit_b,32,31077,76,50.12497270213896,19.950135553037967,0.5407576752390846,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 5 | SDPA,2.4382545590400695,sdpa-decoder,2.2.0.dev20231024+cu121,vit_b,32,18128,44,65.21504434452841,15.333885149522262,0.5355346808697282,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 6 | Triton,2.0670506795247396,local-fork,2.2.0.dev20231024+cu121,vit_b,32,6224,15,84.64604662944608,11.813900823716994,0.5339075529136259,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 7 | NT,1.9937538544336955,local-fork,2.2.0.dev20231024+cu121,vit_b,32,6963,17,94.91964930359119,10.535226450337992,0.533777680508926,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,154,4928,None,None 8 | int8,3.1097598036130267,local-fork,2.2.0.dev20231026+cu121,vit_b,32,6878,16,93.73750482853725,10.668088529017064,0.5332023666166303,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,154,4928,None,None 9 | sparse,2.0192723433176676,local-fork,2.2.0.dev20231024+cu121,vit_b,32,7397,18,95.40203078844067,10.481957163129534,0.4781413896120807,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,154,4928,None,None 10 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs32_vit_h.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,0.3408697446187337,default,2.2.0.dev20231024+cu121,ERROR 3 | bf16,0.3232296347618103,codesign,2.2.0.dev20231024+cu121,ERROR 4 | compile,1.4304784536361694,codesign,2.2.0.dev20231024+cu121,ERROR 5 | SDPA,5.974866807460785,sdpa-decoder,2.2.0.dev20231024+cu121,vit_h,32,27276,67,21.754509457085913,45.96748099389014,0.581191777206921,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 6 | Triton,5.73175394932429,local-fork,2.2.0.dev20231024+cu121,vit_h,32,14424,35,22.71537227352625,44.02305134859952,0.5820036887609843,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 7 | NT,5.5431528210639955,local-fork,2.2.0.dev20231024+cu121,vit_h,32,14424,35,23.333015236230114,42.85772712509353,0.5807765231991617,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,154,4928,None,None 8 | int8,6.999959905942281,local-fork,2.2.0.dev20231026+cu121,vit_h,32,14783,36,25.167496549454327,39.73378909717906,0.581791527840287,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,154,4928,None,None 9 | sparse,5.685418633619944,local-fork,2.2.0.dev20231024+cu121,vit_h,32,15108,37,24.31964924223424,41.11901409595043,0.5286787009095467,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,154,4928,None,None 10 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs8.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,8.32037784655889,default,2.2.0.dev20231023+cu121,vit_b,8,19934,49,12.123363560576202,82.4853593644495,0.5335705502003885,False,None,None,False,False,False,True,True,32,619,4952,None,None 3 | bf16,2.861330274740855,codesign,2.2.0.dev20231023+cu121,vit_b,8,10003,24,39.7795247545384,25.13856075884645,0.5415806118803995,False,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 4 | compile,2.652463134129842,codesign,2.2.0.dev20231023+cu121,vit_b,8,7916,19,54.71426032562412,18.276770882922342,0.5407576752390846,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 5 | SDPA,2.148758562405904,sdpa-decoder,2.2.0.dev20231023+cu121,vit_b,8,4679,11,73.1570663251564,13.669219533153035,0.5355346808697282,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 6 | Triton,2.0386854648590087,local-fork,2.2.0.dev20231023+cu121,vit_b,8,1703,4,85.53658249838097,11.690904298391018,0.5339075529136259,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 7 | NT,1.9225259701410928,local-fork,2.2.0.dev20231023+cu121,vit_b,8,2797,6,92.11983959049361,10.85542489484747,0.5337810700594795,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None 8 | int8,3.0942590634028115,local-fork,2.2.0.dev20231026+cu121,vit_b,8,2712,6,90.34645860305645,11.068502467745533,0.5330263012027209,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None 9 | sparse,3.841790223121643,local-fork,2.2.0.dev20231023+cu121,vit_b,8,3217,7,81.4912293589238,12.271259224665185,0.4783508911148021,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None 10 | -------------------------------------------------------------------------------- /experiments/p4d_results/results_bs8_vit_h.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,33.78124016523361,default,2.2.0.dev20231110+cu121,vit_h,8,28806,71,2.7815123804581523,359.51664534215973,0.584173340367447,False,None,None,False,False,False,True,True,32,619,4952,None,None 3 | bf16,6.781990921497345,codesign,2.2.0.dev20231110+cu121,vit_h,8,14424,35,14.949423825567571,66.89221013921143,0.5809121174676433,False,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 4 | compile,6.374551324049632,codesign,2.2.0.dev20231110+cu121,vit_h,8,12358,30,19.174280981156954,52.153194217959204,0.5809779984197878,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 5 | SDPA,5.68677978515625,sdpa-decoder,2.2.0.dev20231110+cu121,vit_h,8,7947,19,21.702906054763186,46.07677872616178,0.5810025119549971,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 6 | Triton,5.398879257837931,local-fork,2.2.0.dev20231110+cu121,vit_h,8,4550,11,23.150324132692813,43.19593947230325,0.5821156148026875,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 7 | NT,5.341391940911611,local-fork,2.2.0.dev20231110+cu121,vit_h,8,4550,11,23.571117205832582,42.42480283253413,0.580580762025661,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None 8 | int8,5.64138038555781,local-fork,2.2.0.dev20231110+cu121,vit_h,8,4167,10,24.98278914178334,40.02755634387976,0.5821159807587736,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None 9 | sparse,5.2578640898068745,local-fork,2.2.0.dev20231110+cu121,vit_h,8,7055,17,24.876148779272,40.19914854477989,0.5287772086382821,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None 10 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | pandas 3 | tqdm 4 | -------------------------------------------------------------------------------- /experiments/results.csv: -------------------------------------------------------------------------------- 1 | technique,time,sam_commit_name,pytorch_version,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path 2 | fp32,8.32037784655889,default,2.2.0.dev20231023+cu121,vit_b,8,19934,49,12.123363560576202,82.4853593644495,0.5335705502003885,False,None,None,False,False,False,True,True,32,619,4952,None,None 3 | bf16,2.861330274740855,codesign,2.2.0.dev20231023+cu121,vit_b,8,10003,24,39.7795247545384,25.13856075884645,0.5415806118803995,False,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 4 | compile,2.652463134129842,codesign,2.2.0.dev20231023+cu121,vit_b,8,7916,19,54.71426032562412,18.276770882922342,0.5407576752390846,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 5 | SDPA,2.148758562405904,sdpa-decoder,2.2.0.dev20231023+cu121,vit_b,8,4679,11,73.1570663251564,13.669219533153035,0.5355346808697282,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 6 | Triton,2.0386854648590087,local-fork,2.2.0.dev20231023+cu121,vit_b,8,1703,4,85.53658249838097,11.690904298391018,0.5339075529136259,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 7 | NT,1.9225259701410928,local-fork,2.2.0.dev20231023+cu121,vit_b,8,2797,6,92.11983959049361,10.85542489484747,0.5337810700594795,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None 8 | int8,3.0942590634028115,local-fork,2.2.0.dev20231026+cu121,vit_b,8,2712,6,90.34645860305645,11.068502467745533,0.5330263012027209,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None 9 | sparse,3.841790223121643,local-fork,2.2.0.dev20231023+cu121,vit_b,8,3217,7,81.4912293589238,12.271259224665185,0.4783508911148021,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None 10 | fp32,33.77061091264089,default,2.2.0.dev20231024+cu121,vit_h,8,28806,71,2.7820335945039893,359.44928989194705,0.584173340367447,False,None,None,False,False,False,True,True,32,619,4952,None,None 11 | bf16,6.822473649183909,codesign,2.2.0.dev20231024+cu121,vit_h,8,14424,35,14.850424350893103,67.33814309756475,0.5809121174676433,False,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 12 | compile,7.968364950021108,codesign,2.2.0.dev20231024+cu121,vit_h,8,12358,30,19.69605657526638,50.77158446304247,0.5811320849834102,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 13 | SDPA,5.843019040425618,sdpa-decoder,2.2.0.dev20231024+cu121,vit_h,8,7947,19,21.92026495560376,45.61988653081299,0.581191777206921,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 14 | Triton,9.09047209819158,local-fork,2.2.0.dev20231024+cu121,vit_h,8,4550,11,22.874989934428537,43.71586623060877,0.5820036887609843,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,619,4952,None,None 15 | NT,5.455243261655172,local-fork,2.2.0.dev20231024+cu121,vit_h,8,4550,11,23.206823845253847,43.09077393219044,0.5809004559961229,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,619,4952,None,None 16 | int8,6.994769084453583,local-fork,2.2.0.dev20231026+cu121,vit_h,8,4167,10,24.87583443921619,40.19965651578395,0.5819033780783904,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,619,4952,None,None 17 | sparse,5.597406772772471,local-fork,2.2.0.dev20231024+cu121,vit_h,8,7055,17,24.900183397177024,40.16034677533225,0.5289167514647479,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,619,4952,None,None 18 | compile,2.914792573451996,codesign,2.2.0.dev20231024+cu121,vit_b,32,31077,76,50.12497270213896,19.950135553037967,0.5407576752390846,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 19 | SDPA,2.4382545590400695,sdpa-decoder,2.2.0.dev20231024+cu121,vit_b,32,18128,44,65.21504434452841,15.333885149522262,0.5355346808697282,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 20 | Triton,2.0670506795247396,local-fork,2.2.0.dev20231024+cu121,vit_b,32,6224,15,84.64604662944608,11.813900823716994,0.5339075529136259,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 21 | NT,1.9937538544336955,local-fork,2.2.0.dev20231024+cu121,vit_b,32,6963,17,94.91964930359119,10.535226450337992,0.533777680508926,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,154,4928,None,None 22 | int8,3.1097598036130267,local-fork,2.2.0.dev20231026+cu121,vit_b,32,6878,16,93.73750482853725,10.668088529017064,0.5332023666166303,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,154,4928,None,None 23 | sparse,2.0192723433176676,local-fork,2.2.0.dev20231024+cu121,vit_b,32,7397,18,95.40203078844067,10.481957163129534,0.4781413896120807,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,154,4928,None,None 24 | SDPA,5.974866807460785,sdpa-decoder,2.2.0.dev20231024+cu121,vit_h,32,27276,67,21.754509457085913,45.96748099389014,0.581191777206921,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 25 | Triton,5.73175394932429,local-fork,2.2.0.dev20231024+cu121,vit_h,32,14424,35,22.71537227352625,44.02305134859952,0.5820036887609843,max-autotune,torch.bfloat16,None,False,False,False,True,True,32,154,4928,None,None 26 | NT,5.5431528210639955,local-fork,2.2.0.dev20231024+cu121,vit_h,32,14424,35,23.333015236230114,42.85772712509353,0.5807765231991617,max-autotune,torch.bfloat16,None,False,False,True,True,True,32,154,4928,None,None 27 | int8,6.999959905942281,local-fork,2.2.0.dev20231026+cu121,vit_h,32,14783,36,25.167496549454327,39.73378909717906,0.581791527840287,max-autotune,torch.bfloat16,dynamic_quant,False,False,True,True,True,32,154,4928,None,None 28 | sparse,5.685418633619944,local-fork,2.2.0.dev20231024+cu121,vit_h,32,15108,37,24.31964924223424,41.11901409595043,0.5286787009095467,max-autotune,torch.bfloat16,sparse,False,False,True,True,True,32,154,4928,None,None 29 | -------------------------------------------------------------------------------- /experiments/run_experiments.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import fire 3 | import itertools 4 | import functools 5 | 6 | sam_commits = { 7 | "default": "6fdee8f2727f4506cfbbe553e23b895e27956588", 8 | "graphbreaks": "55f772f77864752f2e98a6fc7713b45a1843c167", 9 | "codesign": "50cb459d080bcd783a4b481d3bde4150d35ac497", 10 | "sdpa": "22f654553bbe7aa28337ce34a25f1a9d27cee111", 11 | "sdpa-decoder": "7dc75fdf283693f73606f2fe7fdcb693afcb16b9", 12 | "predict-masks-nested": "187e2359f9eb3b00d43487a1ec3db849964753e4", 13 | "use-rel-pos": "d2fa29d580eaf7928eef702cd71d133b943c30cf", 14 | "hacky-nested-encoder": "8f2fc3cc90b222a2431d4c43379282e36f021b69", 15 | "wip-flash-nested": "e01edb904a49c449425fca9e48902824b22cf764", 16 | "wip-flash-sdpa-decoder": "bb1c8b6f3749b1a5f31635f5d2f26bcafa9d94f9"} 17 | 18 | 19 | 20 | def change_sam_commit(sam_path, commit_name): 21 | assert commit_name in sam_commits 22 | root_cmd = ["git", "-C", sam_path] 23 | result = subprocess.run( 24 | root_cmd + ["checkout", sam_commits[commit_name]], capture_output=True) 25 | assert result.returncode == 0 26 | result = subprocess.run( 27 | root_cmd + ["rev-parse", "HEAD"], capture_output=True) 28 | assert result.returncode == 0 29 | 30 | 31 | def run_experiment(experiments_data, 32 | sam_path, 33 | model_type, 34 | idx, 35 | sam_commit_name, 36 | batch_size=1, 37 | num_workers=0, 38 | use_half=None, 39 | use_compile="False", 40 | compress=None, 41 | use_nested_tensor=False, 42 | extra_args=None, 43 | print_header=False, 44 | capture_output=True, 45 | limit=None, 46 | profile_path=None, 47 | profile_top=False, 48 | memory_path=None, 49 | device="cuda"): 50 | root_cmd = ["python", "eval_combo.py", 51 | "--coco_root_dir", 52 | f"{experiments_data}/datasets/coco2017", 53 | "--coco_slice_name", 54 | "val2017", 55 | "--sam_checkpoint_base_path", 56 | f"{experiments_data}/checkpoints", 57 | "--sam_model_type", 58 | "vit_b", 59 | "--point_sampling_cache_dir", 60 | f"{experiments_data}/tmp/sam_coco_mask_center_cache", 61 | "--mask_debug_out_dir", 62 | f"{experiments_data}/tmp/sam_eval_masks_out"] 63 | args = root_cmd 64 | args = args + ["--sam_model_type", model_type] 65 | args = args + ["--batch_size", str(batch_size)] 66 | args = args + ["--num_workers", str(num_workers)] 67 | args = args + ["--use_compile", use_compile] 68 | if sam_commit_name == "local-fork": 69 | args = args + ["--use_local_sam_fork", "True"] 70 | else: 71 | change_sam_commit(sam_path, sam_commit_name) 72 | if use_half: 73 | args = args + ["--use_half", use_half] 74 | if compress is not None: 75 | args = args + ["--compress", compress] 76 | if use_nested_tensor: 77 | args = args + ["--use_nested_tensor", str(use_nested_tensor)] 78 | if limit is not None: 79 | args = args + ["--limit", str(limit)] 80 | if profile_path is not None: 81 | args = args + ["--profile-path", profile_path] 82 | if profile_top: 83 | args = args + ["--profile-top", "True"] 84 | if memory_path is not None: 85 | args = args + ["--memory-path", memory_path] 86 | if extra_args is None: 87 | extra_args = [] 88 | args = args + ["--device", device] 89 | args = args + extra_args 90 | if print_header: 91 | args = args + ["--print_header", "True"] 92 | import time 93 | t0 = time.time() 94 | result = subprocess.run(args, capture_output=capture_output) 95 | if not capture_output: 96 | return 97 | t1 = time.time() 98 | import torch 99 | pytorch_version = torch.__version__ 100 | prefix = ",".join( 101 | map(str, [idx, (t1 - t0)/60.0, sam_commit_name, pytorch_version])) 102 | if result.returncode != 0: 103 | print(prefix + ",ERROR") 104 | return 105 | if print_header: 106 | header = result.stdout.decode().split("\n")[-3] 107 | print("technique,time,sam_commit_name,pytorch_version," + header) 108 | print(prefix + "," + result.stdout.decode().split("\n")[-2]) 109 | 110 | 111 | def run_traces_fn(traces_dir, pytorch_path, rexp, *args, **kwargs): 112 | # Limit to 10 batches 113 | kwargs['limit'] = 160 114 | 115 | # Create kernel traces 116 | profile_path = f"{traces_dir}/{args[0]}.json.gz" 117 | kwargs['profile_path'] = profile_path 118 | rexp(*args, **kwargs) 119 | kwargs['profile_path'] = None 120 | 121 | # Don't print header again if already printed 122 | kwargs['print_header'] = False 123 | 124 | # Create memory trace 125 | if 'use_compile' in kwargs and kwargs['use_compile'] == "max-autotune": 126 | # Memory traces don't seem to support CUDA graphs 127 | kwargs['use_compile'] = "max-autotune-no-cudagraphs" 128 | 129 | memory_path = f"{traces_dir}/{args[0]}" 130 | kwargs['memory_path'] = memory_path + ".pickle" 131 | rexp(*args, **kwargs) 132 | kwargs['memory_path'] = None 133 | 134 | # Convert memory trace to html page 135 | conversion_cmd = ["python", f"{pytorch_path}/torch/cuda/_memory_viz.py", 136 | "trace_plot", memory_path + ".pickle", "-o", memory_path + ".html"] 137 | result = subprocess.run(conversion_cmd, capture_output=True) 138 | 139 | def run(batch_size, 140 | model, 141 | pytorch_path, 142 | sam_path, 143 | experiments_data, 144 | run_traces=False, 145 | run_experiments=False, 146 | traces_dir=None, 147 | num_workers=32, 148 | print_header=True, 149 | capture_output=True, 150 | local_fork_only=False, 151 | device="cuda"): 152 | 153 | assert model == "vit_b" or model == "vit_h" 154 | 155 | rexp = functools.partial(run_experiment, 156 | experiments_data, 157 | sam_path, 158 | model, 159 | batch_size=batch_size, 160 | num_workers=num_workers, 161 | capture_output=capture_output, 162 | device=device) 163 | 164 | print_header = True 165 | if run_traces: 166 | assert traces_dir is not None 167 | rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp) 168 | 169 | if local_fork_only: 170 | rt("fp32", "local-fork", print_header=print_header) 171 | rt("fp16", "local-fork", use_half="bfloat16") 172 | rt("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune") 173 | # The local fork already uses SDPA + Triton for all of the above experiments. 174 | # local_fork_only mainly exists to ablate the order in which we apply 175 | # techniques and cannot be used to reproduce the experimental results 176 | else: 177 | rt("fp32", "default", print_header=print_header) 178 | rt("fp16", "codesign", use_half="bfloat16") 179 | rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune") 180 | rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune") 181 | rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune") 182 | if batch_size > 1: 183 | rt("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True) 184 | rt("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant") 185 | rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse") 186 | 187 | if run_experiments: 188 | if local_fork_only: 189 | rexp("fp32", "local-fork", print_header=print_header) 190 | rexp("bf16", "local-fork", use_half="bfloat16") 191 | rexp("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune") 192 | # The local fork already uses SDPA + Triton for all of the above experiments. 193 | # local_fork_only mainly exists to ablate the order in which we apply 194 | # techniques and cannot be used to reproduce the experimental results 195 | else: 196 | rexp("fp32", "default", print_header=print_header) 197 | rexp("bf16", "codesign", use_half="bfloat16") 198 | rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune") 199 | rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune") 200 | rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune") 201 | if batch_size > 1: 202 | rexp("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1)) 203 | rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="dynamic_quant") 204 | rexp("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="sparse") 205 | 206 | 207 | if __name__ == '__main__': 208 | fire.Fire(run) 209 | -------------------------------------------------------------------------------- /experiments/summary_chart.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import fire 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | 6 | COLORS = list(matplotlib.colors.TABLEAU_COLORS.values()) 7 | 8 | def make_sub_chart(batch_size_idx, techniques, df, ax, title, category_column, value_column, ylim_low, ylim_high, data_format, label, va): 9 | x_values = [] 10 | y_values = [] 11 | bar_colors = [] 12 | x_idx = 0 13 | for key in techniques.keys(): 14 | if key in df[category_column].tolist(): 15 | x_values.append(key) 16 | y_values.append(df[value_column].tolist()[x_idx]) 17 | bar_colors.append(COLORS[batch_size_idx]) 18 | x_idx += 1 19 | else: 20 | x_values.append(key) 21 | y_values.append(0) 22 | x_coords = [] 23 | for name in df[category_column]: 24 | if name in techniques: 25 | x_coords.append(techniques[name]) 26 | ax.bar(x_values, y_values, label=label, color=bar_colors) 27 | 28 | # Customize the chart labels and title 29 | ax.set_xlabel(category_column) 30 | ax.set_ylabel(value_column) 31 | ax.set_title(title) 32 | if ylim_low is None: 33 | assert ylim_high is None 34 | else: 35 | ax.set_ylim(ylim_low, ylim_high) 36 | 37 | tick_positions = ax.get_yticks() 38 | for tick in tick_positions: 39 | ax.axhline(y=tick, color='gray', linestyle='--', alpha=0.7) 40 | 41 | # Add data labels or data points above the bars 42 | for x, value in zip(x_coords, df[value_column]): 43 | ax.text(x, value, data_format.format(value), ha='center', va=va) 44 | 45 | 46 | def make_row_chart(batch_size_idx, techniques, df, value_column, ax1, ax2, label, ylim_low, ylim_high, va, title="", relative=False, data_format=None): 47 | category_column = "technique" 48 | if not isinstance(ylim_low, tuple): 49 | ylim_low = (ylim_low, ylim_low) 50 | if not isinstance(ylim_high, tuple): 51 | ylim_high = (ylim_high, ylim_high) 52 | 53 | def helper(sam_model_type, ax1, ylim_low, ylim_high, va): 54 | vit_b_df = df[df['sam_model_type'] == sam_model_type] 55 | 56 | vit_b_df = vit_b_df.copy() 57 | 58 | if relative: 59 | vit_b_df[value_column] = vit_b_df[value_column].div( 60 | vit_b_df[value_column].iloc[0]) 61 | 62 | make_sub_chart(batch_size_idx, techniques, vit_b_df, ax1, f"{title} for {sam_model_type}", 63 | category_column, value_column, ylim_low, ylim_high, data_format, label, va) 64 | helper("vit_b", ax1, ylim_low[0], ylim_high[0], va) 65 | helper("vit_h", ax2, ylim_low[1], ylim_high[1], va) 66 | 67 | def run(csv_file, 68 | fig_format): 69 | matplotlib.rcParams.update({'font.size': 12}) 70 | 71 | mdf_ = pd.read_csv(csv_file) 72 | mdf = mdf_.dropna(subset=["batch_size"]) 73 | techniques = {'fp32': 0, 'bf16': 1, 'compile': 2, 'SDPA': 3, 'Triton': 4, 'NT': 5, 'int8': 6, 'sparse': 7} 74 | print("techniques: ", techniques) 75 | 76 | fig, axs = plt.subplots(3, 2, figsize=(20, 20)) 77 | 78 | for batch_size_idx, (batch_size, hlim, va) in enumerate(zip([32, 8], [100, 100], ["bottom", "top"])): 79 | df = mdf[mdf["batch_size"] == batch_size] 80 | make_row_chart(batch_size_idx, techniques, df, "img_s(avg)", *axs[0], f"Batch size {batch_size}", (0.0, 0.0), (100.0, 25.0), va, 81 | "Images per second", data_format="{:.2f}") 82 | make_row_chart(batch_size_idx, techniques, df, "memory(MiB)", *axs[1], f"Batch size {batch_size}", 0, 40000, va, 83 | title="Memory savings", data_format="{:.0f}") 84 | make_row_chart(batch_size_idx, techniques, df, "mIoU", *axs[2], f"Batch size {batch_size}", 0.0, 1.0, va, 85 | title="Accuracy", data_format="{:.2f}") 86 | for ax in axs: 87 | ax[0].legend() 88 | ax[1].legend() 89 | # plt.tick_params(axis='both', which='both', length=10) 90 | plt.tight_layout() 91 | 92 | fig.savefig(f'bar_chart.{fig_format}', format=fig_format) 93 | 94 | if __name__ == '__main__': 95 | fire.Fire(run) 96 | -------------------------------------------------------------------------------- /segment_anything_fast/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | build_sam_fast, 14 | build_sam_fast_vit_h, 15 | build_sam_fast_vit_l, 16 | build_sam_fast_vit_b, 17 | sam_model_fast_registry, 18 | ) 19 | from .predictor import SamPredictor 20 | from .automatic_mask_generator import SamAutomaticMaskGenerator 21 | -------------------------------------------------------------------------------- /segment_anything_fast/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | mask_to_rle_pytorch_2, 28 | remove_small_regions, 29 | rle_to_mask, 30 | uncrop_boxes_xyxy, 31 | uncrop_masks, 32 | uncrop_points, 33 | ) 34 | 35 | 36 | class SamAutomaticMaskGenerator: 37 | def __init__( 38 | self, 39 | model: Sam, 40 | points_per_side: Optional[int] = 32, 41 | points_per_batch: int = 64, 42 | pred_iou_thresh: float = 0.88, 43 | stability_score_thresh: float = 0.95, 44 | stability_score_offset: float = 1.0, 45 | box_nms_thresh: float = 0.7, 46 | crop_n_layers: int = 0, 47 | crop_nms_thresh: float = 0.7, 48 | crop_overlap_ratio: float = 512 / 1500, 49 | crop_n_points_downscale_factor: int = 1, 50 | point_grids: Optional[List[np.ndarray]] = None, 51 | min_mask_region_area: int = 0, 52 | output_mode: str = "binary_mask", 53 | process_batch_size: Optional[int] = None, 54 | ) -> None: 55 | """ 56 | Using a SAM model, generates masks for the entire image. 57 | Generates a grid of point prompts over the image, then filters 58 | low quality and duplicate masks. The default settings are chosen 59 | for SAM with a ViT-H backbone. 60 | 61 | Arguments: 62 | model (Sam): The SAM model to use for mask prediction. 63 | points_per_side (int or None): The number of points to be sampled 64 | along one side of the image. The total number of points is 65 | points_per_side**2. If None, 'point_grids' must provide explicit 66 | point sampling. 67 | points_per_batch (int): Sets the number of points run simultaneously 68 | by the model. Higher numbers may be faster but use more GPU memory. 69 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 70 | model's predicted mask quality. 71 | stability_score_thresh (float): A filtering threshold in [0,1], using 72 | the stability of the mask under changes to the cutoff used to binarize 73 | the model's mask predictions. 74 | stability_score_offset (float): The amount to shift the cutoff when 75 | calculated the stability score. 76 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 77 | suppression to filter duplicate masks. 78 | crop_n_layers (int): If >0, mask prediction will be run again on 79 | crops of the image. Sets the number of layers to run, where each 80 | layer has 2**i_layer number of image crops. 81 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 82 | suppression to filter duplicate masks between different crops. 83 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 84 | In the first crop layer, crops will overlap by this fraction of 85 | the image length. Later layers with more crops scale down this overlap. 86 | crop_n_points_downscale_factor (int): The number of points-per-side 87 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 88 | point_grids (list(np.ndarray) or None): A list over explicit grids 89 | of points used for sampling, normalized to [0,1]. The nth grid in the 90 | list is used in the nth crop layer. Exclusive with points_per_side. 91 | min_mask_region_area (int): If >0, postprocessing will be applied 92 | to remove disconnected regions and holes in masks with area smaller 93 | than min_mask_region_area. Requires opencv. 94 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 95 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 96 | For large resolutions, 'binary_mask' may consume large amounts of 97 | memory. 98 | process_batch_size (int or None): Set a batch size for the decoding step. 99 | If None, all points will be batched up at once. Set a small number here 100 | to decrease memory footprint. A smaller number will likely decrease 101 | latency, but also decrease memory usage. 102 | """ 103 | 104 | assert (points_per_side is None) != ( 105 | point_grids is None 106 | ), "Exactly one of points_per_side or point_grid must be provided." 107 | if points_per_side is not None: 108 | self.point_grids = build_all_layer_point_grids( 109 | points_per_side, 110 | crop_n_layers, 111 | crop_n_points_downscale_factor, 112 | ) 113 | elif point_grids is not None: 114 | self.point_grids = point_grids 115 | else: 116 | raise ValueError("Can't have both points_per_side and point_grid be None.") 117 | 118 | assert output_mode in [ 119 | "binary_mask", 120 | "uncompressed_rle", 121 | "coco_rle", 122 | ], f"Unknown output_mode {output_mode}." 123 | if output_mode == "coco_rle": 124 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 125 | 126 | if min_mask_region_area > 0: 127 | import cv2 # type: ignore # noqa: F401 128 | 129 | self.predictor = SamPredictor(model) 130 | self.points_per_batch = points_per_batch 131 | self.pred_iou_thresh = pred_iou_thresh 132 | self.stability_score_thresh = stability_score_thresh 133 | self.stability_score_offset = stability_score_offset 134 | self.box_nms_thresh = box_nms_thresh 135 | self.crop_n_layers = crop_n_layers 136 | self.crop_nms_thresh = crop_nms_thresh 137 | self.crop_overlap_ratio = crop_overlap_ratio 138 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 139 | self.min_mask_region_area = min_mask_region_area 140 | self.output_mode = output_mode 141 | self.process_batch_size = process_batch_size 142 | 143 | @torch.no_grad() 144 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 145 | """ 146 | Generates masks for the given image. 147 | 148 | Arguments: 149 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 150 | 151 | Returns: 152 | list(dict(str, any)): A list over records for masks. Each record is 153 | a dict containing the following keys: 154 | segmentation (dict(str, any) or np.ndarray): The mask. If 155 | output_mode='binary_mask', is an array of shape HW. Otherwise, 156 | is a dictionary containing the RLE. 157 | bbox (list(float)): The box around the mask, in XYWH format. 158 | area (int): The area in pixels of the mask. 159 | predicted_iou (float): The model's own prediction of the mask's 160 | quality. This is filtered by the pred_iou_thresh parameter. 161 | point_coords (list(list(float))): The point coordinates input 162 | to the model to generate this mask. 163 | stability_score (float): A measure of the mask's quality. This 164 | is filtered on using the stability_score_thresh parameter. 165 | crop_box (list(float)): The crop of the image used to generate 166 | the mask, given in XYWH format. 167 | """ 168 | 169 | # Generate masks 170 | mask_data = self._generate_masks(image) 171 | 172 | # Filter small disconnected regions and holes in masks 173 | if self.min_mask_region_area > 0: 174 | mask_data = self.postprocess_small_regions( 175 | mask_data, 176 | self.min_mask_region_area, 177 | max(self.box_nms_thresh, self.crop_nms_thresh), 178 | ) 179 | 180 | # Encode masks 181 | if self.output_mode == "coco_rle": 182 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 183 | elif self.output_mode == "binary_mask": 184 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 185 | else: 186 | mask_data["segmentations"] = mask_data["rles"] 187 | 188 | # Write mask records 189 | curr_anns = [] 190 | for idx in range(len(mask_data["segmentations"])): 191 | ann = { 192 | "segmentation": mask_data["segmentations"][idx], 193 | "area": area_from_rle(mask_data["rles"][idx]), 194 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 195 | "predicted_iou": mask_data["iou_preds"][idx].item(), 196 | "point_coords": [mask_data["points"][idx].tolist()], 197 | "stability_score": mask_data["stability_score"][idx].item(), 198 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 199 | } 200 | curr_anns.append(ann) 201 | 202 | return curr_anns 203 | 204 | def _generate_masks(self, image: np.ndarray) -> MaskData: 205 | orig_size = image.shape[:2] 206 | crop_boxes, layer_idxs = generate_crop_boxes( 207 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 208 | ) 209 | 210 | # Iterate over image crops 211 | data = MaskData() 212 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 213 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 214 | data.cat(crop_data) 215 | 216 | # Remove duplicate masks between crops 217 | if len(crop_boxes) > 1: 218 | # Prefer masks from smaller crops 219 | scores = 1 / box_area(data["crop_boxes"]) 220 | scores = scores.to(data["boxes"].device) 221 | keep_by_nms = batched_nms( 222 | data["boxes"].float(), 223 | scores, 224 | torch.zeros_like(data["boxes"][:, 0]), # categories 225 | iou_threshold=self.crop_nms_thresh, 226 | ) 227 | data.filter(keep_by_nms) 228 | 229 | data.to_numpy() 230 | return data 231 | 232 | def _process_crop( 233 | self, 234 | image: np.ndarray, 235 | crop_box: List[int], 236 | crop_layer_idx: int, 237 | orig_size: Tuple[int, ...], 238 | ) -> MaskData: 239 | # Crop the image and calculate embeddings 240 | x0, y0, x1, y1 = crop_box 241 | cropped_im = image[y0:y1, x0:x1, :] 242 | cropped_im_size = cropped_im.shape[:2] 243 | self.predictor.set_image(cropped_im) 244 | 245 | # Get points for this crop 246 | points_scale = np.array(cropped_im_size)[None, ::-1] 247 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 248 | 249 | # Generate masks for this crop in batches 250 | data = MaskData() 251 | all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)] 252 | process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size 253 | for i in range(0, len(all_points), process_batch_size): 254 | some_points = all_points[i:i+process_batch_size] 255 | batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size) 256 | data.cat(batch_data) 257 | data["rles"] = mask_to_rle_pytorch_2(data["masks"]) 258 | self.predictor.reset_image() 259 | 260 | # Remove duplicates within this crop. 261 | keep_by_nms = batched_nms( 262 | data["boxes"].float(), 263 | data["iou_preds"], 264 | torch.zeros_like(data["boxes"][:, 0]), # categories 265 | iou_threshold=self.box_nms_thresh, 266 | ) 267 | data.filter(keep_by_nms) 268 | 269 | # Return to the original image frame 270 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 271 | data["points"] = uncrop_points(data["points"], crop_box) 272 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 273 | 274 | return data 275 | 276 | def _process_batch( 277 | self, 278 | all_points: List[np.ndarray], 279 | im_size: Tuple[int, ...], 280 | crop_box: List[int], 281 | orig_size: Tuple[int, ...], 282 | ) -> MaskData: 283 | orig_h, orig_w = orig_size 284 | nt_in_points = [] 285 | for points in all_points: 286 | # Run model on this batch 287 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 288 | in_points = torch.as_tensor(transformed_points) #, device=self.predictor.device) 289 | nt_in_points.append(in_points) 290 | 291 | nt_in_points = torch.nested.nested_tensor(nt_in_points, layout=torch.jagged, pin_memory=True).to(device=self.predictor.device, non_blocking=True) 292 | # The call to prod is a workaround to share jagged sizes between two NestedTensors. 293 | nt_in_labels = torch.ones_like(nt_in_points, dtype=torch.int).prod(dim=-1, keepdim=True) 294 | nt_in_points = nt_in_points.unsqueeze(2) 295 | 296 | self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))] 297 | self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))] 298 | nt_masks, nt_iou_preds, _ = self.predictor.predict_torch( 299 | point_coords=nt_in_points, 300 | point_labels=nt_in_labels, 301 | multimask_output=True, 302 | return_logits=True, 303 | ) 304 | 305 | data = MaskData() 306 | for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points): 307 | batch_data = self._process_batch_2(masks, iou_preds, points, im_size, crop_box, orig_size) 308 | data.cat(batch_data) 309 | return data 310 | 311 | # TODO: Batch this up 312 | def _process_batch_2( 313 | self, 314 | masks: torch.Tensor, 315 | iou_preds: torch.Tensor, 316 | points: torch.Tensor, 317 | im_size: Tuple[int, ...], 318 | crop_box: List[int], 319 | orig_size: Tuple[int, ...], 320 | ) -> MaskData: 321 | orig_h, orig_w = orig_size 322 | # Serialize predictions and store in MaskData 323 | data = MaskData( 324 | masks=masks.flatten(0, 1), 325 | iou_preds=iou_preds.flatten(0, 1), 326 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 327 | ) 328 | del masks 329 | 330 | # Filter by predicted IoU 331 | if self.pred_iou_thresh > 0.0: 332 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 333 | data.filter(keep_mask) 334 | 335 | # Calculate stability score 336 | data["stability_score"] = calculate_stability_score( 337 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 338 | ) 339 | if self.stability_score_thresh > 0.0: 340 | keep_mask = data["stability_score"] >= self.stability_score_thresh 341 | data.filter(keep_mask) 342 | 343 | # Threshold masks and calculate boxes 344 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 345 | data["boxes"] = batched_mask_to_box(data["masks"]) 346 | 347 | # Filter boxes that touch crop boundaries 348 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 349 | if not torch.all(keep_mask): 350 | data.filter(keep_mask) 351 | 352 | # Compress to RLE 353 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 354 | # Doing this once at the end across all masks. 355 | # data["rles"] = mask_to_rle_pytorch(data["masks"].cpu()) 356 | # Keeping the masks around is faster, even though it uses more memory. 357 | # del data["masks"] 358 | 359 | return data 360 | 361 | @staticmethod 362 | def postprocess_small_regions( 363 | mask_data: MaskData, min_area: int, nms_thresh: float 364 | ) -> MaskData: 365 | """ 366 | Removes small disconnected regions and holes in masks, then reruns 367 | box NMS to remove any new duplicates. 368 | 369 | Edits mask_data in place. 370 | 371 | Requires open-cv as a dependency. 372 | """ 373 | if len(mask_data["rles"]) == 0: 374 | return mask_data 375 | 376 | # Filter small disconnected regions and holes 377 | new_masks = [] 378 | scores = [] 379 | for rle in mask_data["rles"]: 380 | mask = rle_to_mask(rle) 381 | 382 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 383 | unchanged = not changed 384 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 385 | unchanged = unchanged and not changed 386 | 387 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 388 | # Give score=0 to changed masks and score=1 to unchanged masks 389 | # so NMS will prefer ones that didn't need postprocessing 390 | scores.append(float(unchanged)) 391 | 392 | # Recalculate boxes and remove any new duplicates 393 | masks = torch.cat(new_masks, dim=0) 394 | boxes = batched_mask_to_box(masks) 395 | keep_by_nms = batched_nms( 396 | boxes.float(), 397 | torch.as_tensor(scores), 398 | torch.zeros_like(boxes[:, 0]), # categories 399 | iou_threshold=nms_thresh, 400 | ) 401 | 402 | # Only recalculate RLEs for masks that have changed 403 | for i_mask in keep_by_nms: 404 | if scores[i_mask] == 0.0: 405 | mask_torch = masks[i_mask].unsqueeze(0) 406 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 407 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 408 | mask_data.filter(keep_by_nms) 409 | 410 | return mask_data 411 | -------------------------------------------------------------------------------- /segment_anything_fast/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | def _apply_eval_dtype_sam(model, dtype): 55 | 56 | def prep_model(model, dtype): 57 | if dtype is not None: 58 | return model.eval().to(dtype) 59 | return model.eval() 60 | 61 | model.image_encoder = prep_model(model.image_encoder, dtype) 62 | model.prompt_encoder = prep_model(model.prompt_encoder, dtype) 63 | model.mask_decoder = prep_model(model.mask_decoder, dtype) 64 | 65 | return model 66 | 67 | def build_sam_fast_vit_h(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): 68 | sam = build_sam_vit_h(checkpoint) 69 | sam = _apply_eval_dtype_sam(sam, dtype) 70 | sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) 71 | return sam 72 | 73 | build_sam_fast = build_sam_fast_vit_h 74 | 75 | def build_sam_fast_vit_l(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): 76 | sam = build_sam_vit_l(checkpoint) 77 | sam = _apply_eval_dtype_sam(sam, dtype) 78 | sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) 79 | return sam 80 | 81 | def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): 82 | sam = build_sam_vit_b(checkpoint) 83 | sam = _apply_eval_dtype_sam(sam, dtype) 84 | sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) 85 | return sam 86 | 87 | sam_model_fast_registry = { 88 | "default": build_sam_fast_vit_h, 89 | "vit_h": build_sam_fast_vit_h, 90 | "vit_l": build_sam_fast_vit_l, 91 | "vit_b": build_sam_fast_vit_b, 92 | } 93 | 94 | 95 | def _build_sam( 96 | encoder_embed_dim, 97 | encoder_depth, 98 | encoder_num_heads, 99 | encoder_global_attn_indexes, 100 | checkpoint=None, 101 | ): 102 | prompt_embed_dim = 256 103 | image_size = 1024 104 | vit_patch_size = 16 105 | image_embedding_size = image_size // vit_patch_size 106 | sam = Sam( 107 | image_encoder=ImageEncoderViT( 108 | depth=encoder_depth, 109 | embed_dim=encoder_embed_dim, 110 | img_size=image_size, 111 | mlp_ratio=4, 112 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 113 | num_heads=encoder_num_heads, 114 | patch_size=vit_patch_size, 115 | qkv_bias=True, 116 | use_rel_pos=True, 117 | global_attn_indexes=encoder_global_attn_indexes, 118 | window_size=14, 119 | out_chans=prompt_embed_dim, 120 | ), 121 | prompt_encoder=PromptEncoder( 122 | embed_dim=prompt_embed_dim, 123 | image_embedding_size=(image_embedding_size, image_embedding_size), 124 | input_image_size=(image_size, image_size), 125 | mask_in_chans=16, 126 | ), 127 | mask_decoder=MaskDecoder( 128 | num_multimask_outputs=3, 129 | transformer=TwoWayTransformer( 130 | depth=2, 131 | embedding_dim=prompt_embed_dim, 132 | mlp_dim=2048, 133 | num_heads=8, 134 | ), 135 | transformer_dim=prompt_embed_dim, 136 | iou_head_depth=3, 137 | iou_head_hidden_dim=256, 138 | ), 139 | pixel_mean=[123.675, 116.28, 103.53], 140 | pixel_std=[58.395, 57.12, 57.375], 141 | ) 142 | sam.eval() 143 | if checkpoint is not None: 144 | with open(checkpoint, "rb") as f: 145 | state_dict = torch.load(f, weights_only=True) 146 | sam.load_state_dict(state_dict) 147 | return sam 148 | -------------------------------------------------------------------------------- /segment_anything_fast/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/segment_anything_fast/configs/__init__.py -------------------------------------------------------------------------------- /segment_anything_fast/configs/flash_4_configs_a100.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/segment-anything-fast/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/segment_anything_fast/configs/flash_4_configs_a100.p -------------------------------------------------------------------------------- /segment_anything_fast/flash_4.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fused Attention 3 | =============== 4 | 5 | This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) 6 | 7 | Extra Credits: 8 | - Original flash attention paper (https://arxiv.org/abs/2205.14135) 9 | - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) 10 | - Adam P. Goucher for simplified vector math 11 | 12 | This version was modified to fuse an addition of two attention masks into one 13 | attn_bias = (rel_h_ + rel_w_).view(q_.size(0), q_.size(1), rel_h_.size(2), rel_h_.size(3) * rel_w_.size(4)) 14 | 15 | We use attn_mask and attn_bias interchangeably. 16 | 17 | This modification was designed by Christian Puhrsch and Daniel Haziza 18 | 19 | """ 20 | 21 | import torch 22 | 23 | import triton 24 | import triton.language as tl 25 | 26 | import os 27 | import pathlib 28 | 29 | 30 | @triton.jit 31 | def _fwd_kernel_aligned( 32 | Q, K, V, B0, sm_scale, 33 | Out, 34 | stride_qh, stride_qm, stride_qk, 35 | stride_kh, stride_kn, stride_kk, 36 | stride_vh, stride_vk, stride_vn, 37 | stride_oh, stride_om, stride_on, 38 | stride_b0h, stride_b0m, 39 | Z, 40 | H, 41 | N_CTX, 42 | P_SEQ, 43 | OUT_DTYPE: tl.constexpr, 44 | BIAS_LAST_SIZE: tl.constexpr, 45 | B0_NUMEL: tl.constexpr, 46 | BLOCK_DMODEL: tl.constexpr, 47 | BLOCK_M: tl.constexpr, 48 | BLOCK_N: tl.constexpr, 49 | ): 50 | start_m = tl.program_id(0) 51 | off_hz = tl.program_id(1) 52 | q_offset = off_hz * stride_qh 53 | kv_offset = off_hz * stride_kh 54 | Q_block_ptr = tl.make_block_ptr( 55 | base=Q + q_offset, 56 | shape=(N_CTX, BLOCK_DMODEL), 57 | strides=(stride_qm, stride_qk), 58 | offsets=(start_m * BLOCK_M, 0), 59 | block_shape=(BLOCK_M, BLOCK_DMODEL), 60 | order=(1, 0) 61 | ) 62 | K_block_ptr = tl.make_block_ptr( 63 | base=K + kv_offset, 64 | shape=(BLOCK_DMODEL, N_CTX + P_SEQ), 65 | strides=(stride_kk, stride_kn), 66 | offsets=(0, 0), 67 | block_shape=(BLOCK_DMODEL, BLOCK_N), 68 | order=(0, 1) 69 | ) 70 | V_block_ptr = tl.make_block_ptr( 71 | base=V + kv_offset, 72 | shape=(N_CTX + P_SEQ, BLOCK_DMODEL), 73 | strides=(stride_vk, stride_vn), 74 | offsets=(0, 0), 75 | block_shape=(BLOCK_N, BLOCK_DMODEL), 76 | order=(1, 0) 77 | ) 78 | 79 | # initialize offsets 80 | # initialize pointer to m and l 81 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 82 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 83 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 84 | # scale sm_scale by log_2(e) and use 85 | # 2^x instead of exp in the loop because CSE and LICM 86 | # don't work as expected with `exp` in the loop 87 | qk_scale = sm_scale * 1.44269504 88 | # load q: it will stay in SRAM throughout 89 | q = tl.load(Q_block_ptr) # , boundary_check=(1, 0), padding_option="zero") 90 | q = (q * qk_scale).to(OUT_DTYPE) 91 | # loop over k, v and update accumulator 92 | lo = 0 93 | hi = N_CTX + P_SEQ 94 | 95 | b_ptr_offsets_m = tl.arange(0, BLOCK_M) 96 | 97 | b_offset = off_hz * stride_b0h 98 | b_ptr_offsets_n_1 = (tl.arange(0, BLOCK_N) % 99 | BIAS_LAST_SIZE) + BIAS_LAST_SIZE 100 | b1 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m) 101 | * stride_b0m)[:, None] + b_ptr_offsets_n_1[None, :]) 102 | for start_n in range(lo, hi, BLOCK_N): 103 | # -- load k, v -- 104 | # , boundary_check=(0, 1), padding_option="zero") 105 | k = tl.load(K_block_ptr) 106 | # , boundary_check=(1, 0), padding_option="zero") 107 | v = tl.load(V_block_ptr) 108 | # -- compute qk --- 109 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE) 110 | qk += tl.dot(q, k) #, out_dtype=OUT_DTYPE) 111 | 112 | # -- compute rel_h[:, None] + rel_w[None, :] bias --- 113 | 114 | # Bias 115 | b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m) 116 | * stride_b0m)[:, None] + start_n // BLOCK_N) 117 | qk += ((b0 + b1) * 1.44269504) 118 | 119 | # -- compute scaling constant --- 120 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 121 | alpha = tl.math.exp2(m_i - m_i_new) 122 | p = tl.math.exp2(qk - m_i_new[:, None]) 123 | # -- scale and update acc -- 124 | acc *= alpha[:, None] 125 | acc += tl.dot(p.to(OUT_DTYPE), v) 126 | # -- update m_i and l_i -- 127 | l_i = l_i * alpha + tl.sum(p, 1) 128 | m_i = m_i_new 129 | # update pointers 130 | K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) 131 | V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 132 | 133 | # write back l and m 134 | acc = acc / l_i[:, None] 135 | 136 | # write back O 137 | O_block_ptr = tl.make_block_ptr( 138 | base=Out + q_offset, 139 | shape=(N_CTX, BLOCK_DMODEL), 140 | strides=(stride_om, stride_on), 141 | offsets=(start_m * BLOCK_M, 0), 142 | block_shape=(BLOCK_M, BLOCK_DMODEL), 143 | order=(1, 0) 144 | ) 145 | tl.store(O_block_ptr, acc.to(OUT_DTYPE)) 146 | 147 | 148 | def _autotune(configs, function): 149 | import torch.utils.benchmark as benchmark 150 | 151 | def benchmark_torch_function_in_microseconds(f, *args, **kwargs): 152 | try: 153 | f(*args, **kwargs) 154 | t0 = benchmark.Timer( 155 | stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} 156 | ) 157 | except: 158 | return None 159 | return t0.blocked_autorange().mean * 1e6 160 | 161 | best = None 162 | best_config = None 163 | for config in configs: 164 | BLOCK_M, BLOCK_N, num_warps, num_stages = config 165 | t_config = benchmark_torch_function_in_microseconds( 166 | function, BLOCK_M, BLOCK_N, num_warps, num_stages) 167 | if t_config is not None: 168 | if best is not None: 169 | if t_config < best: 170 | best = t_config 171 | best_config = config 172 | else: 173 | best = t_config 174 | best_config = config 175 | print(str(config), " :", str(t_config)) 176 | return best, best_config 177 | 178 | 179 | def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o, 180 | BLOCK_M, 181 | BLOCK_N, 182 | num_warps, 183 | num_stages): 184 | _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] 185 | assert q.size() == k.size() 186 | assert q.size() == v.size() 187 | assert q.size(-2) == rel_h_w.size(-2) 188 | assert (q.dtype == torch.bfloat16 or q.dtype == torch.float16) 189 | assert k.dtype == q.dtype 190 | assert v.dtype == k.dtype 191 | assert o.dtype == v.dtype 192 | assert rel_h_w.dtype == q.dtype 193 | assert rel_h_w.size(-1) == 128 194 | # assert rel_h_w.size(-1) == 2 * BLOCK_N 195 | 196 | grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) 197 | # print("q.shape[0] * q.shape[1]: ", q.shape[0] * q.shape[1]) 198 | P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2] 199 | assert P_SEQ == 0 200 | assert rel_h_w.is_contiguous(), str(rel_h_w.stride()) 201 | OUT_DTYPE = tl.float16 if q.dtype == torch.float16 else tl.bfloat16 202 | _fwd_kernel_aligned[grid]( 203 | q, k, v, 204 | rel_h_w, 205 | sm_scale, 206 | o, 207 | q.stride(1), q.stride(2), q.stride(3), 208 | k.stride(1), k.stride(2), k.stride(3), 209 | v.stride(1), v.stride(2), v.stride(3), 210 | o.stride(1), o.stride(2), o.stride(3), 211 | rel_h_w.stride(1), rel_h_w.stride(2), 212 | q.shape[0], 213 | q.shape[1], 214 | q.shape[2], 215 | P_SEQ, 216 | OUT_DTYPE=OUT_DTYPE, 217 | BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2), 218 | B0_NUMEL=rel_h_w.size(-1), 219 | BLOCK_M=BLOCK_M, 220 | BLOCK_N=BLOCK_N, 221 | BLOCK_DMODEL=Lk, 222 | num_warps=num_warps, 223 | num_stages=num_stages) 224 | 225 | 226 | def _load_best_configs(): 227 | device_name = torch.cuda.get_device_name() 228 | if not device_name.startswith('NVIDIA A100'): 229 | print("Warning: Custom flash attention kernels were written specifically for A100.") 230 | import importlib 231 | saved_configs = importlib.resources.files("segment_anything_fast") 232 | saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p" 233 | if not device_name.startswith('NVIDIA A100'): 234 | cwd = pathlib.Path.cwd() 235 | saved_configs = cwd / "flash_4_configs.p" 236 | print(f"We will try to read previously created kernel configurations from {saved_configs}.") 237 | print("You can disable this kernel by setting SEGMENT_ANYTHING_FAST_USE_FLASH_4=0") 238 | if saved_configs.is_file(): 239 | import pickle 240 | with open(saved_configs, 'rb') as f: 241 | print(f"Loading best configs from file {saved_configs}") 242 | return pickle.load(f) 243 | 244 | 245 | def _save_best_configs(best_configs): 246 | import importlib 247 | saved_configs = importlib.resources.files("segment_anything_fast") 248 | saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p" 249 | device_name = torch.cuda.get_device_name() 250 | if not device_name.startswith('NVIDIA A100'): 251 | saved_configs = pathlib.Path.cwd() / "flash_4_configs.p" 252 | print("Warning: Custom flash attention kernels were written specifically for A100.") 253 | print(f"Storing configs for {device_name} locally under {saved_configs}") 254 | with open(saved_configs, 'wb') as f: 255 | import pickle 256 | print(f"Saving best configs to file {saved_configs}") 257 | pickle.dump(best_configs, f) 258 | 259 | 260 | def _create_best_configs_key(q, k, v, rel_h_w, o): 261 | key = (q.size(), k.size(), v.size(), rel_h_w.size(), o.size(), 262 | q.stride(), k.stride(), v.stride(), rel_h_w.stride(), o.stride()) 263 | return key 264 | 265 | 266 | BEST_CONFIGS = None 267 | 268 | lib = torch.library.Library("customflash", "FRAGMENT") 269 | lib.define("custom_flash_aligned(Tensor q, Tensor k, Tensor v, Tensor rel_h_w, float sm_scale) -> Tensor") 270 | 271 | 272 | # All that's needed for torch.compile support 273 | @torch.library.impl(lib, "custom_flash_aligned", "Meta") 274 | def _attention_rel_h_rel_w_kernel_aligned_meta(q, k, v, rel_h_w, sm_scale): 275 | return q.contiguous() 276 | 277 | 278 | @torch.library.impl(lib, "custom_flash_aligned", "CUDA") 279 | def _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale): 280 | # This is likely not needed, but without it the kernel 281 | # is guaranteed to fail. If the inputs are already contiguous 282 | # these are cheap checks via is_contiguous and do nothing. 283 | q = q.contiguous() 284 | k = k.contiguous() 285 | v = v.contiguous() 286 | # shape constraints 287 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 288 | assert Lq == Lk and Lk == Lv 289 | assert Lk in {16, 32, 64, 128} 290 | o = torch.empty_like(q, memory_format=torch.contiguous_format) 291 | 292 | global BEST_CONFIGS 293 | if BEST_CONFIGS is None: 294 | BEST_CONFIGS = _load_best_configs() 295 | # Loading must have not been successful. Let's create a new dictionary. 296 | if BEST_CONFIGS is None: 297 | BEST_CONFIGS = {} 298 | key = _create_best_configs_key(q, k, v, rel_h_w, o) 299 | if key not in BEST_CONFIGS: 300 | print("key ", key, " not found. Running autotune. This might take a while.") 301 | import functools 302 | import itertools 303 | configs = [] 304 | for (BLOCK_M, BLOCK_N, num_warps) in itertools.product([64, 128], [64, 128], [1, 2, 4, 8]): 305 | for num_stages in range(1, num_warps + 1): 306 | configs.append((BLOCK_M, BLOCK_N, num_warps, num_stages)) 307 | print("all configs len: ", len(configs)) 308 | best, best_config = _autotune(configs, functools.partial(_attention_rel_h_rel_w_kernel_aligned_device, 309 | q, k, v, rel_h_w, sm_scale, o)) 310 | BEST_CONFIGS[key] = best_config 311 | print("Found best_config ", best_config, 312 | " with time ", best, " for key ", key) 313 | _save_best_configs(BEST_CONFIGS) 314 | best_config = BEST_CONFIGS[key] 315 | if best_config is None: 316 | return torch.tensor([]) 317 | 318 | _attention_rel_h_rel_w_kernel_aligned_device(q, 319 | k, 320 | v, 321 | rel_h_w, 322 | sm_scale, 323 | o, 324 | best_config[0], 325 | best_config[1], 326 | best_config[2], 327 | best_config[3]) 328 | 329 | return o 330 | 331 | 332 | USE_CUSTOM_KERNEL = bool(int(os.environ.get('SEGMENT_ANYTHING_FAST_USE_FLASH_4', 1))) 333 | 334 | 335 | def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_): 336 | """ 337 | Writing this as a composite allows torch.compile to fuse 338 | the needed padding into previous operations and memory 339 | allocations. 340 | """ 341 | 342 | import math 343 | sm_scale = 1. / math.sqrt(q_.size(-1)) 344 | # Check if second last dimension is multiple of 256 345 | q_size_2_padded = (((q_.size(-2) + 256 - 1) // 256) * 256) - q_.size(-2) 346 | 347 | def kernel_guards(q_, k_, v_): 348 | return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL 349 | # vit_b and vit_l 350 | # TODO: This kernel currently does not produce correct results for batch size 1 for this case 351 | if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_): 352 | rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1) 353 | o = torch.ops.customflash.custom_flash_aligned( 354 | q_, k_, v_, rel_h_w, sm_scale) 355 | if o.numel() > 0: 356 | return o 357 | # vit_h 358 | if q_size_2_padded == 0 and q_.size(-1) == 80 and kernel_guards(q_, k_, v_): 359 | # Only support multiples of 64, so need to pad 360 | q = torch.nn.functional.pad(q_, (0, 128 - 80, 0, 0), "constant", 0) 361 | k = torch.nn.functional.pad(k_, (0, 128 - 80, 0, 0), "constant", 0) 362 | v = torch.nn.functional.pad(v_, (0, 128 - 80, 0, 0), "constant", 0) 363 | rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1) 364 | o = torch.ops.customflash.custom_flash_aligned( 365 | q, k, v, rel_h_w, sm_scale) 366 | if o.numel() > 0: 367 | return o[:, :, :, :80] 368 | attn_bias = (rel_h_ + rel_w_).view(q_.size(0), q_.size(1), 369 | rel_h_.size(2), rel_h_.size(3) * rel_w_.size(4)) 370 | return torch.nn.functional.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_bias) 371 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(-3, keepdim=True) 40 | s = (x - u).pow(2).mean(-3, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | from segment_anything_fast.flash_4 import _attention_rel_h_rel_w 16 | 17 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 18 | class ImageEncoderViT(nn.Module): 19 | def __init__( 20 | self, 21 | img_size: int = 1024, 22 | patch_size: int = 16, 23 | in_chans: int = 3, 24 | embed_dim: int = 768, 25 | depth: int = 12, 26 | num_heads: int = 12, 27 | mlp_ratio: float = 4.0, 28 | out_chans: int = 256, 29 | qkv_bias: bool = True, 30 | norm_layer: Type[nn.Module] = nn.LayerNorm, 31 | act_layer: Type[nn.Module] = nn.GELU, 32 | use_abs_pos: bool = True, 33 | use_rel_pos: bool = False, 34 | rel_pos_zero_init: bool = True, 35 | window_size: int = 0, 36 | global_attn_indexes: Tuple[int, ...] = (), 37 | ) -> None: 38 | """ 39 | Args: 40 | img_size (int): Input image size. 41 | patch_size (int): Patch size. 42 | in_chans (int): Number of input image channels. 43 | embed_dim (int): Patch embedding dimension. 44 | depth (int): Depth of ViT. 45 | num_heads (int): Number of attention heads in each ViT block. 46 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 47 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 48 | norm_layer (nn.Module): Normalization layer. 49 | act_layer (nn.Module): Activation layer. 50 | use_abs_pos (bool): If True, use absolute positional embeddings. 51 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 52 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 53 | window_size (int): Window size for window attention blocks. 54 | global_attn_indexes (list): Indexes for blocks using global attention. 55 | """ 56 | super().__init__() 57 | self.img_size = img_size 58 | 59 | self.patch_embed = PatchEmbed( 60 | kernel_size=(patch_size, patch_size), 61 | stride=(patch_size, patch_size), 62 | in_chans=in_chans, 63 | embed_dim=embed_dim, 64 | ) 65 | 66 | self.pos_embed: Optional[nn.Parameter] = None 67 | if use_abs_pos: 68 | # Initialize absolute positional embedding with pretrain image size. 69 | self.pos_embed = nn.Parameter( 70 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 71 | ) 72 | 73 | self.blocks = nn.ModuleList() 74 | for i in range(depth): 75 | block = Block( 76 | dim=embed_dim, 77 | num_heads=num_heads, 78 | mlp_ratio=mlp_ratio, 79 | qkv_bias=qkv_bias, 80 | norm_layer=norm_layer, 81 | act_layer=act_layer, 82 | use_rel_pos=use_rel_pos, 83 | rel_pos_zero_init=rel_pos_zero_init, 84 | window_size=window_size if i not in global_attn_indexes else 0, 85 | input_size=(img_size // patch_size, img_size // patch_size), 86 | ) 87 | self.blocks.append(block) 88 | 89 | self.neck = nn.Sequential( 90 | nn.Conv2d( 91 | embed_dim, 92 | out_chans, 93 | kernel_size=1, 94 | bias=False, 95 | ), 96 | LayerNorm2d(out_chans), 97 | nn.Conv2d( 98 | out_chans, 99 | out_chans, 100 | kernel_size=3, 101 | padding=1, 102 | bias=False, 103 | ), 104 | LayerNorm2d(out_chans), 105 | ) 106 | 107 | def forward(self, x: torch.Tensor) -> torch.Tensor: 108 | x = self.patch_embed(x) 109 | if self.pos_embed is not None: 110 | x = x + self.pos_embed 111 | 112 | for blk in self.blocks: 113 | x = blk(x) 114 | 115 | x = self.neck(x.permute(0, 3, 1, 2)) 116 | 117 | return x 118 | 119 | 120 | class Block(nn.Module): 121 | """Transformer blocks with support of window attention and residual propagation blocks""" 122 | 123 | def __init__( 124 | self, 125 | dim: int, 126 | num_heads: int, 127 | mlp_ratio: float = 4.0, 128 | qkv_bias: bool = True, 129 | norm_layer: Type[nn.Module] = nn.LayerNorm, 130 | act_layer: Type[nn.Module] = nn.GELU, 131 | use_rel_pos: bool = False, 132 | rel_pos_zero_init: bool = True, 133 | window_size: int = 0, 134 | input_size: Optional[Tuple[int, int]] = None, 135 | ) -> None: 136 | """ 137 | Args: 138 | dim (int): Number of input channels. 139 | num_heads (int): Number of attention heads in each ViT block. 140 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 141 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 142 | norm_layer (nn.Module): Normalization layer. 143 | act_layer (nn.Module): Activation layer. 144 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 145 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 146 | window_size (int): Window size for window attention blocks. If it equals 0, then 147 | use global attention. 148 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 149 | positional parameter size. 150 | """ 151 | super().__init__() 152 | self.norm1 = norm_layer(dim) 153 | self.attn = Attention( 154 | dim, 155 | num_heads=num_heads, 156 | qkv_bias=qkv_bias, 157 | use_rel_pos=use_rel_pos, 158 | rel_pos_zero_init=rel_pos_zero_init, 159 | input_size=input_size if window_size == 0 else (window_size, window_size), 160 | ) 161 | 162 | self.norm2 = norm_layer(dim) 163 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 164 | 165 | self.window_size = window_size 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | shortcut = x 169 | x = self.norm1(x) 170 | # Window partition 171 | if self.window_size > 0: 172 | H, W = x.shape[1], x.shape[2] 173 | x, pad_hw = window_partition(x, self.window_size) 174 | 175 | x = self.attn(x) 176 | # Reverse window partition 177 | if self.window_size > 0: 178 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 179 | 180 | x = shortcut + x 181 | x = x + self.mlp(self.norm2(x)) 182 | 183 | return x 184 | 185 | 186 | class Attention(nn.Module): 187 | """Multi-head Attention block with relative position embeddings.""" 188 | 189 | def __init__( 190 | self, 191 | dim: int, 192 | num_heads: int = 8, 193 | qkv_bias: bool = True, 194 | use_rel_pos: bool = False, 195 | rel_pos_zero_init: bool = True, 196 | input_size: Optional[Tuple[int, int]] = None, 197 | ) -> None: 198 | """ 199 | Args: 200 | dim (int): Number of input channels. 201 | num_heads (int): Number of attention heads. 202 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 203 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 204 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 205 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 206 | positional parameter size. 207 | """ 208 | super().__init__() 209 | self.num_heads = num_heads 210 | head_dim = dim // num_heads 211 | self.scale = head_dim**-0.5 212 | 213 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 214 | self.proj = nn.Linear(dim, dim) 215 | 216 | self.use_rel_pos = use_rel_pos 217 | if self.use_rel_pos: 218 | assert ( 219 | input_size is not None 220 | ), "Input size must be provided if using relative positional encoding." 221 | # initialize relative positional embeddings 222 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 223 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 224 | 225 | def forward(self, x: torch.Tensor) -> torch.Tensor: 226 | B, H, W, _ = x.shape 227 | # qkv with shape (3, B, nHead, H * W, C) 228 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 229 | # q, k, v with shape (B * nHead, H * W, C) 230 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 231 | 232 | rel_h, rel_w = None, None 233 | if self.use_rel_pos: 234 | rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | q = q.view(B, self.num_heads, H * W, -1) 237 | k = k.view(B, self.num_heads, H * W, -1) 238 | v = v.view(B, self.num_heads, H * W, -1) 239 | 240 | if self.use_rel_pos: 241 | rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) 242 | rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) 243 | # attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)) 244 | # x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) 245 | x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) 246 | else: 247 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 248 | 249 | x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 250 | 251 | x = self.proj(x) 252 | 253 | return x 254 | 255 | 256 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 257 | """ 258 | Partition into non-overlapping windows with padding if needed. 259 | Args: 260 | x (tensor): input tokens with [B, H, W, C]. 261 | window_size (int): window size. 262 | 263 | Returns: 264 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 265 | (Hp, Wp): padded height and width before partition 266 | """ 267 | B, H, W, C = x.shape 268 | 269 | pad_h = (window_size - H % window_size) % window_size 270 | pad_w = (window_size - W % window_size) % window_size 271 | if pad_h > 0 or pad_w > 0: 272 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 273 | Hp, Wp = H + pad_h, W + pad_w 274 | 275 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 276 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 277 | return windows, (Hp, Wp) 278 | 279 | 280 | def window_unpartition( 281 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 282 | ) -> torch.Tensor: 283 | """ 284 | Window unpartition into original sequences and removing padding. 285 | Args: 286 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 287 | window_size (int): window size. 288 | pad_hw (Tuple): padded height and width (Hp, Wp). 289 | hw (Tuple): original height and width (H, W) before padding. 290 | 291 | Returns: 292 | x: unpartitioned sequences with [B, H, W, C]. 293 | """ 294 | Hp, Wp = pad_hw 295 | H, W = hw 296 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 297 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 298 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 299 | 300 | if Hp > H or Wp > W: 301 | x = x[:, :H, :W, :].contiguous() 302 | return x 303 | 304 | 305 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 306 | """ 307 | Get relative positional embeddings according to the relative positions of 308 | query and key sizes. 309 | Args: 310 | q_size (int): size of query q. 311 | k_size (int): size of key k. 312 | rel_pos (Tensor): relative position embeddings (L, C). 313 | 314 | Returns: 315 | Extracted positional embeddings according to relative positions. 316 | """ 317 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 318 | # Interpolate rel pos if needed. 319 | if rel_pos.shape[0] != max_rel_dist: 320 | # Interpolate rel pos. 321 | rel_pos_resized = F.interpolate( 322 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 323 | size=max_rel_dist, 324 | mode="linear", 325 | ) 326 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 327 | else: 328 | rel_pos_resized = rel_pos 329 | 330 | # Scale the coords with short length if shapes for q and k are different. 331 | q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) 332 | k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) 333 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 334 | 335 | return rel_pos_resized[relative_coords.long()] 336 | 337 | 338 | def add_decomposed_rel_pos( 339 | q: torch.Tensor, 340 | rel_pos_h: torch.Tensor, 341 | rel_pos_w: torch.Tensor, 342 | q_size: Tuple[int, int], 343 | k_size: Tuple[int, int], 344 | ) -> torch.Tensor: 345 | """ 346 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 347 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 348 | Args: 349 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 350 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 351 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 352 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 353 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 354 | 355 | Returns: 356 | attn (Tensor): attention map with added relative positional embeddings. 357 | """ 358 | q_h, q_w = q_size 359 | k_h, k_w = k_size 360 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 361 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 362 | 363 | B, _, dim = q.shape 364 | r_q = q.reshape(B, q_h, q_w, dim) 365 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 366 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 367 | rel_h = rel_h.unsqueeze(-1) 368 | rel_w = rel_w.unsqueeze(-2) 369 | rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) 370 | rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) 371 | 372 | return rel_h, rel_w 373 | 374 | 375 | class PatchEmbed(nn.Module): 376 | """ 377 | Image to Patch Embedding. 378 | """ 379 | 380 | def __init__( 381 | self, 382 | kernel_size: Tuple[int, int] = (16, 16), 383 | stride: Tuple[int, int] = (16, 16), 384 | padding: Tuple[int, int] = (0, 0), 385 | in_chans: int = 3, 386 | embed_dim: int = 768, 387 | ) -> None: 388 | """ 389 | Args: 390 | kernel_size (Tuple): kernel size of the projection layer. 391 | stride (Tuple): stride of the projection layer. 392 | padding (Tuple): padding size of the projection layer. 393 | in_chans (int): Number of input image channels. 394 | embed_dim (int): Patch embedding dimension. 395 | """ 396 | super().__init__() 397 | 398 | self.proj = nn.Conv2d( 399 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 400 | ) 401 | 402 | def forward(self, x: torch.Tensor) -> torch.Tensor: 403 | x = self.proj(x) 404 | # B C H W -> B H W C 405 | x = x.permute(0, 2, 3, 1) 406 | return x 407 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | 95 | self_dtype = self.iou_prediction_head.layers[0].weight.dtype 96 | if sparse_prompt_embeddings.is_nested: 97 | assert dense_prompt_embeddings.is_nested 98 | assert multimask_output 99 | masks, iou_pred = self.predict_masks_nested( 100 | image_embeddings=image_embeddings, 101 | image_pe=image_pe, 102 | sparse_prompt_embeddings=sparse_prompt_embeddings.to(self_dtype), 103 | dense_prompt_embeddings=dense_prompt_embeddings.to(self_dtype), 104 | ) 105 | return masks, iou_pred 106 | else: 107 | masks, iou_pred = self.predict_masks( 108 | image_embeddings=image_embeddings, 109 | image_pe=image_pe, 110 | sparse_prompt_embeddings=sparse_prompt_embeddings.to(self_dtype), 111 | dense_prompt_embeddings=dense_prompt_embeddings.to(self_dtype), 112 | ) 113 | 114 | # Select the correct mask or masks for output 115 | if multimask_output: 116 | mask_slice = slice(1, None) 117 | else: 118 | mask_slice = slice(0, 1) 119 | masks = masks[:, mask_slice, :, :] 120 | iou_pred = iou_pred[:, mask_slice] 121 | 122 | if sparse_prompt_embeddings.is_nested: 123 | return masks, iou_pred, offsets 124 | 125 | if sparse_prompt_embeddings.dtype != self_dtype: 126 | return masks.to(sparse_prompt_embeddings.dtype), iou_pred.to(sparse_prompt_embeddings.dtype) 127 | # Prepare output 128 | return masks, iou_pred 129 | 130 | def predict_masks( 131 | self, 132 | image_embeddings: torch.Tensor, 133 | image_pe: torch.Tensor, 134 | sparse_prompt_embeddings: torch.Tensor, 135 | dense_prompt_embeddings: torch.Tensor, 136 | ) -> Tuple[torch.Tensor, torch.Tensor]: 137 | """Predicts masks. See 'forward' for more details.""" 138 | # Concatenate output tokens 139 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 140 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 141 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 142 | 143 | # Expand per-image data in batch direction to be per-mask 144 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 145 | src = src + dense_prompt_embeddings 146 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 147 | b, c, h, w = src.shape 148 | 149 | # Run the transformer 150 | hs, src = self.transformer(src, pos_src, tokens) 151 | iou_token_out = hs[:, 0, :] 152 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 153 | 154 | # Upscale mask embeddings and predict masks using the mask tokens 155 | src = src.transpose(1, 2).view(b, c, h, w) 156 | upscaled_embedding = self.output_upscaling(src) 157 | hyper_in_list: List[torch.Tensor] = [] 158 | for i in range(self.num_mask_tokens): 159 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 160 | hyper_in = torch.stack(hyper_in_list, dim=1) 161 | b, c, h, w = upscaled_embedding.shape 162 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 163 | 164 | # Generate mask quality predictions 165 | iou_pred = self.iou_prediction_head(iou_token_out) 166 | 167 | return masks, iou_pred 168 | 169 | def predict_masks_nested( 170 | self, 171 | image_embeddings: torch.Tensor, 172 | image_pe: torch.Tensor, 173 | sparse_prompt_embeddings: torch.Tensor, 174 | dense_prompt_embeddings: torch.Tensor, 175 | ) -> Tuple[torch.Tensor, torch.Tensor]: 176 | """Predicts masks. See 'forward' for more details.""" 177 | # Concatenate output tokens 178 | output_tokens = ( 179 | torch.zeros_like(sparse_prompt_embeddings).prod(dim=2, keepdim=True) + 180 | torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)) 181 | tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) 182 | 183 | src = dense_prompt_embeddings + image_embeddings.unsqueeze(1) 184 | pos_src = torch.zeros_like(src) + image_pe 185 | h, w = src.shape[-2:] 186 | 187 | # Run the transformer 188 | hs, src = self.transformer(src, pos_src, tokens) 189 | iou_token_out = hs[..., 0, :] 190 | mask_tokens_out = hs[..., 1 : (1 + self.num_mask_tokens), :] 191 | 192 | # Upscale mask embeddings and predict masks using the mask tokens 193 | src = src.transpose(-2, -1).unflatten(-1, (h, w)) 194 | upscaled_embedding = self.output_upscaling(src) 195 | hyper_in_list: List[torch.Tensor] = [] 196 | for i in range(self.num_mask_tokens): 197 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[..., i, :])) 198 | hyper_in = torch.stack(hyper_in_list, dim=-2) 199 | h, w = upscaled_embedding.shape[-2:] 200 | masks = (hyper_in @ upscaled_embedding.flatten(-2)).unflatten(-1, (h, w)) 201 | 202 | # Generate mask quality predictions 203 | iou_pred = self.iou_prediction_head(iou_token_out) 204 | 205 | return masks, iou_pred 206 | 207 | 208 | # Lightly adapted from 209 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 210 | class MLP(nn.Module): 211 | def __init__( 212 | self, 213 | input_dim: int, 214 | hidden_dim: int, 215 | output_dim: int, 216 | num_layers: int, 217 | sigmoid_output: bool = False, 218 | ) -> None: 219 | super().__init__() 220 | self.num_layers = num_layers 221 | h = [hidden_dim] * (num_layers - 1) 222 | self.layers = nn.ModuleList( 223 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 224 | ) 225 | self.sigmoid_output = sigmoid_output 226 | 227 | def forward(self, x): 228 | for i, layer in enumerate(self.layers): 229 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 230 | if self.sigmoid_output: 231 | x = F.sigmoid(x) 232 | return x 233 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | if points.is_nested: 83 | points = torch.cat([points, torch.zeros_like(points)], dim=2) 84 | labels = torch.cat([labels, -torch.ones_like(labels)], dim=2) 85 | else: 86 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device, dtype=points.dtype) 87 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device, dtype=points.dtype) 88 | points = torch.cat([points, padding_point], dim=1) 89 | labels = torch.cat([labels, padding_label], dim=1) 90 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 91 | point_embedding = torch.where((labels == -1).unsqueeze(-1).expand_as(point_embedding), 92 | torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, 93 | point_embedding) 94 | point_embedding = torch.where((labels == 0).unsqueeze(-1).expand_as(point_embedding), 95 | point_embedding + self.point_embeddings[0].weight, 96 | point_embedding) 97 | point_embedding = torch.where((labels == 1).unsqueeze(-1).expand_as(point_embedding), 98 | point_embedding + self.point_embeddings[1].weight, 99 | point_embedding) 100 | return point_embedding 101 | 102 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 103 | """Embeds box prompts.""" 104 | boxes = boxes + 0.5 # Shift to center of pixel 105 | coords = boxes.reshape(-1, 2, 2) 106 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 109 | return corner_embedding 110 | 111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 112 | """Embeds mask inputs.""" 113 | mask_embedding = self.mask_downscaling(masks) 114 | return mask_embedding 115 | 116 | def _get_batch_size( 117 | self, 118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 119 | boxes: Optional[torch.Tensor], 120 | masks: Optional[torch.Tensor], 121 | ) -> int: 122 | """ 123 | Gets the batch size of the output given the batch size of the input prompts. 124 | """ 125 | if points is not None: 126 | return points[0].shape[0] 127 | elif boxes is not None: 128 | return boxes.shape[0] 129 | elif masks is not None: 130 | return masks.shape[0] 131 | else: 132 | return 1 133 | 134 | def _get_device(self) -> torch.device: 135 | return self.point_embeddings[0].weight.device 136 | 137 | def forward( 138 | self, 139 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 140 | boxes: Optional[torch.Tensor], 141 | masks: Optional[torch.Tensor], 142 | ) -> Tuple[torch.Tensor, torch.Tensor]: 143 | """ 144 | Embeds different types of prompts, returning both sparse and dense 145 | embeddings. 146 | 147 | Arguments: 148 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 149 | and labels to embed. 150 | boxes (torch.Tensor or none): boxes to embed 151 | masks (torch.Tensor or none): masks to embed 152 | 153 | Returns: 154 | torch.Tensor: sparse embeddings for the points and boxes, with shape 155 | BxNx(embed_dim), where N is determined by the number of input points 156 | and boxes. 157 | torch.Tensor: dense embeddings for the masks, in the shape 158 | Bx(embed_dim)x(embed_H)x(embed_W) 159 | """ 160 | bs = self._get_batch_size(points, boxes, masks) 161 | if points is not None: 162 | coords, labels = points 163 | sparse_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 164 | if boxes is not None: 165 | sparse_embeddings = self._embed_boxes(boxes) 166 | 167 | if masks is not None: 168 | dense_embeddings = self._embed_masks(masks) 169 | else: 170 | if sparse_embeddings.is_nested: 171 | embed_weight = ( 172 | self.no_mask_embed.weight.squeeze(0).unsqueeze(-1).unsqueeze(-1).expand( 173 | -1, self.image_embedding_size[0], self.image_embedding_size[1]) 174 | ) 175 | dense_embeddings = ( 176 | torch.zeros_like(sparse_embeddings.unsqueeze(2).prod(dim=-1, keepdim=True).prod( 177 | dim=-2, keepdim=True)) + embed_weight 178 | ) 179 | else: 180 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 181 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]) 182 | 183 | return sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings 184 | 185 | 186 | class PositionEmbeddingRandom(nn.Module): 187 | """ 188 | Positional encoding using random spatial frequencies. 189 | """ 190 | 191 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 192 | super().__init__() 193 | if scale is None or scale <= 0.0: 194 | scale = 1.0 195 | self.register_buffer( 196 | "positional_encoding_gaussian_matrix", 197 | scale * torch.randn((2, num_pos_feats)), 198 | ) 199 | 200 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 201 | """Positionally encode points that are normalized to [0,1].""" 202 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 203 | coords = 2 * coords - 1 204 | coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) @ self.positional_encoding_gaussian_matrix 205 | coords = 2 * np.pi * coords 206 | # outputs d_1 x ... x d_n x C shape 207 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 208 | 209 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 210 | """Generate positional encoding for a grid of the specified size.""" 211 | h, w = size 212 | device: Any = self.positional_encoding_gaussian_matrix.device 213 | grid = torch.ones((h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype) 214 | y_embed = grid.cumsum(dim=0) - 0.5 215 | x_embed = grid.cumsum(dim=1) - 0.5 216 | y_embed = y_embed / h 217 | x_embed = x_embed / w 218 | 219 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 220 | return pe.permute(2, 0, 1) # C x H x W 221 | 222 | def forward_with_coords( 223 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 224 | ) -> torch.Tensor: 225 | # Take advantage of square image size to simplify normalization 226 | assert image_size[1] == image_size[0] 227 | return self._pe_encoding(coords_input / image_size[1]) # B x N x C 228 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything_fast/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | image_embedding = image_embedding.flatten(-2).transpose(-1, -2) 83 | image_pe = image_pe.flatten(-2).transpose(-1, -2) 84 | 85 | # Prepare queries 86 | queries = point_embedding 87 | keys = image_embedding 88 | 89 | # Apply transformer blocks and final layernorm 90 | for layer in self.layers: 91 | queries, keys = layer( 92 | queries=queries, 93 | keys=keys, 94 | query_pe=point_embedding, 95 | key_pe=image_pe, 96 | ) 97 | 98 | # Apply the final attention layer from the points to the image 99 | q = queries + point_embedding 100 | k = keys + image_pe 101 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 102 | queries = queries + attn_out 103 | queries = self.norm_final_attn(queries) 104 | 105 | return queries, keys 106 | 107 | 108 | class TwoWayAttentionBlock(nn.Module): 109 | def __init__( 110 | self, 111 | embedding_dim: int, 112 | num_heads: int, 113 | mlp_dim: int = 2048, 114 | activation: Type[nn.Module] = nn.ReLU, 115 | attention_downsample_rate: int = 2, 116 | skip_first_layer_pe: bool = False, 117 | ) -> None: 118 | """ 119 | A transformer block with four layers: (1) self-attention of sparse 120 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 121 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 122 | inputs. 123 | 124 | Arguments: 125 | embedding_dim (int): the channel dimension of the embeddings 126 | num_heads (int): the number of heads in the attention layers 127 | mlp_dim (int): the hidden dimension of the mlp block 128 | activation (nn.Module): the activation of the mlp block 129 | skip_first_layer_pe (bool): skip the PE on the first layer 130 | """ 131 | super().__init__() 132 | self.self_attn = Attention(embedding_dim, num_heads) 133 | self.norm1 = nn.LayerNorm(embedding_dim) 134 | 135 | self.cross_attn_token_to_image = Attention( 136 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 137 | ) 138 | self.norm2 = nn.LayerNorm(embedding_dim) 139 | 140 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 141 | self.norm3 = nn.LayerNorm(embedding_dim) 142 | 143 | self.norm4 = nn.LayerNorm(embedding_dim) 144 | self.cross_attn_image_to_token = Attention( 145 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 146 | ) 147 | 148 | self.skip_first_layer_pe = skip_first_layer_pe 149 | 150 | def forward( 151 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 152 | ) -> Tuple[Tensor, Tensor]: 153 | # Self attention block 154 | if self.skip_first_layer_pe: 155 | queries = self.self_attn(q=queries, k=queries, v=queries) 156 | else: 157 | q = queries + query_pe 158 | attn_out = self.self_attn(q=q, k=q, v=queries) 159 | queries = queries + attn_out 160 | queries = self.norm1(queries) 161 | 162 | # Cross attention block, tokens attending to image embedding 163 | q = queries + query_pe 164 | k = keys + key_pe 165 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 166 | queries = queries + attn_out 167 | queries = self.norm2(queries) 168 | 169 | # MLP block 170 | mlp_out = self.mlp(queries) 171 | queries = queries + mlp_out 172 | queries = self.norm3(queries) 173 | 174 | # Cross attention block, image embedding attending to tokens 175 | q = queries + query_pe 176 | k = keys + key_pe 177 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 178 | keys = keys + attn_out 179 | keys = self.norm4(keys) 180 | 181 | return queries, keys 182 | 183 | 184 | class Attention(nn.Module): 185 | """ 186 | An attention layer that allows for downscaling the size of the embedding 187 | after projection to queries, keys, and values. 188 | """ 189 | 190 | def __init__( 191 | self, 192 | embedding_dim: int, 193 | num_heads: int, 194 | downsample_rate: int = 1, 195 | ) -> None: 196 | super().__init__() 197 | self.embedding_dim = embedding_dim 198 | self.internal_dim = embedding_dim // downsample_rate 199 | self.num_heads = num_heads 200 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 201 | 202 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 203 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 206 | 207 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 208 | x = x.unflatten(-1, (num_heads, -1)) 209 | return x.transpose(-3, -2) # B... x N_heads x N_tokens x C_per_head 210 | 211 | def _recombine_heads(self, x: Tensor) -> Tensor: 212 | x = x.transpose(-3, -2) 213 | return x.flatten(-2) 214 | 215 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 216 | # Input projections 217 | q = self.q_proj(q) 218 | k = self.k_proj(k) 219 | v = self.v_proj(v) 220 | 221 | # Separate into heads 222 | q = self._separate_heads(q, self.num_heads) 223 | k = self._separate_heads(k, self.num_heads) 224 | v = self._separate_heads(v, self.num_heads) 225 | 226 | # Attention 227 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v) 228 | 229 | # Get output 230 | out = self._recombine_heads(out) 231 | out = self.out_proj(out) 232 | 233 | return out 234 | -------------------------------------------------------------------------------- /segment_anything_fast/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | model_dtype = self.model.mask_decoder.iou_prediction_head.layers[0].weight.dtype 90 | self.features = self.model.image_encoder(input_image.to(model_dtype)) 91 | self.is_image_set = True 92 | 93 | def predict( 94 | self, 95 | point_coords: Optional[np.ndarray] = None, 96 | point_labels: Optional[np.ndarray] = None, 97 | box: Optional[np.ndarray] = None, 98 | mask_input: Optional[np.ndarray] = None, 99 | multimask_output: bool = True, 100 | return_logits: bool = False, 101 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 102 | """ 103 | Predict masks for the given input prompts, using the currently set image. 104 | 105 | Arguments: 106 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 107 | model. Each point is in (X,Y) in pixels. 108 | point_labels (np.ndarray or None): A length N array of labels for the 109 | point prompts. 1 indicates a foreground point and 0 indicates a 110 | background point. 111 | box (np.ndarray or None): A length 4 array given a box prompt to the 112 | model, in XYXY format. 113 | mask_input (np.ndarray): A low resolution mask input to the model, typically 114 | coming from a previous prediction iteration. Has form 1xHxW, where 115 | for SAM, H=W=256. 116 | multimask_output (bool): If true, the model will return three masks. 117 | For ambiguous input prompts (such as a single click), this will often 118 | produce better masks than a single prediction. If only a single 119 | mask is needed, the model's predicted quality score can be used 120 | to select the best mask. For non-ambiguous prompts, such as multiple 121 | input prompts, multimask_output=False can give better results. 122 | return_logits (bool): If true, returns un-thresholded masks logits 123 | instead of a binary mask. 124 | 125 | Returns: 126 | (np.ndarray): The output masks in CxHxW format, where C is the 127 | number of masks, and (H, W) is the original image size. 128 | (np.ndarray): An array of length C containing the model's 129 | predictions for the quality of each mask. 130 | (np.ndarray): An array of shape CxHxW, where C is the number 131 | of masks and H=W=256. These low resolution logits can be passed to 132 | a subsequent iteration as mask input. 133 | """ 134 | if not self.is_image_set: 135 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 136 | 137 | # Transform input prompts 138 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 139 | if point_coords is not None: 140 | assert ( 141 | point_labels is not None 142 | ), "point_labels must be supplied if point_coords is supplied." 143 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 144 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 145 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 146 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 147 | if box is not None: 148 | box = self.transform.apply_boxes(box, self.original_size) 149 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 150 | box_torch = box_torch[None, :] 151 | if mask_input is not None: 152 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 153 | mask_input_torch = mask_input_torch[None, :, :, :] 154 | 155 | masks, iou_predictions, low_res_masks = self.predict_torch( 156 | coords_torch, 157 | labels_torch, 158 | box_torch, 159 | mask_input_torch, 160 | multimask_output, 161 | return_logits=return_logits, 162 | ) 163 | 164 | masks_np = masks[0].detach().cpu().numpy() 165 | iou_predictions_np = iou_predictions[0].detach().cpu().float().numpy() 166 | low_res_masks_np = low_res_masks[0].detach().cpu().float().numpy() 167 | return masks_np, iou_predictions_np, low_res_masks_np 168 | 169 | @torch.no_grad() 170 | def predict_torch( 171 | self, 172 | point_coords: Optional[torch.Tensor], 173 | point_labels: Optional[torch.Tensor], 174 | boxes: Optional[torch.Tensor] = None, 175 | mask_input: Optional[torch.Tensor] = None, 176 | multimask_output: bool = True, 177 | return_logits: bool = False, 178 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 179 | """ 180 | Predict masks for the given input prompts, using the currently set image. 181 | Input prompts are batched torch tensors and are expected to already be 182 | transformed to the input frame using ResizeLongestSide. 183 | 184 | Arguments: 185 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 186 | model. Each point is in (X,Y) in pixels. 187 | point_labels (torch.Tensor or None): A BxN array of labels for the 188 | point prompts. 1 indicates a foreground point and 0 indicates a 189 | background point. 190 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 191 | model, in XYXY format. 192 | mask_input (np.ndarray): A low resolution mask input to the model, typically 193 | coming from a previous prediction iteration. Has form Bx1xHxW, where 194 | for SAM, H=W=256. Masks returned by a previous iteration of the 195 | predict method do not need further transformation. 196 | multimask_output (bool): If true, the model will return three masks. 197 | For ambiguous input prompts (such as a single click), this will often 198 | produce better masks than a single prediction. If only a single 199 | mask is needed, the model's predicted quality score can be used 200 | to select the best mask. For non-ambiguous prompts, such as multiple 201 | input prompts, multimask_output=False can give better results. 202 | return_logits (bool): If true, returns un-thresholded masks logits 203 | instead of a binary mask. 204 | 205 | Returns: 206 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 207 | number of masks, and (H, W) is the original image size. 208 | (torch.Tensor): An array of shape BxC containing the model's 209 | predictions for the quality of each mask. 210 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 211 | of masks and H=W=256. These low res logits can be passed to 212 | a subsequent iteration as mask input. 213 | """ 214 | if not self.is_image_set: 215 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 216 | 217 | if point_coords is not None: 218 | points = (point_coords, point_labels) 219 | else: 220 | points = None 221 | 222 | # Embed prompts 223 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 224 | points=points, 225 | boxes=boxes, 226 | masks=mask_input, 227 | ) 228 | 229 | # Predict masks 230 | low_res_masks, iou_predictions = self.model.mask_decoder( 231 | image_embeddings=self.features, 232 | image_pe=self.model.prompt_encoder.get_dense_pe(), 233 | sparse_prompt_embeddings=sparse_embeddings, 234 | dense_prompt_embeddings=dense_embeddings, 235 | multimask_output=multimask_output, 236 | ) 237 | 238 | if low_res_masks.is_nested: 239 | masks = [] 240 | for lrm, input_size, original_size in zip(low_res_masks.unbind(), self.input_sizes, self.original_sizes, strict=True): 241 | # Upscale the masks to the original image resolution 242 | m = self.model.postprocess_masks(lrm, input_size, original_size) 243 | masks.append(m) 244 | masks = torch.nested.nested_tensor(masks, layout=torch.strided) 245 | else: 246 | # Upscale the masks to the original image resolution 247 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 248 | 249 | if not return_logits: 250 | masks = masks > self.model.mask_threshold 251 | 252 | return masks, iou_predictions, low_res_masks 253 | 254 | def get_image_embedding(self) -> torch.Tensor: 255 | """ 256 | Returns the image embeddings for the currently set image, with 257 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 258 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 259 | """ 260 | if not self.is_image_set: 261 | raise RuntimeError( 262 | "An image must be set with .set_image(...) to generate an embedding." 263 | ) 264 | assert self.features is not None, "Features must exist if an image has been set." 265 | return self.features 266 | 267 | @property 268 | def device(self) -> torch.device: 269 | return self.model.device 270 | 271 | def reset_image(self) -> None: 272 | """Resets the currently set image.""" 273 | self.is_image_set = False 274 | self.features = None 275 | self.orig_h = None 276 | self.orig_w = None 277 | self.input_h = None 278 | self.input_w = None 279 | -------------------------------------------------------------------------------- /segment_anything_fast/sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor 3 | SparseSemiStructuredTensor._FORCE_CUTLASS = True 4 | 5 | # Sparsity helper functions 6 | def apply_fake_sparsity(model): 7 | """ 8 | This function simulates 2:4 sparsity on all linear layers in a model. 9 | It uses the torch.ao.pruning flow. 10 | """ 11 | # torch.ao.pruning flow 12 | from torch.ao.pruning import WeightNormSparsifier 13 | sparse_config = [] 14 | for name, mod in model.named_modules(): 15 | if isinstance(mod, torch.nn.Linear): 16 | sparse_config.append({"tensor_fqn": f"{name}.weight"}) 17 | 18 | sparsifier = WeightNormSparsifier(sparsity_level=1.0, 19 | sparse_block_shape=(1,4), 20 | zeros_per_block=2) 21 | sparsifier.prepare(model, sparse_config) 22 | sparsifier.step() 23 | 24 | sparsifier.step() 25 | sparsifier.squash_mask() 26 | 27 | 28 | def apply_sparse(model): 29 | apply_fake_sparsity(model) 30 | for name, mod in model.named_modules(): 31 | if isinstance(mod, torch.nn.Linear): 32 | mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) 33 | -------------------------------------------------------------------------------- /segment_anything_fast/tools.py: -------------------------------------------------------------------------------- 1 | def replace_with_custom_fn_if_matches_filter( 2 | model, replacement_fn, filter_fn, cur_fqn='' 3 | ) -> None: 4 | """ 5 | For each `child` in `model`, replaces it with `replacement_fn(child)` 6 | if `filter_fn(child)` is `True` 7 | """ 8 | name_to_child = dict(model.named_children()) 9 | for name, child in name_to_child.items(): 10 | if cur_fqn == '': 11 | new_fqn = name 12 | else: 13 | new_fqn = f'{cur_fqn}.{name}' 14 | if filter_fn(child, new_fqn): 15 | new_child = replacement_fn(child) 16 | setattr(model, name, new_child) 17 | else: 18 | replace_with_custom_fn_if_matches_filter( 19 | child, replacement_fn, filter_fn, new_fqn) 20 | 21 | 22 | def apply_eval_dtype_predictor(predictor, dtype=None): 23 | 24 | def prep_model(model, dtype): 25 | if dtype is not None: 26 | return model.eval().to(dtype) 27 | return model.eval() 28 | 29 | predictor.model.image_encoder = prep_model( 30 | predictor.model.image_encoder, dtype) 31 | predictor.model.prompt_encoder = prep_model( 32 | predictor.model.prompt_encoder, dtype) 33 | predictor.model.mask_decoder = prep_model( 34 | predictor.model.mask_decoder, dtype) 35 | 36 | return predictor 37 | -------------------------------------------------------------------------------- /segment_anything_fast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything_fast/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().float().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: 107 | """ 108 | Encodes masks to an uncompressed RLE, in the format expected by 109 | pycoco tools. 110 | """ 111 | # Put in fortran order and flatten h,w 112 | b, h, w = tensor.shape 113 | tensor = tensor.permute(0, 2, 1).flatten(1) 114 | 115 | # Compute change indices 116 | diff = tensor[:, 1:] ^ tensor[:, :-1] 117 | a = torch.tensor([[True]]) 118 | if diff.is_cuda: 119 | a = a.pin_memory().cuda() 120 | a = a.expand_as(diff.narrow(1, 0, 1)) 121 | diff = torch.cat([a, diff, a], dim=1) 122 | change_indices = diff.nonzero() 123 | 124 | alt_lens = diff.sum(dim=1).tolist() 125 | 126 | all_cur_idx = change_indices[:, 1] 127 | all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx 128 | all_btw_idx = all_btw_idx.detach().cpu().tolist() 129 | 130 | # Encode run length 131 | out = [] 132 | counts_init = (tensor[:, 0] == 0).tolist() 133 | offset = 0 134 | for i, ci in zip(range(b), counts_init): 135 | btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1] 136 | offset += alt_lens[i] 137 | counts = [] if ci else [0] 138 | counts.extend(btw_idxs) 139 | out.append({"size": [h, w], "counts": counts}) 140 | 141 | return out 142 | 143 | 144 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 145 | """ 146 | Encodes masks to an uncompressed RLE, in the format expected by 147 | pycoco tools. 148 | """ 149 | # Put in fortran order and flatten h,w 150 | b, h, w = tensor.shape 151 | tensor = tensor.permute(0, 2, 1).flatten(1) 152 | 153 | # Compute change indices 154 | diff = tensor[:, 1:] ^ tensor[:, :-1] 155 | change_indices = diff.nonzero() 156 | 157 | # Encode run length 158 | out = [] 159 | for i in range(b): 160 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 161 | cur_idxs = torch.cat( 162 | [ 163 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 164 | cur_idxs + 1, 165 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 166 | ] 167 | ) 168 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 169 | counts = [] if tensor[i, 0] == 0 else [0] 170 | counts.extend(btw_idxs.detach().cpu().tolist()) 171 | out.append({"size": [h, w], "counts": counts}) 172 | return out 173 | 174 | 175 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 176 | """Compute a binary mask from an uncompressed RLE.""" 177 | h, w = rle["size"] 178 | mask = np.empty(h * w, dtype=bool) 179 | idx = 0 180 | parity = False 181 | for count in rle["counts"]: 182 | mask[idx : idx + count] = parity 183 | idx += count 184 | parity ^= True 185 | mask = mask.reshape(w, h) 186 | return mask.transpose() # Put in C order 187 | 188 | 189 | def area_from_rle(rle: Dict[str, Any]) -> int: 190 | return sum(rle["counts"][1::2]) 191 | 192 | 193 | def calculate_stability_score( 194 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 195 | ) -> torch.Tensor: 196 | """ 197 | Computes the stability score for a batch of masks. The stability 198 | score is the IoU between the binary masks obtained by thresholding 199 | the predicted mask logits at high and low values. 200 | """ 201 | # One mask is always contained inside the other. 202 | # Save memory by preventing unnecessary cast to torch.int64 203 | intersections = ( 204 | (masks > (mask_threshold + threshold_offset)) 205 | .sum(-1, dtype=torch.int16) 206 | .sum(-1, dtype=torch.int32) 207 | ) 208 | unions = ( 209 | (masks > (mask_threshold - threshold_offset)) 210 | .sum(-1, dtype=torch.int16) 211 | .sum(-1, dtype=torch.int32) 212 | ) 213 | return intersections / unions 214 | 215 | 216 | def build_point_grid(n_per_side: int) -> np.ndarray: 217 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 218 | offset = 1 / (2 * n_per_side) 219 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 220 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 221 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 222 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 223 | return points 224 | 225 | 226 | def build_all_layer_point_grids( 227 | n_per_side: int, n_layers: int, scale_per_layer: int 228 | ) -> List[np.ndarray]: 229 | """Generates point grids for all crop layers.""" 230 | points_by_layer = [] 231 | for i in range(n_layers + 1): 232 | n_points = int(n_per_side / (scale_per_layer**i)) 233 | points_by_layer.append(build_point_grid(n_points)) 234 | return points_by_layer 235 | 236 | 237 | def generate_crop_boxes( 238 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 239 | ) -> Tuple[List[List[int]], List[int]]: 240 | """ 241 | Generates a list of crop boxes of different sizes. Each layer 242 | has (2**i)**2 boxes for the ith layer. 243 | """ 244 | crop_boxes, layer_idxs = [], [] 245 | im_h, im_w = im_size 246 | short_side = min(im_h, im_w) 247 | 248 | # Original image 249 | crop_boxes.append([0, 0, im_w, im_h]) 250 | layer_idxs.append(0) 251 | 252 | def crop_len(orig_len, n_crops, overlap): 253 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 254 | 255 | for i_layer in range(n_layers): 256 | n_crops_per_side = 2 ** (i_layer + 1) 257 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 258 | 259 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 260 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 261 | 262 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 263 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 264 | 265 | # Crops in XYWH format 266 | for x0, y0 in product(crop_box_x0, crop_box_y0): 267 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 268 | crop_boxes.append(box) 269 | layer_idxs.append(i_layer + 1) 270 | 271 | return crop_boxes, layer_idxs 272 | 273 | 274 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 275 | x0, y0, _, _ = crop_box 276 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 277 | # Check if boxes has a channel dimension 278 | if len(boxes.shape) == 3: 279 | offset = offset.unsqueeze(1) 280 | return boxes + offset 281 | 282 | 283 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 284 | x0, y0, _, _ = crop_box 285 | offset = torch.tensor([[x0, y0]], device=points.device) 286 | # Check if points has a channel dimension 287 | if len(points.shape) == 3: 288 | offset = offset.unsqueeze(1) 289 | return points + offset 290 | 291 | 292 | def uncrop_masks( 293 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 294 | ) -> torch.Tensor: 295 | x0, y0, x1, y1 = crop_box 296 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 297 | return masks 298 | # Coordinate transform masks 299 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 300 | pad = (x0, pad_x - x0, y0, pad_y - y0) 301 | return torch.nn.functional.pad(masks, pad, value=0) 302 | 303 | 304 | def remove_small_regions( 305 | mask: np.ndarray, area_thresh: float, mode: str 306 | ) -> Tuple[np.ndarray, bool]: 307 | """ 308 | Removes small disconnected regions and holes in a mask. Returns the 309 | mask and an indicator of if the mask has been modified. 310 | """ 311 | import cv2 # type: ignore 312 | 313 | assert mode in ["holes", "islands"] 314 | correct_holes = mode == "holes" 315 | working_mask = (correct_holes ^ mask).astype(np.uint8) 316 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 317 | sizes = stats[:, -1][1:] # Row 0 is background label 318 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 319 | if len(small_regions) == 0: 320 | return mask, False 321 | fill_labels = [0] + small_regions 322 | if not correct_holes: 323 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 324 | # If every region is below threshold, keep largest 325 | if len(fill_labels) == 0: 326 | fill_labels = [int(np.argmax(sizes)) + 1] 327 | mask = np.isin(regions, fill_labels) 328 | return mask, True 329 | 330 | 331 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 332 | from pycocotools import mask as mask_utils # type: ignore 333 | 334 | h, w = uncompressed_rle["size"] 335 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 336 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 337 | return rle 338 | 339 | 340 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 341 | """ 342 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 343 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 344 | """ 345 | # torch.max below raises an error on empty inputs, just skip in this case 346 | if torch.numel(masks) == 0: 347 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 348 | 349 | # Normalize shape to CxHxW 350 | shape = masks.shape 351 | h, w = shape[-2:] 352 | if len(shape) > 2: 353 | masks = masks.flatten(0, -3) 354 | else: 355 | masks = masks.unsqueeze(0) 356 | 357 | # Get top and bottom edges 358 | in_height, _ = torch.max(masks, dim=-1) 359 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 360 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 361 | in_height_coords = in_height_coords + h * (~in_height) 362 | top_edges, _ = torch.min(in_height_coords, dim=-1) 363 | 364 | # Get left and right edges 365 | in_width, _ = torch.max(masks, dim=-2) 366 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 367 | right_edges, _ = torch.max(in_width_coords, dim=-1) 368 | in_width_coords = in_width_coords + w * (~in_width) 369 | left_edges, _ = torch.min(in_width_coords, dim=-1) 370 | 371 | # If the mask is empty the right edge will be to the left of the left edge. 372 | # Replace these boxes with [0, 0, 0, 0] 373 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 374 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 375 | out = out * (~empty_filter).unsqueeze(-1) 376 | 377 | # Return to original shape 378 | if len(shape) > 2: 379 | out = out.reshape(*shape[:-2], 4) 380 | else: 381 | out = out[0] 382 | 383 | return out 384 | -------------------------------------------------------------------------------- /segment_anything_fast/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything_fast/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | packages = find_packages() 4 | print("packages: ", packages) 5 | setup( 6 | name='pytorch-labs-segment-anything-fast', 7 | version='0.2', 8 | packages=packages, 9 | install_requires=[ 10 | 'torch>=2.2.0.dev20231026', 11 | 'torchvision>=0.17.0.dev20231026', 12 | 'diskcache', 13 | 'pycocotools', 14 | 'scipy', 15 | 'scikit-image', 16 | 'torchao', 17 | ], 18 | include_package_data=True, 19 | package_data={ 20 | 'segment_anything_fast.configs': ["*.p"], 21 | }, 22 | description='A pruned, quantized, compiled, nested and batched implementation of segment-anything', 23 | long_description_content_type='text/markdown', 24 | url='https://github.com/pytorch-labs/segment-anything-fast', 25 | classifiers=[ 26 | 'Programming Language :: Python :: 3', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Operating System :: OS Independent', 29 | ], 30 | python_requires='>=3.8', 31 | ) 32 | -------------------------------------------------------------------------------- /test/test_flash_4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from segment_anything_fast.flash_4 import _attention_rel_h_rel_w 4 | 5 | def test_op(batch, head, seq_len, hidden_dim, dtype): 6 | import math 7 | 8 | sm_scale = 1.0 / math.sqrt(hidden_dim) 9 | device = "cuda" 10 | torch.manual_seed(20) 11 | q = torch.empty( 12 | (batch, head, seq_len, hidden_dim), dtype=dtype, device=device 13 | ).normal_(mean=0.0, std=0.5) 14 | k = torch.empty( 15 | (batch, head, seq_len, hidden_dim), dtype=dtype, device=device 16 | ).normal_(mean=0.0, std=0.5) 17 | v = torch.empty( 18 | (batch, head, seq_len, hidden_dim), dtype=dtype, device=device 19 | ).normal_(mean=0.0, std=0.5) 20 | w = int((seq_len) ** 0.5) 21 | assert w * w == seq_len, "seq_len must be a perfect square" 22 | 23 | rel_h = torch.empty( 24 | (batch, head, seq_len, w, 1), dtype=dtype, device=device 25 | ).normal_(mean=0, std=0.5) 26 | rel_w = torch.empty( 27 | (batch, head, seq_len, 1, w), dtype=dtype, device=device 28 | ).normal_(mean=0, std=0.5) 29 | 30 | tri_out = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) 31 | # reference implementation 32 | attn_bias = (rel_h + rel_w).view( 33 | q.size(0), q.size(1), rel_h.size(2), rel_h.size(3) * rel_w.size(4) 34 | ) 35 | ref_out = torch.nn.functional.scaled_dot_product_attention( 36 | q, k, v, attn_mask=attn_bias 37 | ) 38 | 39 | torch.testing.assert_close(ref_out, tri_out, rtol=1e-3, atol=1e-3) 40 | 41 | for batch, (head, seq_len), dtype in itertools.product([1, 8], [(16, 80), (12, 64)], [torch.float16, torch.bfloat16]): 42 | print(f"batch: {batch} head: {head} seq_len: {seq_len} dtype: {dtype}") 43 | test_op(batch, head, 4096, seq_len, dtype) 44 | -------------------------------------------------------------------------------- /test/test_mask_to_rle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from segment_anything_fast.utils.amg import ( 4 | mask_to_rle_pytorch, 5 | mask_to_rle_pytorch_2, 6 | ) 7 | 8 | def test_masks(masks): 9 | rles_0 = mask_to_rle_pytorch(masks) 10 | rles_2 = mask_to_rle_pytorch_2(masks) 11 | 12 | for i in range(len(rles_0)): 13 | torch.testing.assert_close(torch.tensor(rles_0[i]['counts']), torch.tensor(rles_2[i]['counts'])) 14 | 15 | for b, w, h in itertools.product([1, 5], [50, 128], [50, 128]): 16 | test_masks(torch.randn(b, w, h).clamp(min=0).bool().cuda()) 17 | test_masks(torch.randn(b, w, h).mul(0).bool().cuda()) 18 | test_masks(torch.randn(b, w, h).mul(0).add(1).bool().cuda()) 19 | --------------------------------------------------------------------------------