├── .ruff.toml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── Screenshots ├── sws1.png ├── sws2.png ├── sws3.png └── sws4.png ├── SegmentWithSAM ├── CMakeLists.txt ├── Resources │ ├── Icons │ │ └── SegmentWithSAM.png │ └── UI │ │ └── SegmentWithSAM.ui └── SegmentWithSAM.py ├── sam2 ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ └── transforms.py ├── sam2_configs ├── __init__.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml └── sam2_hiera_t.yaml └── setup.py /.ruff.toml: -------------------------------------------------------------------------------- 1 | target-version = "py39" 2 | line-length = 120 3 | 4 | [lint] 5 | select = [ 6 | "B", # flake8-bugbear 7 | "C4", # flake8-comprehensions 8 | "COM", # flake8-commas 9 | "E", "F", "W", # flake8 10 | "Q", # flake8-quote 11 | "ICN", # flake8-import-conventions 12 | "ISC", # flake8-implicit-str-concat 13 | "NPY", # NumPy specific rules 14 | "PGH", # pygrep-hooks 15 | "PL", # pylint 16 | "RET", # flake8-return 17 | "RUF", # Ruff-specific 18 | "UP", # pyupgrade 19 | "YTT", # flake8-2020 20 | "W", # Warning 21 | ] 22 | 23 | extend-ignore = [ 24 | "PLW0603", # Using the global statement to update `var` is discouraged 25 | 26 | "PLR0912", # Too many branches 27 | "PLR0915", # Too many statements 28 | "PLR2004", # Magic value used in comparison 29 | 30 | "RET505", # Unnecessary `elif` after `return` statement 31 | 32 | # Disable linting rules conflicting with "ruff formatter" 33 | # See https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules 34 | "COM812", 35 | "COM819", 36 | "E111", 37 | "E114", 38 | "E117", 39 | "ISC001", 40 | "ISC002", 41 | "W191", 42 | ] 43 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16.3...3.19.7 FATAL_ERROR) 2 | 3 | project(SegmentWithSAM) 4 | 5 | #----------------------------------------------------------------------------- 6 | # Extension meta-information 7 | set(EXTENSION_HOMEPAGE "https://github.com/mazurowski-lab/SlicerSegmentWithSAM") 8 | set(EXTENSION_CATEGORY "Segmentation") 9 | set(EXTENSION_CONTRIBUTORS "Zafer Yildiz (Mazurowski Lab, Duke University)") 10 | set(EXTENSION_DESCRIPTION "SegmentWithSAM aims to asist its users in segmenting medical data on 3D Slicer by comprehensively integrating the Segment Anything Model (SAM) developed by Meta.") 11 | set(EXTENSION_ICONURL "https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/SegmentWithSAM/Resources/Icons/SegmentWithSAM.png") 12 | set(EXTENSION_SCREENSHOTURLS "https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/Screenshots/sws1.png https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/Screenshots/sws2.png https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/Screenshots/sws3.png https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/Screenshots/sws4.png") 13 | set(EXTENSION_DEPENDS "PyTorch") # Specified as a list or "NA" if no dependencies 14 | 15 | #----------------------------------------------------------------------------- 16 | # Extension dependencies 17 | find_package(Slicer REQUIRED) 18 | include(${Slicer_USE_FILE}) 19 | 20 | #----------------------------------------------------------------------------- 21 | # Extension modules 22 | add_subdirectory(SegmentWithSAM) 23 | ## NEXT_MODULE 24 | 25 | #----------------------------------------------------------------------------- 26 | include(${Slicer_EXTENSION_GENERATE_CONFIG}) 27 | include(${Slicer_EXTENSION_CPACK}) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SlicerSegmentWithSAM 2 | 3 | [![arXiv Paper](https://img.shields.io/badge/arXiv-2401.12974-orange.svg?style=flat)](https://arxiv.org/abs/2408.15224) [**`MIDL Paper`**](https://openreview.net/pdf?id=zDOZ0IhLFF) 4 | 5 | SegmentWithSAM aims to asist its users in segmenting medical data on 3D Slicer by comprehensively integrating the Segment Anything Model (SAM) developed by Meta. 6 | 7 | 8 | 9 | ## How to Cite 10 | 11 | If you find our work to be useful for your research, please cite our papers: 12 | 13 | 14 | ```bibtex 15 | @article{yildiz2024sam, 16 | title={SAM \& SAM 2 in 3D Slicer: SegmentWithSAM Extension for Annotating Medical Images}, 17 | author={Yildiz, Zafer and Chen, Yuwen and Mazurowski, Maciej A}, 18 | journal={arXiv preprint arXiv:2408.15224}, 19 | year={2024} 20 | } 21 | 22 | @inproceedings{yildiz2024segmentwithsam, 23 | title={SegmentWithSAM: 3D Slicer Extension for Segment Anything Model (SAM)}, 24 | author={Yildiz, Zafer and Gu, Hanxue and Zhang, Jikai and Yang, Jichen and Mazurowski, Maciej A}, 25 | booktitle={Medical Imaging with Deep Learning}, 26 | year={2024} 27 | } 28 | ``` 29 | 30 | ## Installation via Extension Manager 31 | 32 | To install this extension via 3D Slicer's Extension Manager, you should need to follow the steps below: 33 | 34 | - Go to Extension Manager of 3D Slicer (Ctrl+4) 35 | - Search for "SegmentWithSAM" 36 | - Click "Install" button 37 | - Restart 3D Slicer 38 | 39 | ## Installation via GitHub Repository 40 | 41 | You can clone this repository by running the following command: 42 | 43 | ``` 44 | git clone https://github.com/mazurowski-lab/SlicerSegmentWithSAM.git 45 | ``` 46 | 47 | Before adding this extension to 3D Slicer, you must install some dependencies in 3D Slicer. To do this, you need the run the following commands in 3D Slicer's Python terminal. 48 | 49 | ``` 50 | slicer.util.pip_install("git+https://github.com/facebookresearch/segment-anything.git") 51 | slicer.util.pip_install("torch torchvision torchaudio") 52 | slicer.util.pip_install("opencv-python") 53 | ``` 54 | 55 | You should also download the following checkpoint of SAM into the repository directory (in the same directory as the readme file). 56 | 57 | After downloading all necessary files, you need to introduce the extension to 3D Slicer. Please go to Modules > Developer Tools > Extension Wizard on 3D Slicer and click 'Select Extension' button. You should select the root folder that contains this repository in the pop-up. If you don't get any error on Python terminal, that means you are ready to use the extension! 58 | 59 | ## Usage 60 | 61 | You can watch our tutorial video to learn how to use SegmentWithSAM. 62 | 63 | First of all, make sure you open a file on 3D Slicer before you start using SegmentWithSAM. 64 | 65 | If you've added the extension to 3D Slicer, you should be able to see it under **Modules > Segmentation > SegmentWithSAM**. You can see the user interface of the extension after you click on SegmentWithSAM in this menu. 66 | 67 | Before starting the segmentation, make sure that you've created the necessary labels for your case by clicking "Configure labels in the segment editor" button. You need to turn back to our extension through Modules > Segmentation > SegmentWithSAM path again, after you create your labels in the segment editor. You are ready to segment now! 68 | 69 | 70 | 71 | Firstly, select the label you want to segment from the dropdown list (hip for the image below). Then, click "Start Segmentation for Current Slice" button. 72 | 73 | 74 | 75 | If it is the first to segment a slice of this file, you need to wait for SAM to produce some files that will be used for the segmentation. After SAM generated these files, you can start putting **prompt points** or **prompt boxes** on the current slice. You'll be able to see the segmentation mask on 3D Slicer. Please click "Stop Segmentation for Current Slice" whenever you finish your segmentation for the current slice. 76 | 77 | 78 | 79 | If you are not satisfied with the segmentation mask produced by SAM, you can edit it as you wish using the "Segment Editor" module of 3D Slicer. 80 | -------------------------------------------------------------------------------- /Screenshots/sws1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/c8b3f39e6864a466bb7519aa453c723a95b51b85/Screenshots/sws1.png -------------------------------------------------------------------------------- /Screenshots/sws2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/c8b3f39e6864a466bb7519aa453c723a95b51b85/Screenshots/sws2.png -------------------------------------------------------------------------------- /Screenshots/sws3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/c8b3f39e6864a466bb7519aa453c723a95b51b85/Screenshots/sws3.png -------------------------------------------------------------------------------- /Screenshots/sws4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/c8b3f39e6864a466bb7519aa453c723a95b51b85/Screenshots/sws4.png -------------------------------------------------------------------------------- /SegmentWithSAM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------------------------------- 2 | set(MODULE_NAME SegmentWithSAM) 3 | 4 | #----------------------------------------------------------------------------- 5 | set(MODULE_PYTHON_SCRIPTS 6 | ${MODULE_NAME}.py 7 | ) 8 | 9 | set(MODULE_PYTHON_RESOURCES 10 | Resources/Icons/${MODULE_NAME}.png 11 | Resources/UI/${MODULE_NAME}.ui 12 | ) 13 | set(EXTENSION_HOMEPAGE "https://github.com/mazurowski-lab/SlicerSegmentWithSAM") 14 | set(EXTENSION_CATEGORY "Segmentation") 15 | set(EXTENSION_CONTRIBUTORS "Zafer Yildiz (Mazurowski Lab, Duke University)") 16 | set(EXTENSION_DESCRIPTION "SegmentWithSAM aims to asist its users in segmenting medical data on 3D Slicer by comprehensively integrating the Segment Anything Model (SAM) developed by Meta.") 17 | set(EXTENSION_ICONURL "https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/SegmentWithSAM/Resources/Icons/SegmentWithSAM.png") 18 | set(EXTENSION_SCREENSHOTURLS "https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/main/Screenshots/sws4.png") 19 | set(EXTENSION_DEPENDS "PyTorch") # Specified as a list or "NA" if no dependencies 20 | #----------------------------------------------------------------------------- 21 | slicerMacroBuildScriptedModule( 22 | NAME ${MODULE_NAME} 23 | SCRIPTS ${MODULE_PYTHON_SCRIPTS} 24 | RESOURCES ${MODULE_PYTHON_RESOURCES} 25 | WITH_GENERIC_TESTS 26 | ) 27 | 28 | #----------------------------------------------------------------------------- 29 | if(BUILD_TESTING) 30 | 31 | # Register the unittest subclass in the main script as a ctest. 32 | # Note that the test will also be available at runtime. 33 | slicer_add_python_unittest(SCRIPT ${MODULE_NAME}.py) 34 | 35 | # Additional build-time testing 36 | # NA 37 | endif() 38 | -------------------------------------------------------------------------------- /SegmentWithSAM/Resources/Icons/SegmentWithSAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/SlicerSegmentWithSAM/c8b3f39e6864a466bb7519aa453c723a95b51b85/SegmentWithSAM/Resources/Icons/SegmentWithSAM.png -------------------------------------------------------------------------------- /SegmentWithSAM/Resources/UI/SegmentWithSAM.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | SegmentWithSAM 4 | 5 | 6 | 7 | 0 8 | 0 9 | 816 10 | 759 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | Create New Box Prompt 22 | 23 | 24 | 25 | 26 | 27 | 28 | Configure Labels 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | Start 2D Segmentation for Current Slice 40 | 41 | 42 | 43 | 44 | 45 | 46 | Stop 2D Segmentation for Current Slice 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | Propagate To Left 58 | 59 | 60 | 61 | 62 | 63 | 64 | Propagate To Right 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | If you would like to use ground truth annotation as prompt for SAM-2, it is recommended that you annotate the middle slice among the slices containing the tissue you want to segment. 76 | 77 | 78 | Propagate GT Mask Through All Slices 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 0 89 | 30 90 | 91 | 92 | 93 | Add Positive Prompt Point 94 | 95 | 96 | 97 | 98 | 99 | 100 | false 101 | 102 | 103 | false 104 | 105 | 106 | 107 | 0 108 | 255 109 | 0 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | Add Negative Prompt Point 118 | 119 | 120 | 121 | 122 | 123 | 124 | false 125 | 126 | 127 | false 128 | 129 | 130 | 131 | 255 132 | 0 133 | 0 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | Qt::Horizontal 144 | 145 | 146 | 147 | 40 148 | 20 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | Select target label you want to segment: 166 | 167 | 168 | 169 | 170 | 171 | 172 | SAM produces 3 masks for the same set of prompt inputs. Mask-1 is generally the best, but you can select the most accurate mask among them: 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 0 183 | 30 184 | 185 | 186 | 187 | Select your model, label and mask output: 188 | 189 | 190 | 191 | 192 | 193 | 194 | Qt::Horizontal 195 | 196 | 197 | 198 | 40 199 | 20 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | qMRMLWidget 209 | QWidget 210 |
qMRMLWidget.h
211 | 1 212 |
213 | 214 | qSlicerWidget 215 | QWidget 216 |
qSlicerWidget.h
217 | 1 218 |
219 | 220 | qSlicerSimpleMarkupsWidget 221 | qSlicerWidget 222 |
qSlicerSimpleMarkupsWidget.h
223 |
224 |
225 | 226 | 227 | 228 | SegmentWithSAM 229 | mrmlSceneChanged(vtkMRMLScene*) 230 | positivePrompts 231 | setMRMLScene(vtkMRMLScene*) 232 | 233 | 234 | 20 235 | 20 236 | 237 | 238 | 20 239 | 20 240 | 241 | 242 | 243 | 244 | SegmentWithSAM 245 | mrmlSceneChanged(vtkMRMLScene*) 246 | negativePrompts 247 | setMRMLScene(vtkMRMLScene*) 248 | 249 | 250 | 20 251 | 20 252 | 253 | 254 | 20 255 | 20 256 | 257 | 258 | 259 | 260 |
261 | -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | 9 | initialize_config_module("sam2_configs", version_base="1.2") 10 | -------------------------------------------------------------------------------- /sam2/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py 8 | from typing import Any, Dict, List, Optional, Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 13 | 14 | from sam2.modeling.sam2_base import SAM2Base 15 | from sam2.sam2_image_predictor import SAM2ImagePredictor 16 | from sam2.utils.amg import ( 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | MaskData, 28 | remove_small_regions, 29 | rle_to_mask, 30 | uncrop_boxes_xyxy, 31 | uncrop_masks, 32 | uncrop_points, 33 | ) 34 | 35 | 36 | class SAM2AutomaticMaskGenerator: 37 | def __init__( 38 | self, 39 | model: SAM2Base, 40 | points_per_side: Optional[int] = 32, 41 | points_per_batch: int = 64, 42 | pred_iou_thresh: float = 0.8, 43 | stability_score_thresh: float = 0.95, 44 | stability_score_offset: float = 1.0, 45 | mask_threshold: float = 0.0, 46 | box_nms_thresh: float = 0.7, 47 | crop_n_layers: int = 0, 48 | crop_nms_thresh: float = 0.7, 49 | crop_overlap_ratio: float = 512 / 1500, 50 | crop_n_points_downscale_factor: int = 1, 51 | point_grids: Optional[List[np.ndarray]] = None, 52 | min_mask_region_area: int = 0, 53 | output_mode: str = "binary_mask", 54 | use_m2m: bool = False, 55 | multimask_output: bool = True, 56 | **kwargs, 57 | ) -> None: 58 | """ 59 | Using a SAM 2 model, generates masks for the entire image. 60 | Generates a grid of point prompts over the image, then filters 61 | low quality and duplicate masks. The default settings are chosen 62 | for SAM 2 with a HieraL backbone. 63 | 64 | Arguments: 65 | model (Sam): The SAM 2 model to use for mask prediction. 66 | points_per_side (int or None): The number of points to be sampled 67 | along one side of the image. The total number of points is 68 | points_per_side**2. If None, 'point_grids' must provide explicit 69 | point sampling. 70 | points_per_batch (int): Sets the number of points run simultaneously 71 | by the model. Higher numbers may be faster but use more GPU memory. 72 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 73 | model's predicted mask quality. 74 | stability_score_thresh (float): A filtering threshold in [0,1], using 75 | the stability of the mask under changes to the cutoff used to binarize 76 | the model's mask predictions. 77 | stability_score_offset (float): The amount to shift the cutoff when 78 | calculated the stability score. 79 | mask_threshold (float): Threshold for binarizing the mask logits 80 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 81 | suppression to filter duplicate masks. 82 | crop_n_layers (int): If >0, mask prediction will be run again on 83 | crops of the image. Sets the number of layers to run, where each 84 | layer has 2**i_layer number of image crops. 85 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 86 | suppression to filter duplicate masks between different crops. 87 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 88 | In the first crop layer, crops will overlap by this fraction of 89 | the image length. Later layers with more crops scale down this overlap. 90 | crop_n_points_downscale_factor (int): The number of points-per-side 91 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 92 | point_grids (list(np.ndarray) or None): A list over explicit grids 93 | of points used for sampling, normalized to [0,1]. The nth grid in the 94 | list is used in the nth crop layer. Exclusive with points_per_side. 95 | min_mask_region_area (int): If >0, postprocessing will be applied 96 | to remove disconnected regions and holes in masks with area smaller 97 | than min_mask_region_area. Requires opencv. 98 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 99 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 100 | For large resolutions, 'binary_mask' may consume large amounts of 101 | memory. 102 | use_m2m (bool): Whether to add a one step refinement using previous mask predictions. 103 | multimask_output (bool): Whether to output multimask at each point of the grid. 104 | """ 105 | 106 | assert (points_per_side is None) != ( 107 | point_grids is None 108 | ), "Exactly one of points_per_side or point_grid must be provided." 109 | if points_per_side is not None: 110 | self.point_grids = build_all_layer_point_grids( 111 | points_per_side, 112 | crop_n_layers, 113 | crop_n_points_downscale_factor, 114 | ) 115 | elif point_grids is not None: 116 | self.point_grids = point_grids 117 | else: 118 | raise ValueError("Can't have both points_per_side and point_grid be None.") 119 | 120 | assert output_mode in [ 121 | "binary_mask", 122 | "uncompressed_rle", 123 | "coco_rle", 124 | ], f"Unknown output_mode {output_mode}." 125 | if output_mode == "coco_rle": 126 | try: 127 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 128 | except ImportError as e: 129 | print("Please install pycocotools") 130 | raise e 131 | 132 | self.predictor = SAM2ImagePredictor( 133 | model, 134 | max_hole_area=min_mask_region_area, 135 | max_sprinkle_area=min_mask_region_area, 136 | ) 137 | self.points_per_batch = points_per_batch 138 | self.pred_iou_thresh = pred_iou_thresh 139 | self.stability_score_thresh = stability_score_thresh 140 | self.stability_score_offset = stability_score_offset 141 | self.mask_threshold = mask_threshold 142 | self.box_nms_thresh = box_nms_thresh 143 | self.crop_n_layers = crop_n_layers 144 | self.crop_nms_thresh = crop_nms_thresh 145 | self.crop_overlap_ratio = crop_overlap_ratio 146 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 147 | self.min_mask_region_area = min_mask_region_area 148 | self.output_mode = output_mode 149 | self.use_m2m = use_m2m 150 | self.multimask_output = multimask_output 151 | 152 | @classmethod 153 | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": 154 | """ 155 | Load a pretrained model from the Hugging Face hub. 156 | 157 | Arguments: 158 | model_id (str): The Hugging Face repository ID. 159 | **kwargs: Additional arguments to pass to the model constructor. 160 | 161 | Returns: 162 | (SAM2AutomaticMaskGenerator): The loaded model. 163 | """ 164 | from sam2.build_sam import build_sam2_hf 165 | 166 | sam_model = build_sam2_hf(model_id, **kwargs) 167 | return cls(sam_model, **kwargs) 168 | 169 | @torch.no_grad() 170 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 171 | """ 172 | Generates masks for the given image. 173 | 174 | Arguments: 175 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 176 | 177 | Returns: 178 | list(dict(str, any)): A list over records for masks. Each record is 179 | a dict containing the following keys: 180 | segmentation (dict(str, any) or np.ndarray): The mask. If 181 | output_mode='binary_mask', is an array of shape HW. Otherwise, 182 | is a dictionary containing the RLE. 183 | bbox (list(float)): The box around the mask, in XYWH format. 184 | area (int): The area in pixels of the mask. 185 | predicted_iou (float): The model's own prediction of the mask's 186 | quality. This is filtered by the pred_iou_thresh parameter. 187 | point_coords (list(list(float))): The point coordinates input 188 | to the model to generate this mask. 189 | stability_score (float): A measure of the mask's quality. This 190 | is filtered on using the stability_score_thresh parameter. 191 | crop_box (list(float)): The crop of the image used to generate 192 | the mask, given in XYWH format. 193 | """ 194 | 195 | # Generate masks 196 | mask_data = self._generate_masks(image) 197 | 198 | # Encode masks 199 | if self.output_mode == "coco_rle": 200 | mask_data["segmentations"] = [ 201 | coco_encode_rle(rle) for rle in mask_data["rles"] 202 | ] 203 | elif self.output_mode == "binary_mask": 204 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 205 | else: 206 | mask_data["segmentations"] = mask_data["rles"] 207 | 208 | # Write mask records 209 | curr_anns = [] 210 | for idx in range(len(mask_data["segmentations"])): 211 | ann = { 212 | "segmentation": mask_data["segmentations"][idx], 213 | "area": area_from_rle(mask_data["rles"][idx]), 214 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 215 | "predicted_iou": mask_data["iou_preds"][idx].item(), 216 | "point_coords": [mask_data["points"][idx].tolist()], 217 | "stability_score": mask_data["stability_score"][idx].item(), 218 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 219 | } 220 | curr_anns.append(ann) 221 | 222 | return curr_anns 223 | 224 | def _generate_masks(self, image: np.ndarray) -> MaskData: 225 | orig_size = image.shape[:2] 226 | crop_boxes, layer_idxs = generate_crop_boxes( 227 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 228 | ) 229 | 230 | # Iterate over image crops 231 | data = MaskData() 232 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 233 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 234 | data.cat(crop_data) 235 | 236 | # Remove duplicate masks between crops 237 | if len(crop_boxes) > 1: 238 | # Prefer masks from smaller crops 239 | scores = 1 / box_area(data["crop_boxes"]) 240 | scores = scores.to(data["boxes"].device) 241 | keep_by_nms = batched_nms( 242 | data["boxes"].float(), 243 | scores, 244 | torch.zeros_like(data["boxes"][:, 0]), # categories 245 | iou_threshold=self.crop_nms_thresh, 246 | ) 247 | data.filter(keep_by_nms) 248 | data.to_numpy() 249 | return data 250 | 251 | def _process_crop( 252 | self, 253 | image: np.ndarray, 254 | crop_box: List[int], 255 | crop_layer_idx: int, 256 | orig_size: Tuple[int, ...], 257 | ) -> MaskData: 258 | # Crop the image and calculate embeddings 259 | x0, y0, x1, y1 = crop_box 260 | cropped_im = image[y0:y1, x0:x1, :] 261 | cropped_im_size = cropped_im.shape[:2] 262 | self.predictor.set_image(cropped_im) 263 | 264 | # Get points for this crop 265 | points_scale = np.array(cropped_im_size)[None, ::-1] 266 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 267 | 268 | # Generate masks for this crop in batches 269 | data = MaskData() 270 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 271 | batch_data = self._process_batch( 272 | points, cropped_im_size, crop_box, orig_size, normalize=True 273 | ) 274 | data.cat(batch_data) 275 | del batch_data 276 | self.predictor.reset_predictor() 277 | 278 | # Remove duplicates within this crop. 279 | keep_by_nms = batched_nms( 280 | data["boxes"].float(), 281 | data["iou_preds"], 282 | torch.zeros_like(data["boxes"][:, 0]), # categories 283 | iou_threshold=self.box_nms_thresh, 284 | ) 285 | data.filter(keep_by_nms) 286 | 287 | # Return to the original image frame 288 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 289 | data["points"] = uncrop_points(data["points"], crop_box) 290 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 291 | 292 | return data 293 | 294 | def _process_batch( 295 | self, 296 | points: np.ndarray, 297 | im_size: Tuple[int, ...], 298 | crop_box: List[int], 299 | orig_size: Tuple[int, ...], 300 | normalize=False, 301 | ) -> MaskData: 302 | orig_h, orig_w = orig_size 303 | 304 | # Run model on this batch 305 | points = torch.as_tensor( 306 | points, dtype=torch.float32, device=self.predictor.device 307 | ) 308 | in_points = self.predictor._transforms.transform_coords( 309 | points, normalize=normalize, orig_hw=im_size 310 | ) 311 | in_labels = torch.ones( 312 | in_points.shape[0], dtype=torch.int, device=in_points.device 313 | ) 314 | masks, iou_preds, low_res_masks = self.predictor._predict( 315 | in_points[:, None, :], 316 | in_labels[:, None], 317 | multimask_output=self.multimask_output, 318 | return_logits=True, 319 | ) 320 | 321 | # Serialize predictions and store in MaskData 322 | data = MaskData( 323 | masks=masks.flatten(0, 1), 324 | iou_preds=iou_preds.flatten(0, 1), 325 | points=points.repeat_interleave(masks.shape[1], dim=0), 326 | low_res_masks=low_res_masks.flatten(0, 1), 327 | ) 328 | del masks 329 | 330 | if not self.use_m2m: 331 | # Filter by predicted IoU 332 | if self.pred_iou_thresh > 0.0: 333 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 334 | data.filter(keep_mask) 335 | 336 | # Calculate and filter by stability score 337 | data["stability_score"] = calculate_stability_score( 338 | data["masks"], self.mask_threshold, self.stability_score_offset 339 | ) 340 | if self.stability_score_thresh > 0.0: 341 | keep_mask = data["stability_score"] >= self.stability_score_thresh 342 | data.filter(keep_mask) 343 | else: 344 | # One step refinement using previous mask predictions 345 | in_points = self.predictor._transforms.transform_coords( 346 | data["points"], normalize=normalize, orig_hw=im_size 347 | ) 348 | labels = torch.ones( 349 | in_points.shape[0], dtype=torch.int, device=in_points.device 350 | ) 351 | masks, ious = self.refine_with_m2m( 352 | in_points, labels, data["low_res_masks"], self.points_per_batch 353 | ) 354 | data["masks"] = masks.squeeze(1) 355 | data["iou_preds"] = ious.squeeze(1) 356 | 357 | if self.pred_iou_thresh > 0.0: 358 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 359 | data.filter(keep_mask) 360 | 361 | data["stability_score"] = calculate_stability_score( 362 | data["masks"], self.mask_threshold, self.stability_score_offset 363 | ) 364 | if self.stability_score_thresh > 0.0: 365 | keep_mask = data["stability_score"] >= self.stability_score_thresh 366 | data.filter(keep_mask) 367 | 368 | # Threshold masks and calculate boxes 369 | data["masks"] = data["masks"] > self.mask_threshold 370 | data["boxes"] = batched_mask_to_box(data["masks"]) 371 | 372 | # Filter boxes that touch crop boundaries 373 | keep_mask = ~is_box_near_crop_edge( 374 | data["boxes"], crop_box, [0, 0, orig_w, orig_h] 375 | ) 376 | if not torch.all(keep_mask): 377 | data.filter(keep_mask) 378 | 379 | # Compress to RLE 380 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 381 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 382 | del data["masks"] 383 | 384 | return data 385 | 386 | @staticmethod 387 | def postprocess_small_regions( 388 | mask_data: MaskData, min_area: int, nms_thresh: float 389 | ) -> MaskData: 390 | """ 391 | Removes small disconnected regions and holes in masks, then reruns 392 | box NMS to remove any new duplicates. 393 | 394 | Edits mask_data in place. 395 | 396 | Requires open-cv as a dependency. 397 | """ 398 | if len(mask_data["rles"]) == 0: 399 | return mask_data 400 | 401 | # Filter small disconnected regions and holes 402 | new_masks = [] 403 | scores = [] 404 | for rle in mask_data["rles"]: 405 | mask = rle_to_mask(rle) 406 | 407 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 408 | unchanged = not changed 409 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 410 | unchanged = unchanged and not changed 411 | 412 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 413 | # Give score=0 to changed masks and score=1 to unchanged masks 414 | # so NMS will prefer ones that didn't need postprocessing 415 | scores.append(float(unchanged)) 416 | 417 | # Recalculate boxes and remove any new duplicates 418 | masks = torch.cat(new_masks, dim=0) 419 | boxes = batched_mask_to_box(masks) 420 | keep_by_nms = batched_nms( 421 | boxes.float(), 422 | torch.as_tensor(scores), 423 | torch.zeros_like(boxes[:, 0]), # categories 424 | iou_threshold=nms_thresh, 425 | ) 426 | 427 | # Only recalculate RLEs for masks that have changed 428 | for i_mask in keep_by_nms: 429 | if scores[i_mask] == 0.0: 430 | mask_torch = masks[i_mask].unsqueeze(0) 431 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 432 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 433 | mask_data.filter(keep_by_nms) 434 | 435 | return mask_data 436 | 437 | def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): 438 | new_masks = [] 439 | new_iou_preds = [] 440 | 441 | for cur_points, cur_point_labels, low_res_mask in batch_iterator( 442 | points_per_batch, points, point_labels, low_res_masks 443 | ): 444 | best_masks, best_iou_preds, _ = self.predictor._predict( 445 | cur_points[:, None, :], 446 | cur_point_labels[:, None], 447 | mask_input=low_res_mask[:, None, :], 448 | multimask_output=False, 449 | return_logits=True, 450 | ) 451 | new_masks.append(best_masks) 452 | new_iou_preds.append(best_iou_preds) 453 | masks = torch.cat(new_masks, dim=0) 454 | return masks, torch.cat(new_iou_preds, dim=0) 455 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def build_sam2( 16 | config_file, 17 | ckpt_path=None, 18 | device="cuda", 19 | mode="eval", 20 | hydra_overrides_extra=[], 21 | apply_postprocessing=True, 22 | **kwargs, 23 | ): 24 | 25 | if apply_postprocessing: 26 | hydra_overrides_extra = hydra_overrides_extra.copy() 27 | hydra_overrides_extra += [ 28 | # dynamically fall back to multi-mask if the single mask is not stable 29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 30 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 31 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 32 | ] 33 | # Read config and init model 34 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 35 | OmegaConf.resolve(cfg) 36 | model = instantiate(cfg.model, _recursive_=True) 37 | _load_checkpoint(model, ckpt_path) 38 | model = model.to(device) 39 | if mode == "eval": 40 | model.eval() 41 | return model 42 | 43 | 44 | def build_sam2_video_predictor( 45 | config_file, 46 | ckpt_path=None, 47 | device="cuda", 48 | mode="eval", 49 | hydra_overrides_extra=[], 50 | apply_postprocessing=True, 51 | **kwargs, 52 | ): 53 | hydra_overrides = [ 54 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 55 | ] 56 | if apply_postprocessing: 57 | hydra_overrides_extra = hydra_overrides_extra.copy() 58 | hydra_overrides_extra += [ 59 | # dynamically fall back to multi-mask if the single mask is not stable 60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 61 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 62 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 63 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 64 | "++model.binarize_mask_from_pts_for_mem_enc=true", 65 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 66 | "++model.fill_hole_area=8", 67 | ] 68 | hydra_overrides.extend(hydra_overrides_extra) 69 | 70 | # Read config and init model 71 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 72 | OmegaConf.resolve(cfg) 73 | model = instantiate(cfg.model, _recursive_=True) 74 | _load_checkpoint(model, ckpt_path) 75 | model = model.to(device) 76 | if mode == "eval": 77 | model.eval() 78 | return model 79 | 80 | 81 | def build_sam2_hf(model_id, **kwargs): 82 | 83 | from huggingface_hub import hf_hub_download 84 | 85 | model_id_to_filenames = { 86 | "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), 87 | "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), 88 | "facebook/sam2-hiera-base-plus": ( 89 | "sam2_hiera_b+.yaml", 90 | "sam2_hiera_base_plus.pt", 91 | ), 92 | "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), 93 | } 94 | config_name, checkpoint_name = model_id_to_filenames[model_id] 95 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 96 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 97 | 98 | 99 | def build_sam2_video_predictor_hf(model_id, **kwargs): 100 | 101 | from huggingface_hub import hf_hub_download 102 | 103 | model_id_to_filenames = { 104 | "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), 105 | "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), 106 | "facebook/sam2-hiera-base-plus": ( 107 | "sam2_hiera_b+.yaml", 108 | "sam2_hiera_base_plus.pt", 109 | ), 110 | "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), 111 | } 112 | config_name, checkpoint_name = model_id_to_filenames[model_id] 113 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 114 | return build_sam2_video_predictor( 115 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 116 | ) 117 | 118 | 119 | def _load_checkpoint(model, ckpt_path): 120 | if ckpt_path is not None: 121 | sd = torch.load(ckpt_path, map_location="cpu")["model"] 122 | missing_keys, unexpected_keys = model.load_state_dict(sd) 123 | if missing_keys: 124 | logging.error(missing_keys) 125 | raise RuntimeError() 126 | if unexpected_keys: 127 | logging.error(unexpected_keys) 128 | raise RuntimeError() 129 | logging.info("Loaded checkpoint sucessfully") 130 | -------------------------------------------------------------------------------- /sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be an even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be an even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | from typing import List, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.backbones.utils import ( 15 | PatchEmbed, 16 | window_partition, 17 | window_unpartition, 18 | ) 19 | 20 | from sam2.modeling.sam2_utils import DropPath, MLP 21 | 22 | 23 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 24 | if pool is None: 25 | return x 26 | # (B, H, W, C) -> (B, C, H, W) 27 | x = x.permute(0, 3, 1, 2) 28 | x = pool(x) 29 | # (B, C, H', W') -> (B, H', W', C) 30 | x = x.permute(0, 2, 3, 1) 31 | if norm: 32 | x = norm(x) 33 | 34 | return x 35 | 36 | 37 | class MultiScaleAttention(nn.Module): 38 | def __init__( 39 | self, 40 | dim: int, 41 | dim_out: int, 42 | num_heads: int, 43 | q_pool: nn.Module = None, 44 | ): 45 | super().__init__() 46 | 47 | self.dim = dim 48 | self.dim_out = dim_out 49 | self.num_heads = num_heads 50 | self.q_pool = q_pool 51 | self.qkv = nn.Linear(dim, dim_out * 3) 52 | self.proj = nn.Linear(dim_out, dim_out) 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | B, H, W, _ = x.shape 56 | # qkv with shape (B, H * W, 3, nHead, C) 57 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 58 | # q, k, v with shape (B, H * W, nheads, C) 59 | q, k, v = torch.unbind(qkv, 2) 60 | 61 | # Q pooling (for downsample at stage changes) 62 | if self.q_pool: 63 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 64 | H, W = q.shape[1:3] # downsampled shape 65 | q = q.reshape(B, H * W, self.num_heads, -1) 66 | 67 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 68 | x = F.scaled_dot_product_attention( 69 | q.transpose(1, 2), 70 | k.transpose(1, 2), 71 | v.transpose(1, 2), 72 | ) 73 | # Transpose back 74 | x = x.transpose(1, 2) 75 | x = x.reshape(B, H, W, -1) 76 | 77 | x = self.proj(x) 78 | 79 | return x 80 | 81 | 82 | class MultiScaleBlock(nn.Module): 83 | def __init__( 84 | self, 85 | dim: int, 86 | dim_out: int, 87 | num_heads: int, 88 | mlp_ratio: float = 4.0, 89 | drop_path: float = 0.0, 90 | norm_layer: Union[nn.Module, str] = "LayerNorm", 91 | q_stride: Tuple[int, int] = None, 92 | act_layer: nn.Module = nn.GELU, 93 | window_size: int = 0, 94 | ): 95 | super().__init__() 96 | 97 | if isinstance(norm_layer, str): 98 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 99 | 100 | self.dim = dim 101 | self.dim_out = dim_out 102 | self.norm1 = norm_layer(dim) 103 | 104 | self.window_size = window_size 105 | 106 | self.pool, self.q_stride = None, q_stride 107 | if self.q_stride: 108 | self.pool = nn.MaxPool2d( 109 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 110 | ) 111 | 112 | self.attn = MultiScaleAttention( 113 | dim, 114 | dim_out, 115 | num_heads=num_heads, 116 | q_pool=self.pool, 117 | ) 118 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 119 | 120 | self.norm2 = norm_layer(dim_out) 121 | self.mlp = MLP( 122 | dim_out, 123 | int(dim_out * mlp_ratio), 124 | dim_out, 125 | num_layers=2, 126 | activation=act_layer, 127 | ) 128 | 129 | if dim != dim_out: 130 | self.proj = nn.Linear(dim, dim_out) 131 | 132 | def forward(self, x: torch.Tensor) -> torch.Tensor: 133 | shortcut = x # B, H, W, C 134 | x = self.norm1(x) 135 | 136 | # Skip connection 137 | if self.dim != self.dim_out: 138 | shortcut = do_pool(self.proj(x), self.pool) 139 | 140 | # Window partition 141 | window_size = self.window_size 142 | if window_size > 0: 143 | H, W = x.shape[1], x.shape[2] 144 | x, pad_hw = window_partition(x, window_size) 145 | 146 | # Window Attention + Q Pooling (if stage change) 147 | x = self.attn(x) 148 | if self.q_stride: 149 | # Shapes have changed due to Q pooling 150 | window_size = self.window_size // self.q_stride[0] 151 | H, W = shortcut.shape[1:3] 152 | 153 | pad_h = (window_size - H % window_size) % window_size 154 | pad_w = (window_size - W % window_size) % window_size 155 | pad_hw = (H + pad_h, W + pad_w) 156 | 157 | # Reverse window partition 158 | if self.window_size > 0: 159 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 160 | 161 | x = shortcut + self.drop_path(x) 162 | # MLP 163 | x = x + self.drop_path(self.mlp(self.norm2(x))) 164 | return x 165 | 166 | 167 | class Hiera(nn.Module): 168 | """ 169 | Reference: https://arxiv.org/abs/2306.00989 170 | """ 171 | 172 | def __init__( 173 | self, 174 | embed_dim: int = 96, # initial embed dim 175 | num_heads: int = 1, # initial number of heads 176 | drop_path_rate: float = 0.0, # stochastic depth 177 | q_pool: int = 3, # number of q_pool stages 178 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 179 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 180 | dim_mul: float = 2.0, # dim_mul factor at stage shift 181 | head_mul: float = 2.0, # head_mul factor at stage shift 182 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 183 | # window size per stage, when not using global att. 184 | window_spec: Tuple[int, ...] = ( 185 | 8, 186 | 4, 187 | 14, 188 | 7, 189 | ), 190 | # global attn in these blocks 191 | global_att_blocks: Tuple[int, ...] = ( 192 | 12, 193 | 16, 194 | 20, 195 | ), 196 | return_interm_layers=True, # return feats from every stage 197 | ): 198 | super().__init__() 199 | 200 | assert len(stages) == len(window_spec) 201 | self.window_spec = window_spec 202 | 203 | depth = sum(stages) 204 | self.q_stride = q_stride 205 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 206 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 207 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 208 | self.return_interm_layers = return_interm_layers 209 | 210 | self.patch_embed = PatchEmbed( 211 | embed_dim=embed_dim, 212 | ) 213 | # Which blocks have global att? 214 | self.global_att_blocks = global_att_blocks 215 | 216 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 217 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 218 | self.pos_embed = nn.Parameter( 219 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 220 | ) 221 | self.pos_embed_window = nn.Parameter( 222 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 223 | ) 224 | 225 | dpr = [ 226 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 227 | ] # stochastic depth decay rule 228 | 229 | cur_stage = 1 230 | self.blocks = nn.ModuleList() 231 | 232 | for i in range(depth): 233 | dim_out = embed_dim 234 | # lags by a block, so first block of 235 | # next stage uses an initial window size 236 | # of previous stage and final window size of current stage 237 | window_size = self.window_spec[cur_stage - 1] 238 | 239 | if self.global_att_blocks is not None: 240 | window_size = 0 if i in self.global_att_blocks else window_size 241 | 242 | if i - 1 in self.stage_ends: 243 | dim_out = int(embed_dim * dim_mul) 244 | num_heads = int(num_heads * head_mul) 245 | cur_stage += 1 246 | 247 | block = MultiScaleBlock( 248 | dim=embed_dim, 249 | dim_out=dim_out, 250 | num_heads=num_heads, 251 | drop_path=dpr[i], 252 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 253 | window_size=window_size, 254 | ) 255 | 256 | embed_dim = dim_out 257 | self.blocks.append(block) 258 | 259 | self.channel_list = ( 260 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 261 | if return_interm_layers 262 | else [self.blocks[-1].dim_out] 263 | ) 264 | 265 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 266 | h, w = hw 267 | window_embed = self.pos_embed_window 268 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 269 | pos_embed = pos_embed + window_embed.tile( 270 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 271 | ) 272 | pos_embed = pos_embed.permute(0, 2, 3, 1) 273 | return pos_embed 274 | 275 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 276 | x = self.patch_embed(x) 277 | # x: (B, H, W, C) 278 | 279 | # Add pos embed 280 | x = x + self._get_pos_embed(x.shape[1:3]) 281 | 282 | outputs = [] 283 | for i, blk in enumerate(self.blocks): 284 | x = blk(x) 285 | if (i == self.stage_ends[-1]) or ( 286 | i in self.stage_ends and self.return_interm_layers 287 | ): 288 | feats = x.permute(0, 3, 1, 2) 289 | outputs.append(feats) 290 | 291 | return outputs 292 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | ): 29 | super().__init__() 30 | assert num_pos_feats % 2 == 0, "Expecting even model width" 31 | self.num_pos_feats = num_pos_feats // 2 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | self.cache = {} 41 | 42 | def _encode_xy(self, x, y): 43 | # The positions are expected to be normalized 44 | assert len(x) == len(y) and x.ndim == y.ndim == 1 45 | x_embed = x * self.scale 46 | y_embed = y * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, None] / dim_t 52 | pos_y = y_embed[:, None] / dim_t 53 | pos_x = torch.stack( 54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 55 | ).flatten(1) 56 | pos_y = torch.stack( 57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 58 | ).flatten(1) 59 | return pos_x, pos_y 60 | 61 | @torch.no_grad() 62 | def encode_boxes(self, x, y, w, h): 63 | pos_x, pos_y = self._encode_xy(x, y) 64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 65 | return pos 66 | 67 | encode = encode_boxes # Backwards compatibility 68 | 69 | @torch.no_grad() 70 | def encode_points(self, x, y, labels): 71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 72 | assert bx == by and nx == ny and bx == bl and nx == nl 73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 76 | return pos 77 | 78 | @torch.no_grad() 79 | def forward(self, x: torch.Tensor): 80 | cache_key = (x.shape[-2], x.shape[-1]) 81 | if cache_key in self.cache: 82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 83 | y_embed = ( 84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 85 | .view(1, -1, 1) 86 | .repeat(x.shape[0], 1, x.shape[-1]) 87 | ) 88 | x_embed = ( 89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 90 | .view(1, 1, -1) 91 | .repeat(x.shape[0], x.shape[-2], 1) 92 | ) 93 | 94 | if self.normalize: 95 | eps = 1e-6 96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 98 | 99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 101 | 102 | pos_x = x_embed[:, :, :, None] / dim_t 103 | pos_y = y_embed[:, :, :, None] / dim_t 104 | pos_x = torch.stack( 105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 106 | ).flatten(3) 107 | pos_y = torch.stack( 108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 109 | ).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | self.cache[cache_key] = pos[0] 112 | return pos 113 | 114 | 115 | class PositionEmbeddingRandom(nn.Module): 116 | """ 117 | Positional encoding using random spatial frequencies. 118 | """ 119 | 120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 121 | super().__init__() 122 | if scale is None or scale <= 0.0: 123 | scale = 1.0 124 | self.register_buffer( 125 | "positional_encoding_gaussian_matrix", 126 | scale * torch.randn((2, num_pos_feats)), 127 | ) 128 | 129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 130 | """Positionally encode points that are normalized to [0,1].""" 131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 132 | coords = 2 * coords - 1 133 | coords = coords @ self.positional_encoding_gaussian_matrix 134 | coords = 2 * np.pi * coords 135 | # outputs d_1 x ... x d_n x C shape 136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 137 | 138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 139 | """Generate positional encoding for a grid of the specified size.""" 140 | h, w = size 141 | device: Any = self.positional_encoding_gaussian_matrix.device 142 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 143 | y_embed = grid.cumsum(dim=0) - 0.5 144 | x_embed = grid.cumsum(dim=1) - 0.5 145 | y_embed = y_embed / h 146 | x_embed = x_embed / w 147 | 148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 149 | return pe.permute(2, 0, 1) # C x H x W 150 | 151 | def forward_with_coords( 152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 153 | ) -> torch.Tensor: 154 | """Positionally encode points that are not normalized to [0,1].""" 155 | coords = coords_input.clone() 156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 159 | 160 | 161 | # Rotary Positional Encoding, adapted from: 162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 163 | # 2. https://github.com/naver-ai/rope-vit 164 | # 3. https://github.com/lucidrains/rotary-embedding-torch 165 | 166 | 167 | def init_t_xy(end_x: int, end_y: int): 168 | t = torch.arange(end_x * end_y, dtype=torch.float32) 169 | t_x = (t % end_x).float() 170 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 171 | return t_x, t_y 172 | 173 | 174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 177 | 178 | t_x, t_y = init_t_xy(end_x, end_y) 179 | freqs_x = torch.outer(t_x, freqs_x) 180 | freqs_y = torch.outer(t_y, freqs_y) 181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 184 | 185 | 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | ndim = x.ndim 188 | assert 0 <= 1 < ndim 189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 191 | return freqs_cis.view(*shape) 192 | 193 | 194 | def apply_rotary_enc( 195 | xq: torch.Tensor, 196 | xk: torch.Tensor, 197 | freqs_cis: torch.Tensor, 198 | repeat_freqs_k: bool = False, 199 | ): 200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 201 | xk_ = ( 202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 203 | if xk.shape[-2] != 0 204 | else None 205 | ) 206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 208 | if xk_ is None: 209 | # no keys to rotate, due to dropout 210 | return xq_out.type_as(xq).to(xq.device), xk 211 | # repeat freqs along seq_len dim to match k seq_len 212 | if repeat_freqs_k: 213 | r = xk_.shape[-2] // xq_.shape[-2] 214 | if freqs_cis.is_cuda: 215 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 216 | else: 217 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 218 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 219 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 220 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 221 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 222 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.sam2_utils import LayerNorm2d, MLP 13 | 14 | 15 | class MaskDecoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | transformer_dim: int, 20 | transformer: nn.Module, 21 | num_multimask_outputs: int = 3, 22 | activation: Type[nn.Module] = nn.GELU, 23 | iou_head_depth: int = 3, 24 | iou_head_hidden_dim: int = 256, 25 | use_high_res_features: bool = False, 26 | iou_prediction_use_sigmoid=False, 27 | dynamic_multimask_via_stability=False, 28 | dynamic_multimask_stability_delta=0.05, 29 | dynamic_multimask_stability_thresh=0.98, 30 | pred_obj_scores: bool = False, 31 | pred_obj_scores_mlp: bool = False, 32 | use_multimask_token_for_obj_ptr: bool = False, 33 | ) -> None: 34 | """ 35 | Predicts masks given an image and prompt embeddings, using a 36 | transformer architecture. 37 | 38 | Arguments: 39 | transformer_dim (int): the channel dimension of the transformer 40 | transformer (nn.Module): the transformer used to predict masks 41 | num_multimask_outputs (int): the number of masks to predict 42 | when disambiguating masks 43 | activation (nn.Module): the type of activation to use when 44 | upscaling masks 45 | iou_head_depth (int): the depth of the MLP used to predict 46 | mask quality 47 | iou_head_hidden_dim (int): the hidden dimension of the MLP 48 | used to predict mask quality 49 | """ 50 | super().__init__() 51 | self.transformer_dim = transformer_dim 52 | self.transformer = transformer 53 | 54 | self.num_multimask_outputs = num_multimask_outputs 55 | 56 | self.iou_token = nn.Embedding(1, transformer_dim) 57 | self.num_mask_tokens = num_multimask_outputs + 1 58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 59 | 60 | self.pred_obj_scores = pred_obj_scores 61 | if self.pred_obj_scores: 62 | self.obj_score_token = nn.Embedding(1, transformer_dim) 63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr 64 | 65 | self.output_upscaling = nn.Sequential( 66 | nn.ConvTranspose2d( 67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 68 | ), 69 | LayerNorm2d(transformer_dim // 4), 70 | activation(), 71 | nn.ConvTranspose2d( 72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 73 | ), 74 | activation(), 75 | ) 76 | self.use_high_res_features = use_high_res_features 77 | if use_high_res_features: 78 | self.conv_s0 = nn.Conv2d( 79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 80 | ) 81 | self.conv_s1 = nn.Conv2d( 82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 83 | ) 84 | 85 | self.output_hypernetworks_mlps = nn.ModuleList( 86 | [ 87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 88 | for i in range(self.num_mask_tokens) 89 | ] 90 | ) 91 | 92 | self.iou_prediction_head = MLP( 93 | transformer_dim, 94 | iou_head_hidden_dim, 95 | self.num_mask_tokens, 96 | iou_head_depth, 97 | sigmoid_output=iou_prediction_use_sigmoid, 98 | ) 99 | if self.pred_obj_scores: 100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1) 101 | if pred_obj_scores_mlp: 102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) 103 | 104 | # When outputting a single mask, optionally we can dynamically fall back to the best 105 | # multimask output token if the single mask output token gives low stability scores. 106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability 107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta 108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh 109 | 110 | def forward( 111 | self, 112 | image_embeddings: torch.Tensor, 113 | image_pe: torch.Tensor, 114 | sparse_prompt_embeddings: torch.Tensor, 115 | dense_prompt_embeddings: torch.Tensor, 116 | multimask_output: bool, 117 | repeat_image: bool, 118 | high_res_features: Optional[List[torch.Tensor]] = None, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """ 121 | Predict masks given image and prompt embeddings. 122 | 123 | Arguments: 124 | image_embeddings (torch.Tensor): the embeddings from the image encoder 125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 128 | multimask_output (bool): Whether to return multiple masks or a single 129 | mask. 130 | 131 | Returns: 132 | torch.Tensor: batched predicted masks 133 | torch.Tensor: batched predictions of mask quality 134 | torch.Tensor: batched SAM token for mask output 135 | """ 136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( 137 | image_embeddings=image_embeddings, 138 | image_pe=image_pe, 139 | sparse_prompt_embeddings=sparse_prompt_embeddings, 140 | dense_prompt_embeddings=dense_prompt_embeddings, 141 | repeat_image=repeat_image, 142 | high_res_features=high_res_features, 143 | ) 144 | 145 | # Select the correct mask or masks for output 146 | if multimask_output: 147 | masks = masks[:, 1:, :, :] 148 | iou_pred = iou_pred[:, 1:] 149 | elif self.dynamic_multimask_via_stability and not self.training: 150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) 151 | else: 152 | masks = masks[:, 0:1, :, :] 153 | iou_pred = iou_pred[:, 0:1] 154 | 155 | if multimask_output and self.use_multimask_token_for_obj_ptr: 156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape 157 | else: 158 | # Take the mask output token. Here we *always* use the token for single mask output. 159 | # At test time, even if we track after 1-click (and using multimask_output=True), 160 | # we still take the single mask token here. The rationale is that we always track 161 | # after multiple clicks during training, so the past tokens seen during training 162 | # are always the single mask token (and we'll let it be the object-memory token). 163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape 164 | 165 | # Prepare output 166 | return masks, iou_pred, sam_tokens_out, object_score_logits 167 | 168 | def predict_masks( 169 | self, 170 | image_embeddings: torch.Tensor, 171 | image_pe: torch.Tensor, 172 | sparse_prompt_embeddings: torch.Tensor, 173 | dense_prompt_embeddings: torch.Tensor, 174 | repeat_image: bool, 175 | high_res_features: Optional[List[torch.Tensor]] = None, 176 | ) -> Tuple[torch.Tensor, torch.Tensor]: 177 | """Predicts masks. See 'forward' for more details.""" 178 | # Concatenate output tokens 179 | s = 0 180 | if self.pred_obj_scores: 181 | output_tokens = torch.cat( 182 | [ 183 | self.obj_score_token.weight, 184 | self.iou_token.weight, 185 | self.mask_tokens.weight, 186 | ], 187 | dim=0, 188 | ) 189 | s = 1 190 | else: 191 | output_tokens = torch.cat( 192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 193 | ) 194 | output_tokens = output_tokens.unsqueeze(0).expand( 195 | sparse_prompt_embeddings.size(0), -1, -1 196 | ) 197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 198 | 199 | # Expand per-image data in batch direction to be per-mask 200 | if repeat_image: 201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 202 | else: 203 | assert image_embeddings.shape[0] == tokens.shape[0] 204 | src = image_embeddings 205 | src = src + dense_prompt_embeddings 206 | assert ( 207 | image_pe.size(0) == 1 208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" 209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 210 | b, c, h, w = src.shape 211 | 212 | # Run the transformer 213 | hs, src = self.transformer(src, pos_src, tokens) 214 | iou_token_out = hs[:, s, :] 215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] 216 | 217 | # Upscale mask embeddings and predict masks using the mask tokens 218 | src = src.transpose(1, 2).view(b, c, h, w) 219 | if not self.use_high_res_features: 220 | upscaled_embedding = self.output_upscaling(src) 221 | else: 222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling 223 | feat_s0, feat_s1 = high_res_features 224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) 225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) 226 | 227 | hyper_in_list: List[torch.Tensor] = [] 228 | for i in range(self.num_mask_tokens): 229 | hyper_in_list.append( 230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 231 | ) 232 | hyper_in = torch.stack(hyper_in_list, dim=1) 233 | b, c, h, w = upscaled_embedding.shape 234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 235 | 236 | # Generate mask quality predictions 237 | iou_pred = self.iou_prediction_head(iou_token_out) 238 | if self.pred_obj_scores: 239 | assert s == 1 240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) 241 | else: 242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) 244 | 245 | return masks, iou_pred, mask_tokens_out, object_score_logits 246 | 247 | def _get_stability_scores(self, mask_logits): 248 | """ 249 | Compute stability scores of the mask logits based on the IoU between upper and 250 | lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. 251 | """ 252 | mask_logits = mask_logits.flatten(-2) 253 | stability_delta = self.dynamic_multimask_stability_delta 254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() 255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() 256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) 257 | return stability_scores 258 | 259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): 260 | """ 261 | When outputting a single mask, if the stability score from the current single-mask 262 | output (based on output token 0) falls below a threshold, we instead select from 263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted 264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking. 265 | """ 266 | # The best mask from multimask output tokens (1~3) 267 | multimask_logits = all_mask_logits[:, 1:, :, :] 268 | multimask_iou_scores = all_iou_scores[:, 1:] 269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) 270 | batch_inds = torch.arange( 271 | multimask_iou_scores.size(0), device=all_iou_scores.device 272 | ) 273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] 274 | best_multimask_logits = best_multimask_logits.unsqueeze(1) 275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] 276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) 277 | 278 | # The mask from singlemask output token 0 and its stability score 279 | singlemask_logits = all_mask_logits[:, 0:1, :, :] 280 | singlemask_iou_scores = all_iou_scores[:, 0:1] 281 | stability_scores = self._get_stability_scores(singlemask_logits) 282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh 283 | 284 | # Dynamically fall back to best multimask output upon low stability scores. 285 | mask_logits_out = torch.where( 286 | is_stable[..., None, None].expand_as(singlemask_logits), 287 | singlemask_logits, 288 | best_multimask_logits, 289 | ) 290 | iou_scores_out = torch.where( 291 | is_stable.expand_as(singlemask_iou_scores), 292 | singlemask_iou_scores, 293 | best_multimask_iou_scores, 294 | ) 295 | return mask_logits_out, iou_scores_out 296 | -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | """ 147 | Embeds different types of prompts, returning both sparse and dense 148 | embeddings. 149 | 150 | Arguments: 151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 152 | and labels to embed. 153 | boxes (torch.Tensor or none): boxes to embed 154 | masks (torch.Tensor or none): masks to embed 155 | 156 | Returns: 157 | torch.Tensor: sparse embeddings for the points and boxes, with shape 158 | BxNx(embed_dim), where N is determined by the number of input points 159 | and boxes. 160 | torch.Tensor: dense embeddings for the masks, in the shape 161 | Bx(embed_dim)x(embed_H)x(embed_W) 162 | """ 163 | bs = self._get_batch_size(points, boxes, masks) 164 | sparse_embeddings = torch.empty( 165 | (bs, 0, self.embed_dim), device=self._get_device() 166 | ) 167 | if points is not None: 168 | coords, labels = points 169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 171 | if boxes is not None: 172 | box_embeddings = self._embed_boxes(boxes) 173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 174 | 175 | if masks is not None: 176 | dense_embeddings = self._embed_masks(masks) 177 | else: 178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 180 | ) 181 | 182 | return sparse_embeddings, dense_embeddings 183 | -------------------------------------------------------------------------------- /sam2/modeling/sam/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import math 9 | import warnings 10 | from functools import partial 11 | from typing import Tuple, Type 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis 18 | from sam2.modeling.sam2_utils import MLP 19 | from sam2.utils.misc import get_sdpa_settings 20 | 21 | warnings.simplefilter(action="ignore", category=FutureWarning) 22 | # Check whether Flash Attention is available (and use it by default) 23 | OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() 24 | # A fallback setting to allow all available kernels if Flash Attention fails 25 | ALLOW_ALL_KERNELS = False 26 | 27 | 28 | def sdp_kernel_context(dropout_p): 29 | """ 30 | Get the context for the attention scaled dot-product kernel. We use Flash Attention 31 | by default, but fall back to all available kernels if Flash Attention fails. 32 | """ 33 | if ALLOW_ALL_KERNELS: 34 | return contextlib.nullcontext() 35 | 36 | return torch.backends.cuda.sdp_kernel( 37 | enable_flash=USE_FLASH_ATTN, 38 | # if Flash attention kernel is off, then math kernel needs to be enabled 39 | enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, 40 | enable_mem_efficient=OLD_GPU, 41 | ) 42 | 43 | 44 | class TwoWayTransformer(nn.Module): 45 | def __init__( 46 | self, 47 | depth: int, 48 | embedding_dim: int, 49 | num_heads: int, 50 | mlp_dim: int, 51 | activation: Type[nn.Module] = nn.ReLU, 52 | attention_downsample_rate: int = 2, 53 | ) -> None: 54 | """ 55 | A transformer decoder that attends to an input image using 56 | queries whose positional embedding is supplied. 57 | 58 | Args: 59 | depth (int): number of layers in the transformer 60 | embedding_dim (int): the channel dimension for the input embeddings 61 | num_heads (int): the number of heads for multihead attention. Must 62 | divide embedding_dim 63 | mlp_dim (int): the channel dimension internal to the MLP block 64 | activation (nn.Module): the activation to use in the MLP block 65 | """ 66 | super().__init__() 67 | self.depth = depth 68 | self.embedding_dim = embedding_dim 69 | self.num_heads = num_heads 70 | self.mlp_dim = mlp_dim 71 | self.layers = nn.ModuleList() 72 | 73 | for i in range(depth): 74 | self.layers.append( 75 | TwoWayAttentionBlock( 76 | embedding_dim=embedding_dim, 77 | num_heads=num_heads, 78 | mlp_dim=mlp_dim, 79 | activation=activation, 80 | attention_downsample_rate=attention_downsample_rate, 81 | skip_first_layer_pe=(i == 0), 82 | ) 83 | ) 84 | 85 | self.final_attn_token_to_image = Attention( 86 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 87 | ) 88 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 89 | 90 | def forward( 91 | self, 92 | image_embedding: Tensor, 93 | image_pe: Tensor, 94 | point_embedding: Tensor, 95 | ) -> Tuple[Tensor, Tensor]: 96 | """ 97 | Args: 98 | image_embedding (torch.Tensor): image to attend to. Should be shape 99 | B x embedding_dim x h x w for any h and w. 100 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 101 | have the same shape as image_embedding. 102 | point_embedding (torch.Tensor): the embedding to add to the query points. 103 | Must have shape B x N_points x embedding_dim for any N_points. 104 | 105 | Returns: 106 | torch.Tensor: the processed point_embedding 107 | torch.Tensor: the processed image_embedding 108 | """ 109 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 110 | bs, c, h, w = image_embedding.shape 111 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 112 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 113 | 114 | # Prepare queries 115 | queries = point_embedding 116 | keys = image_embedding 117 | 118 | # Apply transformer blocks and final layernorm 119 | for layer in self.layers: 120 | queries, keys = layer( 121 | queries=queries, 122 | keys=keys, 123 | query_pe=point_embedding, 124 | key_pe=image_pe, 125 | ) 126 | 127 | # Apply the final attention layer from the points to the image 128 | q = queries + point_embedding 129 | k = keys + image_pe 130 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 131 | queries = queries + attn_out 132 | queries = self.norm_final_attn(queries) 133 | 134 | return queries, keys 135 | 136 | 137 | class TwoWayAttentionBlock(nn.Module): 138 | def __init__( 139 | self, 140 | embedding_dim: int, 141 | num_heads: int, 142 | mlp_dim: int = 2048, 143 | activation: Type[nn.Module] = nn.ReLU, 144 | attention_downsample_rate: int = 2, 145 | skip_first_layer_pe: bool = False, 146 | ) -> None: 147 | """ 148 | A transformer block with four layers: (1) self-attention of sparse 149 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 150 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 151 | inputs. 152 | 153 | Arguments: 154 | embedding_dim (int): the channel dimension of the embeddings 155 | num_heads (int): the number of heads in the attention layers 156 | mlp_dim (int): the hidden dimension of the mlp block 157 | activation (nn.Module): the activation of the mlp block 158 | skip_first_layer_pe (bool): skip the PE on the first layer 159 | """ 160 | super().__init__() 161 | self.self_attn = Attention(embedding_dim, num_heads) 162 | self.norm1 = nn.LayerNorm(embedding_dim) 163 | 164 | self.cross_attn_token_to_image = Attention( 165 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 166 | ) 167 | self.norm2 = nn.LayerNorm(embedding_dim) 168 | 169 | self.mlp = MLP( 170 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation 171 | ) 172 | self.norm3 = nn.LayerNorm(embedding_dim) 173 | 174 | self.norm4 = nn.LayerNorm(embedding_dim) 175 | self.cross_attn_image_to_token = Attention( 176 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 177 | ) 178 | 179 | self.skip_first_layer_pe = skip_first_layer_pe 180 | 181 | def forward( 182 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 183 | ) -> Tuple[Tensor, Tensor]: 184 | # Self attention block 185 | if self.skip_first_layer_pe: 186 | queries = self.self_attn(q=queries, k=queries, v=queries) 187 | else: 188 | q = queries + query_pe 189 | attn_out = self.self_attn(q=q, k=q, v=queries) 190 | queries = queries + attn_out 191 | queries = self.norm1(queries) 192 | 193 | # Cross attention block, tokens attending to image embedding 194 | q = queries + query_pe 195 | k = keys + key_pe 196 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 197 | queries = queries + attn_out 198 | queries = self.norm2(queries) 199 | 200 | # MLP block 201 | mlp_out = self.mlp(queries) 202 | queries = queries + mlp_out 203 | queries = self.norm3(queries) 204 | 205 | # Cross attention block, image embedding attending to tokens 206 | q = queries + query_pe 207 | k = keys + key_pe 208 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 209 | keys = keys + attn_out 210 | keys = self.norm4(keys) 211 | 212 | return queries, keys 213 | 214 | 215 | class Attention(nn.Module): 216 | """ 217 | An attention layer that allows for downscaling the size of the embedding 218 | after projection to queries, keys, and values. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | embedding_dim: int, 224 | num_heads: int, 225 | downsample_rate: int = 1, 226 | dropout: float = 0.0, 227 | kv_in_dim: int = None, 228 | ) -> None: 229 | super().__init__() 230 | self.embedding_dim = embedding_dim 231 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim 232 | self.internal_dim = embedding_dim // downsample_rate 233 | self.num_heads = num_heads 234 | assert ( 235 | self.internal_dim % num_heads == 0 236 | ), "num_heads must divide embedding_dim." 237 | 238 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 239 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 240 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 241 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 242 | 243 | self.dropout_p = dropout 244 | 245 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 246 | b, n, c = x.shape 247 | x = x.reshape(b, n, num_heads, c // num_heads) 248 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 249 | 250 | def _recombine_heads(self, x: Tensor) -> Tensor: 251 | b, n_heads, n_tokens, c_per_head = x.shape 252 | x = x.transpose(1, 2) 253 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 254 | 255 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 256 | # Input projections 257 | q = self.q_proj(q) 258 | k = self.k_proj(k) 259 | v = self.v_proj(v) 260 | 261 | # Separate into heads 262 | q = self._separate_heads(q, self.num_heads) 263 | k = self._separate_heads(k, self.num_heads) 264 | v = self._separate_heads(v, self.num_heads) 265 | 266 | dropout_p = self.dropout_p if self.training else 0.0 267 | # Attention 268 | try: 269 | with sdp_kernel_context(dropout_p): 270 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 271 | except Exception as e: 272 | # Fall back to all kernels if the Flash attention kernel fails 273 | warnings.warn( 274 | f"Flash Attention kernel failed due to: {e}\nFalling back to all available " 275 | f"kernels for scaled_dot_product_attention (which may have a slower speed).", 276 | category=UserWarning, 277 | stacklevel=2, 278 | ) 279 | global ALLOW_ALL_KERNELS 280 | ALLOW_ALL_KERNELS = True 281 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 282 | 283 | out = self._recombine_heads(out) 284 | out = self.out_proj(out) 285 | 286 | return out 287 | 288 | 289 | class RoPEAttention(Attention): 290 | """Attention with rotary position encoding.""" 291 | 292 | def __init__( 293 | self, 294 | *args, 295 | rope_theta=10000.0, 296 | # whether to repeat q rope to match k length 297 | # this is needed for cross-attention to memories 298 | rope_k_repeat=False, 299 | feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution 300 | **kwargs, 301 | ): 302 | super().__init__(*args, **kwargs) 303 | 304 | self.compute_cis = partial( 305 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta 306 | ) 307 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) 308 | self.freqs_cis = freqs_cis 309 | self.rope_k_repeat = rope_k_repeat 310 | 311 | def forward( 312 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 313 | ) -> Tensor: 314 | # Input projections 315 | q = self.q_proj(q) 316 | k = self.k_proj(k) 317 | v = self.v_proj(v) 318 | 319 | # Separate into heads 320 | q = self._separate_heads(q, self.num_heads) 321 | k = self._separate_heads(k, self.num_heads) 322 | v = self._separate_heads(v, self.num_heads) 323 | 324 | # Apply rotary position encoding 325 | w = h = math.sqrt(q.shape[-2]) 326 | self.freqs_cis = self.freqs_cis.to(q.device) 327 | if self.freqs_cis.shape[0] != q.shape[-2]: 328 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) 329 | if q.shape[-2] != k.shape[-2]: 330 | assert self.rope_k_repeat 331 | 332 | num_k_rope = k.size(-2) - num_k_exclude_rope 333 | q, k[:, :, :num_k_rope] = apply_rotary_enc( 334 | q, 335 | k[:, :, :num_k_rope], 336 | freqs_cis=self.freqs_cis, 337 | repeat_freqs_k=self.rope_k_repeat, 338 | ) 339 | 340 | dropout_p = self.dropout_p if self.training else 0.0 341 | # Attention 342 | try: 343 | with sdp_kernel_context(dropout_p): 344 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 345 | except Exception as e: 346 | # Fall back to all kernels if the Flash attention kernel fails 347 | warnings.warn( 348 | f"Flash Attention kernel failed due to: {e}\nFalling back to all available " 349 | f"kernels for scaled_dot_product_attention (which may have a slower speed).", 350 | category=UserWarning, 351 | stacklevel=2, 352 | ) 353 | global ALLOW_ALL_KERNELS 354 | ALLOW_ALL_KERNELS = True 355 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 356 | 357 | out = self._recombine_heads(out) 358 | out = self.out_proj(out) 359 | 360 | return out 361 | -------------------------------------------------------------------------------- /sam2/modeling/sam2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 16 | """ 17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 18 | that are temporally closest to the current frame at `frame_idx`. Here, we take 19 | - a) the closest conditioning frame before `frame_idx` (if any); 20 | - b) the closest conditioning frame after `frame_idx` (if any); 21 | - c) any other temporally closest conditioning frames until reaching a total 22 | of `max_cond_frame_num` conditioning frames. 23 | 24 | Outputs: 25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 27 | """ 28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 29 | selected_outputs = cond_frame_outputs 30 | unselected_outputs = {} 31 | else: 32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 33 | selected_outputs = {} 34 | 35 | # the closest conditioning frame before `frame_idx` (if any) 36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) 37 | if idx_before is not None: 38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before] 39 | 40 | # the closest conditioning frame after `frame_idx` (if any) 41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) 42 | if idx_after is not None: 43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after] 44 | 45 | # add other temporally closest conditioning frames until reaching a total 46 | # of `max_cond_frame_num` conditioning frames. 47 | num_remain = max_cond_frame_num - len(selected_outputs) 48 | inds_remain = sorted( 49 | (t for t in cond_frame_outputs if t not in selected_outputs), 50 | key=lambda x: abs(x - frame_idx), 51 | )[:num_remain] 52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 53 | unselected_outputs = { 54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 55 | } 56 | 57 | return selected_outputs, unselected_outputs 58 | 59 | 60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000): 61 | """ 62 | Get 1D sine positional embedding as in the original Transformer paper. 63 | """ 64 | pe_dim = dim // 2 65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) 66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) 67 | 68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t 69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) 70 | return pos_embed 71 | 72 | 73 | def get_activation_fn(activation): 74 | """Return an activation function given a string""" 75 | if activation == "relu": 76 | return F.relu 77 | if activation == "gelu": 78 | return F.gelu 79 | if activation == "glu": 80 | return F.glu 81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 82 | 83 | 84 | def get_clones(module, N): 85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 86 | 87 | 88 | class DropPath(nn.Module): 89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py 90 | def __init__(self, drop_prob=0.0, scale_by_keep=True): 91 | super(DropPath, self).__init__() 92 | self.drop_prob = drop_prob 93 | self.scale_by_keep = scale_by_keep 94 | 95 | def forward(self, x): 96 | if self.drop_prob == 0.0 or not self.training: 97 | return x 98 | keep_prob = 1 - self.drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and self.scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | # Lightly adapted from 107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 108 | class MLP(nn.Module): 109 | def __init__( 110 | self, 111 | input_dim: int, 112 | hidden_dim: int, 113 | output_dim: int, 114 | num_layers: int, 115 | activation: nn.Module = nn.ReLU, 116 | sigmoid_output: bool = False, 117 | ) -> None: 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList( 122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 123 | ) 124 | self.sigmoid_output = sigmoid_output 125 | self.act = activation() 126 | 127 | def forward(self, x): 128 | for i, layer in enumerate(self.layers): 129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 130 | if self.sigmoid_output: 131 | x = F.sigmoid(x) 132 | return x 133 | 134 | 135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 137 | class LayerNorm2d(nn.Module): 138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 139 | super().__init__() 140 | self.weight = nn.Parameter(torch.ones(num_channels)) 141 | self.bias = nn.Parameter(torch.zeros(num_channels)) 142 | self.eps = eps 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | return x 150 | -------------------------------------------------------------------------------- /sam2/sam2_image_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from typing import List, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | from PIL.Image import Image 14 | 15 | from sam2.modeling.sam2_base import SAM2Base 16 | 17 | from sam2.utils.transforms import SAM2Transforms 18 | 19 | 20 | class SAM2ImagePredictor: 21 | def __init__( 22 | self, 23 | sam_model: SAM2Base, 24 | mask_threshold=0.0, 25 | max_hole_area=0.0, 26 | max_sprinkle_area=0.0, 27 | **kwargs, 28 | ) -> None: 29 | """ 30 | Uses SAM-2 to calculate the image embedding for an image, and then 31 | allow repeated, efficient mask prediction given prompts. 32 | 33 | Arguments: 34 | sam_model (Sam-2): The model to use for mask prediction. 35 | mask_threshold (float): The threshold to use when converting mask logits 36 | to binary masks. Masks are thresholded at 0 by default. 37 | max_hole_area (int): If max_hole_area > 0, we fill small holes in up to 38 | the maximum area of max_hole_area in low_res_masks. 39 | max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to 40 | the maximum area of max_sprinkle_area in low_res_masks. 41 | """ 42 | super().__init__() 43 | self.model = sam_model 44 | self._transforms = SAM2Transforms( 45 | resolution=self.model.image_size, 46 | mask_threshold=mask_threshold, 47 | max_hole_area=max_hole_area, 48 | max_sprinkle_area=max_sprinkle_area, 49 | ) 50 | 51 | # Predictor state 52 | self._is_image_set = False 53 | self._features = None 54 | self._orig_hw = None 55 | # Whether the predictor is set for single image or a batch of images 56 | self._is_batch = False 57 | 58 | # Predictor config 59 | self.mask_threshold = mask_threshold 60 | 61 | # Spatial dim for backbone feature maps 62 | self._bb_feat_sizes = [ 63 | (256, 256), 64 | (128, 128), 65 | (64, 64), 66 | ] 67 | 68 | @classmethod 69 | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": 70 | """ 71 | Load a pretrained model from the Hugging Face hub. 72 | 73 | Arguments: 74 | model_id (str): The Hugging Face repository ID. 75 | **kwargs: Additional arguments to pass to the model constructor. 76 | 77 | Returns: 78 | (SAM2ImagePredictor): The loaded model. 79 | """ 80 | from sam2.build_sam import build_sam2_hf 81 | 82 | sam_model = build_sam2_hf(model_id, **kwargs) 83 | return cls(sam_model, **kwargs) 84 | 85 | @torch.no_grad() 86 | def set_image( 87 | self, 88 | image: Union[np.ndarray, Image], 89 | ) -> None: 90 | """ 91 | Calculates the image embeddings for the provided image, allowing 92 | masks to be predicted with the 'predict' method. 93 | 94 | Arguments: 95 | image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image 96 | with pixel values in [0, 255]. 97 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 98 | """ 99 | self.reset_predictor() 100 | # Transform the image to the form expected by the model 101 | if isinstance(image, np.ndarray): 102 | logging.info("For numpy array image, we assume (HxWxC) format") 103 | self._orig_hw = [image.shape[:2]] 104 | elif isinstance(image, Image): 105 | w, h = image.size 106 | self._orig_hw = [(h, w)] 107 | else: 108 | raise NotImplementedError("Image format not supported") 109 | 110 | input_image = self._transforms(image) 111 | input_image = input_image[None, ...].to(self.device) 112 | 113 | assert ( 114 | len(input_image.shape) == 4 and input_image.shape[1] == 3 115 | ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" 116 | logging.info("Computing image embeddings for the provided image...") 117 | backbone_out = self.model.forward_image(input_image) 118 | _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) 119 | # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos 120 | if self.model.directly_add_no_mem_embed: 121 | vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed 122 | 123 | feats = [ 124 | feat.permute(1, 2, 0).view(1, -1, *feat_size) 125 | for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) 126 | ][::-1] 127 | self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} 128 | self._is_image_set = True 129 | logging.info("Image embeddings computed.") 130 | 131 | @torch.no_grad() 132 | def set_image_batch( 133 | self, 134 | image_list: List[Union[np.ndarray]], 135 | ) -> None: 136 | """ 137 | Calculates the image embeddings for the provided image batch, allowing 138 | masks to be predicted with the 'predict_batch' method. 139 | 140 | Arguments: 141 | image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray 142 | with pixel values in [0, 255]. 143 | """ 144 | self.reset_predictor() 145 | assert isinstance(image_list, list) 146 | self._orig_hw = [] 147 | for image in image_list: 148 | assert isinstance( 149 | image, np.ndarray 150 | ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" 151 | self._orig_hw.append(image.shape[:2]) 152 | # Transform the image to the form expected by the model 153 | img_batch = self._transforms.forward_batch(image_list) 154 | img_batch = img_batch.to(self.device) 155 | batch_size = img_batch.shape[0] 156 | assert ( 157 | len(img_batch.shape) == 4 and img_batch.shape[1] == 3 158 | ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" 159 | logging.info("Computing image embeddings for the provided images...") 160 | backbone_out = self.model.forward_image(img_batch) 161 | _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) 162 | # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos 163 | if self.model.directly_add_no_mem_embed: 164 | vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed 165 | 166 | feats = [ 167 | feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) 168 | for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) 169 | ][::-1] 170 | self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} 171 | self._is_image_set = True 172 | self._is_batch = True 173 | logging.info("Image embeddings computed.") 174 | 175 | def predict_batch( 176 | self, 177 | point_coords_batch: List[np.ndarray] = None, 178 | point_labels_batch: List[np.ndarray] = None, 179 | box_batch: List[np.ndarray] = None, 180 | mask_input_batch: List[np.ndarray] = None, 181 | multimask_output: bool = True, 182 | return_logits: bool = False, 183 | normalize_coords=True, 184 | ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: 185 | """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. 186 | It returns a tuple of lists of masks, ious, and low_res_masks_logits. 187 | """ 188 | assert self._is_batch, "This function should only be used when in batched mode" 189 | if not self._is_image_set: 190 | raise RuntimeError( 191 | "An image must be set with .set_image_batch(...) before mask prediction." 192 | ) 193 | num_images = len(self._features["image_embed"]) 194 | all_masks = [] 195 | all_ious = [] 196 | all_low_res_masks = [] 197 | for img_idx in range(num_images): 198 | # Transform input prompts 199 | point_coords = ( 200 | point_coords_batch[img_idx] if point_coords_batch is not None else None 201 | ) 202 | point_labels = ( 203 | point_labels_batch[img_idx] if point_labels_batch is not None else None 204 | ) 205 | box = box_batch[img_idx] if box_batch is not None else None 206 | mask_input = ( 207 | mask_input_batch[img_idx] if mask_input_batch is not None else None 208 | ) 209 | mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( 210 | point_coords, 211 | point_labels, 212 | box, 213 | mask_input, 214 | normalize_coords, 215 | img_idx=img_idx, 216 | ) 217 | masks, iou_predictions, low_res_masks = self._predict( 218 | unnorm_coords, 219 | labels, 220 | unnorm_box, 221 | mask_input, 222 | multimask_output, 223 | return_logits=return_logits, 224 | img_idx=img_idx, 225 | ) 226 | masks_np = masks.squeeze(0).float().detach().cpu().numpy() 227 | iou_predictions_np = ( 228 | iou_predictions.squeeze(0).float().detach().cpu().numpy() 229 | ) 230 | low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() 231 | all_masks.append(masks_np) 232 | all_ious.append(iou_predictions_np) 233 | all_low_res_masks.append(low_res_masks_np) 234 | 235 | return all_masks, all_ious, all_low_res_masks 236 | 237 | def predict( 238 | self, 239 | point_coords: Optional[np.ndarray] = None, 240 | point_labels: Optional[np.ndarray] = None, 241 | box: Optional[np.ndarray] = None, 242 | mask_input: Optional[np.ndarray] = None, 243 | multimask_output: bool = True, 244 | return_logits: bool = False, 245 | normalize_coords=True, 246 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 247 | """ 248 | Predict masks for the given input prompts, using the currently set image. 249 | 250 | Arguments: 251 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 252 | model. Each point is in (X,Y) in pixels. 253 | point_labels (np.ndarray or None): A length N array of labels for the 254 | point prompts. 1 indicates a foreground point and 0 indicates a 255 | background point. 256 | box (np.ndarray or None): A length 4 array given a box prompt to the 257 | model, in XYXY format. 258 | mask_input (np.ndarray): A low resolution mask input to the model, typically 259 | coming from a previous prediction iteration. Has form 1xHxW, where 260 | for SAM, H=W=256. 261 | multimask_output (bool): If true, the model will return three masks. 262 | For ambiguous input prompts (such as a single click), this will often 263 | produce better masks than a single prediction. If only a single 264 | mask is needed, the model's predicted quality score can be used 265 | to select the best mask. For non-ambiguous prompts, such as multiple 266 | input prompts, multimask_output=False can give better results. 267 | return_logits (bool): If true, returns un-thresholded masks logits 268 | instead of a binary mask. 269 | normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. 270 | 271 | Returns: 272 | (np.ndarray): The output masks in CxHxW format, where C is the 273 | number of masks, and (H, W) is the original image size. 274 | (np.ndarray): An array of length C containing the model's 275 | predictions for the quality of each mask. 276 | (np.ndarray): An array of shape CxHxW, where C is the number 277 | of masks and H=W=256. These low resolution logits can be passed to 278 | a subsequent iteration as mask input. 279 | """ 280 | if not self._is_image_set: 281 | raise RuntimeError( 282 | "An image must be set with .set_image(...) before mask prediction." 283 | ) 284 | 285 | # Transform input prompts 286 | 287 | mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( 288 | point_coords, point_labels, box, mask_input, normalize_coords 289 | ) 290 | 291 | masks, iou_predictions, low_res_masks = self._predict( 292 | unnorm_coords, 293 | labels, 294 | unnorm_box, 295 | mask_input, 296 | multimask_output, 297 | return_logits=return_logits, 298 | ) 299 | 300 | masks_np = masks.squeeze(0).float().detach().cpu().numpy() 301 | iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() 302 | low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() 303 | return masks_np, iou_predictions_np, low_res_masks_np 304 | 305 | def _prep_prompts( 306 | self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 307 | ): 308 | 309 | unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None 310 | if point_coords is not None: 311 | assert ( 312 | point_labels is not None 313 | ), "point_labels must be supplied if point_coords is supplied." 314 | point_coords = torch.as_tensor( 315 | point_coords, dtype=torch.float, device=self.device 316 | ) 317 | unnorm_coords = self._transforms.transform_coords( 318 | point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] 319 | ) 320 | labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 321 | if len(unnorm_coords.shape) == 2: 322 | unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] 323 | if box is not None: 324 | box = torch.as_tensor(box, dtype=torch.float, device=self.device) 325 | unnorm_box = self._transforms.transform_boxes( 326 | box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] 327 | ) # Bx2x2 328 | if mask_logits is not None: 329 | mask_input = torch.as_tensor( 330 | mask_logits, dtype=torch.float, device=self.device 331 | ) 332 | if len(mask_input.shape) == 3: 333 | mask_input = mask_input[None, :, :, :] 334 | return mask_input, unnorm_coords, labels, unnorm_box 335 | 336 | @torch.no_grad() 337 | def _predict( 338 | self, 339 | point_coords: Optional[torch.Tensor], 340 | point_labels: Optional[torch.Tensor], 341 | boxes: Optional[torch.Tensor] = None, 342 | mask_input: Optional[torch.Tensor] = None, 343 | multimask_output: bool = True, 344 | return_logits: bool = False, 345 | img_idx: int = -1, 346 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 347 | """ 348 | Predict masks for the given input prompts, using the currently set image. 349 | Input prompts are batched torch tensors and are expected to already be 350 | transformed to the input frame using SAM2Transforms. 351 | 352 | Arguments: 353 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 354 | model. Each point is in (X,Y) in pixels. 355 | point_labels (torch.Tensor or None): A BxN array of labels for the 356 | point prompts. 1 indicates a foreground point and 0 indicates a 357 | background point. 358 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 359 | model, in XYXY format. 360 | mask_input (np.ndarray): A low resolution mask input to the model, typically 361 | coming from a previous prediction iteration. Has form Bx1xHxW, where 362 | for SAM, H=W=256. Masks returned by a previous iteration of the 363 | predict method do not need further transformation. 364 | multimask_output (bool): If true, the model will return three masks. 365 | For ambiguous input prompts (such as a single click), this will often 366 | produce better masks than a single prediction. If only a single 367 | mask is needed, the model's predicted quality score can be used 368 | to select the best mask. For non-ambiguous prompts, such as multiple 369 | input prompts, multimask_output=False can give better results. 370 | return_logits (bool): If true, returns un-thresholded masks logits 371 | instead of a binary mask. 372 | 373 | Returns: 374 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 375 | number of masks, and (H, W) is the original image size. 376 | (torch.Tensor): An array of shape BxC containing the model's 377 | predictions for the quality of each mask. 378 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 379 | of masks and H=W=256. These low res logits can be passed to 380 | a subsequent iteration as mask input. 381 | """ 382 | if not self._is_image_set: 383 | raise RuntimeError( 384 | "An image must be set with .set_image(...) before mask prediction." 385 | ) 386 | 387 | if point_coords is not None: 388 | concat_points = (point_coords, point_labels) 389 | else: 390 | concat_points = None 391 | 392 | # Embed prompts 393 | if boxes is not None: 394 | box_coords = boxes.reshape(-1, 2, 2) 395 | box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) 396 | box_labels = box_labels.repeat(boxes.size(0), 1) 397 | # we merge "boxes" and "points" into a single "concat_points" input (where 398 | # boxes are added at the beginning) to sam_prompt_encoder 399 | if concat_points is not None: 400 | concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) 401 | concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) 402 | concat_points = (concat_coords, concat_labels) 403 | else: 404 | concat_points = (box_coords, box_labels) 405 | 406 | sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( 407 | points=concat_points, 408 | boxes=None, 409 | masks=mask_input, 410 | ) 411 | 412 | # Predict masks 413 | batched_mode = ( 414 | concat_points is not None and concat_points[0].shape[0] > 1 415 | ) # multi object prediction 416 | high_res_features = [ 417 | feat_level[img_idx].unsqueeze(0) 418 | for feat_level in self._features["high_res_feats"] 419 | ] 420 | low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( 421 | image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), 422 | image_pe=self.model.sam_prompt_encoder.get_dense_pe(), 423 | sparse_prompt_embeddings=sparse_embeddings, 424 | dense_prompt_embeddings=dense_embeddings, 425 | multimask_output=multimask_output, 426 | repeat_image=batched_mode, 427 | high_res_features=high_res_features, 428 | ) 429 | 430 | # Upscale the masks to the original image resolution 431 | masks = self._transforms.postprocess_masks( 432 | low_res_masks, self._orig_hw[img_idx] 433 | ) 434 | low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) 435 | if not return_logits: 436 | masks = masks > self.mask_threshold 437 | 438 | return masks, iou_predictions, low_res_masks 439 | 440 | def get_image_embedding(self) -> torch.Tensor: 441 | """ 442 | Returns the image embeddings for the currently set image, with 443 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 444 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 445 | """ 446 | if not self._is_image_set: 447 | raise RuntimeError( 448 | "An image must be set with .set_image(...) to generate an embedding." 449 | ) 450 | assert ( 451 | self._features is not None 452 | ), "Features must exist if an image has been set." 453 | return self._features["image_embed"] 454 | 455 | @property 456 | def device(self) -> torch.device: 457 | return self.model.device 458 | 459 | def reset_predictor(self) -> None: 460 | """ 461 | Resets the image embeddings and other state variables. 462 | """ 463 | self._is_image_set = False 464 | self._features = None 465 | self._orig_hw = None 466 | self._is_batch = False 467 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from copy import deepcopy 9 | from itertools import product 10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py 16 | 17 | 18 | class MaskData: 19 | """ 20 | A structure for storing masks and their related data in batched format. 21 | Implements basic filtering and concatenation. 22 | """ 23 | 24 | def __init__(self, **kwargs) -> None: 25 | for v in kwargs.values(): 26 | assert isinstance( 27 | v, (list, np.ndarray, torch.Tensor) 28 | ), "MaskData only supports list, numpy arrays, and torch tensors." 29 | self._stats = dict(**kwargs) 30 | 31 | def __setitem__(self, key: str, item: Any) -> None: 32 | assert isinstance( 33 | item, (list, np.ndarray, torch.Tensor) 34 | ), "MaskData only supports list, numpy arrays, and torch tensors." 35 | self._stats[key] = item 36 | 37 | def __delitem__(self, key: str) -> None: 38 | del self._stats[key] 39 | 40 | def __getitem__(self, key: str) -> Any: 41 | return self._stats[key] 42 | 43 | def items(self) -> ItemsView[str, Any]: 44 | return self._stats.items() 45 | 46 | def filter(self, keep: torch.Tensor) -> None: 47 | for k, v in self._stats.items(): 48 | if v is None: 49 | self._stats[k] = None 50 | elif isinstance(v, torch.Tensor): 51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 52 | elif isinstance(v, np.ndarray): 53 | self._stats[k] = v[keep.detach().cpu().numpy()] 54 | elif isinstance(v, list) and keep.dtype == torch.bool: 55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 56 | elif isinstance(v, list): 57 | self._stats[k] = [v[i] for i in keep] 58 | else: 59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 60 | 61 | def cat(self, new_stats: "MaskData") -> None: 62 | for k, v in new_stats.items(): 63 | if k not in self._stats or self._stats[k] is None: 64 | self._stats[k] = deepcopy(v) 65 | elif isinstance(v, torch.Tensor): 66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 67 | elif isinstance(v, np.ndarray): 68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 69 | elif isinstance(v, list): 70 | self._stats[k] = self._stats[k] + deepcopy(v) 71 | else: 72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 73 | 74 | def to_numpy(self) -> None: 75 | for k, v in self._stats.items(): 76 | if isinstance(v, torch.Tensor): 77 | self._stats[k] = v.float().detach().cpu().numpy() 78 | 79 | 80 | def is_box_near_crop_edge( 81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 82 | ) -> torch.Tensor: 83 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 90 | return torch.any(near_crop_edge, dim=1) 91 | 92 | 93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 94 | box_xywh = deepcopy(box_xyxy) 95 | box_xywh[2] = box_xywh[2] - box_xywh[0] 96 | box_xywh[3] = box_xywh[3] - box_xywh[1] 97 | return box_xywh 98 | 99 | 100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 101 | assert len(args) > 0 and all( 102 | len(a) == len(args[0]) for a in args 103 | ), "Batched iteration must have inputs of all the same size." 104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 105 | for b in range(n_batches): 106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 107 | 108 | 109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 110 | """ 111 | Encodes masks to an uncompressed RLE, in the format expected by 112 | pycoco tools. 113 | """ 114 | # Put in fortran order and flatten h,w 115 | b, h, w = tensor.shape 116 | tensor = tensor.permute(0, 2, 1).flatten(1) 117 | 118 | # Compute change indices 119 | diff = tensor[:, 1:] ^ tensor[:, :-1] 120 | change_indices = diff.nonzero() 121 | 122 | # Encode run length 123 | out = [] 124 | for i in range(b): 125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 126 | cur_idxs = torch.cat( 127 | [ 128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | cur_idxs + 1, 130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 131 | ] 132 | ) 133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 134 | counts = [] if tensor[i, 0] == 0 else [0] 135 | counts.extend(btw_idxs.detach().cpu().tolist()) 136 | out.append({"size": [h, w], "counts": counts}) 137 | return out 138 | 139 | 140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 141 | """Compute a binary mask from an uncompressed RLE.""" 142 | h, w = rle["size"] 143 | mask = np.empty(h * w, dtype=bool) 144 | idx = 0 145 | parity = False 146 | for count in rle["counts"]: 147 | mask[idx : idx + count] = parity 148 | idx += count 149 | parity ^= True 150 | mask = mask.reshape(w, h) 151 | return mask.transpose() # Put in C order 152 | 153 | 154 | def area_from_rle(rle: Dict[str, Any]) -> int: 155 | return sum(rle["counts"][1::2]) 156 | 157 | 158 | def calculate_stability_score( 159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 160 | ) -> torch.Tensor: 161 | """ 162 | Computes the stability score for a batch of masks. The stability 163 | score is the IoU between the binary masks obtained by thresholding 164 | the predicted mask logits at high and low values. 165 | """ 166 | # One mask is always contained inside the other. 167 | # Save memory by preventing unnecessary cast to torch.int64 168 | intersections = ( 169 | (masks > (mask_threshold + threshold_offset)) 170 | .sum(-1, dtype=torch.int16) 171 | .sum(-1, dtype=torch.int32) 172 | ) 173 | unions = ( 174 | (masks > (mask_threshold - threshold_offset)) 175 | .sum(-1, dtype=torch.int16) 176 | .sum(-1, dtype=torch.int32) 177 | ) 178 | return intersections / unions 179 | 180 | 181 | def build_point_grid(n_per_side: int) -> np.ndarray: 182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 183 | offset = 1 / (2 * n_per_side) 184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 188 | return points 189 | 190 | 191 | def build_all_layer_point_grids( 192 | n_per_side: int, n_layers: int, scale_per_layer: int 193 | ) -> List[np.ndarray]: 194 | """Generates point grids for all crop layers.""" 195 | points_by_layer = [] 196 | for i in range(n_layers + 1): 197 | n_points = int(n_per_side / (scale_per_layer**i)) 198 | points_by_layer.append(build_point_grid(n_points)) 199 | return points_by_layer 200 | 201 | 202 | def generate_crop_boxes( 203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 204 | ) -> Tuple[List[List[int]], List[int]]: 205 | """ 206 | Generates a list of crop boxes of different sizes. Each layer 207 | has (2**i)**2 boxes for the ith layer. 208 | """ 209 | crop_boxes, layer_idxs = [], [] 210 | im_h, im_w = im_size 211 | short_side = min(im_h, im_w) 212 | 213 | # Original image 214 | crop_boxes.append([0, 0, im_w, im_h]) 215 | layer_idxs.append(0) 216 | 217 | def crop_len(orig_len, n_crops, overlap): 218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 219 | 220 | for i_layer in range(n_layers): 221 | n_crops_per_side = 2 ** (i_layer + 1) 222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 223 | 224 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 225 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 226 | 227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 229 | 230 | # Crops in XYWH format 231 | for x0, y0 in product(crop_box_x0, crop_box_y0): 232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 233 | crop_boxes.append(box) 234 | layer_idxs.append(i_layer + 1) 235 | 236 | return crop_boxes, layer_idxs 237 | 238 | 239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 240 | x0, y0, _, _ = crop_box 241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 242 | # Check if boxes has a channel dimension 243 | if len(boxes.shape) == 3: 244 | offset = offset.unsqueeze(1) 245 | return boxes + offset 246 | 247 | 248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 249 | x0, y0, _, _ = crop_box 250 | offset = torch.tensor([[x0, y0]], device=points.device) 251 | # Check if points has a channel dimension 252 | if len(points.shape) == 3: 253 | offset = offset.unsqueeze(1) 254 | return points + offset 255 | 256 | 257 | def uncrop_masks( 258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 259 | ) -> torch.Tensor: 260 | x0, y0, x1, y1 = crop_box 261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 262 | return masks 263 | # Coordinate transform masks 264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 265 | pad = (x0, pad_x - x0, y0, pad_y - y0) 266 | return torch.nn.functional.pad(masks, pad, value=0) 267 | 268 | 269 | def remove_small_regions( 270 | mask: np.ndarray, area_thresh: float, mode: str 271 | ) -> Tuple[np.ndarray, bool]: 272 | """ 273 | Removes small disconnected regions and holes in a mask. Returns the 274 | mask and an indicator of if the mask has been modified. 275 | """ 276 | import cv2 # type: ignore 277 | 278 | assert mode in ["holes", "islands"] 279 | correct_holes = mode == "holes" 280 | working_mask = (correct_holes ^ mask).astype(np.uint8) 281 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 282 | sizes = stats[:, -1][1:] # Row 0 is background label 283 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 284 | if len(small_regions) == 0: 285 | return mask, False 286 | fill_labels = [0] + small_regions 287 | if not correct_holes: 288 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 289 | # If every region is below threshold, keep largest 290 | if len(fill_labels) == 0: 291 | fill_labels = [int(np.argmax(sizes)) + 1] 292 | mask = np.isin(regions, fill_labels) 293 | return mask, True 294 | 295 | 296 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 297 | from pycocotools import mask as mask_utils # type: ignore 298 | 299 | h, w = uncompressed_rle["size"] 300 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 301 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 302 | return rle 303 | 304 | 305 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 306 | """ 307 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 308 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 309 | """ 310 | # torch.max below raises an error on empty inputs, just skip in this case 311 | if torch.numel(masks) == 0: 312 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 313 | 314 | # Normalize shape to CxHxW 315 | shape = masks.shape 316 | h, w = shape[-2:] 317 | if len(shape) > 2: 318 | masks = masks.flatten(0, -3) 319 | else: 320 | masks = masks.unsqueeze(0) 321 | 322 | # Get top and bottom edges 323 | in_height, _ = torch.max(masks, dim=-1) 324 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 325 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 326 | in_height_coords = in_height_coords + h * (~in_height) 327 | top_edges, _ = torch.min(in_height_coords, dim=-1) 328 | 329 | # Get left and right edges 330 | in_width, _ = torch.max(masks, dim=-2) 331 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 332 | right_edges, _ = torch.max(in_width_coords, dim=-1) 333 | in_width_coords = in_width_coords + w * (~in_width) 334 | left_edges, _ = torch.min(in_width_coords, dim=-1) 335 | 336 | # If the mask is empty the right edge will be to the left of the left edge. 337 | # Replace these boxes with [0, 0, 0, 0] 338 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 339 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 340 | out = out * (~empty_filter).unsqueeze(-1) 341 | 342 | # Return to original shape 343 | if len(shape) > 2: 344 | out = out.reshape(*shape[:-2], 4) 345 | else: 346 | out = out[0] 347 | 348 | return out 349 | -------------------------------------------------------------------------------- /sam2/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from threading import Thread 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | 17 | def get_sdpa_settings(): 18 | if torch.cuda.is_available(): 19 | old_gpu = torch.cuda.get_device_properties(0).major < 7 20 | # only use Flash Attention on Ampere (8.0) or newer GPUs 21 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 22 | if not use_flash_attn: 23 | warnings.warn( 24 | "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", 25 | category=UserWarning, 26 | stacklevel=2, 27 | ) 28 | # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only 29 | # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) 30 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) 31 | if pytorch_version < (2, 2): 32 | warnings.warn( 33 | f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " 34 | "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", 35 | category=UserWarning, 36 | stacklevel=2, 37 | ) 38 | math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn 39 | else: 40 | old_gpu = True 41 | use_flash_attn = False 42 | math_kernel_on = True 43 | 44 | return old_gpu, use_flash_attn, math_kernel_on 45 | 46 | 47 | def get_connected_components(mask): 48 | """ 49 | Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). 50 | 51 | Inputs: 52 | - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is 53 | background. 54 | 55 | Outputs: 56 | - labels: A tensor of shape (N, 1, H, W) containing the connected component labels 57 | for foreground pixels and 0 for background pixels. 58 | - counts: A tensor of shape (N, 1, H, W) containing the area of the connected 59 | components for foreground pixels and 0 for background pixels. 60 | """ 61 | from sam2 import _C 62 | 63 | return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) 64 | 65 | 66 | def mask_to_box(masks: torch.Tensor): 67 | """ 68 | compute bounding box given an input mask 69 | 70 | Inputs: 71 | - masks: [B, 1, H, W] masks, dtype=torch.Tensor 72 | 73 | Returns: 74 | - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor 75 | """ 76 | B, _, h, w = masks.shape 77 | device = masks.device 78 | xs = torch.arange(w, device=device, dtype=torch.int32) 79 | ys = torch.arange(h, device=device, dtype=torch.int32) 80 | grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") 81 | grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) 82 | grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) 83 | min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) 84 | max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) 85 | min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) 86 | max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) 87 | bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) 88 | 89 | return bbox_coords 90 | 91 | 92 | def _load_img_as_tensor(img_path, image_size): 93 | img_pil = Image.open(img_path) 94 | img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) 95 | if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images 96 | img_np = img_np / 255.0 97 | else: 98 | raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") 99 | img = torch.from_numpy(img_np).permute(2, 0, 1) 100 | video_width, video_height = img_pil.size # the original video size 101 | return img, video_height, video_width 102 | 103 | 104 | class AsyncVideoFrameLoader: 105 | """ 106 | A list of video frames to be load asynchronously without blocking session start. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | img_paths, 112 | image_size, 113 | offload_video_to_cpu, 114 | img_mean, 115 | img_std, 116 | compute_device, 117 | ): 118 | self.img_paths = img_paths 119 | self.image_size = image_size 120 | self.offload_video_to_cpu = offload_video_to_cpu 121 | self.img_mean = img_mean 122 | self.img_std = img_std 123 | # items in `self.images` will be loaded asynchronously 124 | self.images = [None] * len(img_paths) 125 | # catch and raise any exceptions in the async loading thread 126 | self.exception = None 127 | # video_height and video_width be filled when loading the first image 128 | self.video_height = None 129 | self.video_width = None 130 | self.compute_device = compute_device 131 | 132 | # load the first frame to fill video_height and video_width and also 133 | # to cache it (since it's most likely where the user will click) 134 | self.__getitem__(0) 135 | 136 | # load the rest of frames asynchronously without blocking the session start 137 | def _load_frames(): 138 | try: 139 | for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): 140 | self.__getitem__(n) 141 | except Exception as e: 142 | self.exception = e 143 | 144 | self.thread = Thread(target=_load_frames, daemon=True) 145 | self.thread.start() 146 | 147 | def __getitem__(self, index): 148 | if self.exception is not None: 149 | raise RuntimeError("Failure in frame loading thread") from self.exception 150 | 151 | img = self.images[index] 152 | if img is not None: 153 | return img 154 | 155 | img, video_height, video_width = _load_img_as_tensor( 156 | self.img_paths[index], self.image_size 157 | ) 158 | self.video_height = video_height 159 | self.video_width = video_width 160 | # normalize by mean and std 161 | img -= self.img_mean 162 | img /= self.img_std 163 | if not self.offload_video_to_cpu: 164 | img = img.to(self.compute_device, non_blocking=True) 165 | self.images[index] = img 166 | return img 167 | 168 | def __len__(self): 169 | return len(self.images) 170 | 171 | 172 | def load_video_frames( 173 | video_path, 174 | image_size, 175 | offload_video_to_cpu, 176 | img_mean=(0.485, 0.456, 0.406), 177 | img_std=(0.229, 0.224, 0.225), 178 | async_loading_frames=False, 179 | compute_device=torch.device("cuda"), 180 | ): 181 | """ 182 | Load the video frames from a directory of JPEG files (".jpg" format). 183 | 184 | The frames are resized to image_size x image_size and are loaded to GPU if 185 | `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. 186 | 187 | You can load a frame asynchronously by setting `async_loading_frames` to `True`. 188 | """ 189 | if isinstance(video_path, str) and os.path.isdir(video_path): 190 | jpg_folder = video_path 191 | else: 192 | raise NotImplementedError( 193 | "Only JPEG frames are supported at this moment. For video files, you may use " 194 | "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" 195 | "```\n" 196 | "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" 197 | "```\n" 198 | "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " 199 | "ffmpeg to start the JPEG file from 00000.jpg." 200 | ) 201 | 202 | frame_names = [ 203 | p 204 | for p in os.listdir(jpg_folder) 205 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 206 | ] 207 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 208 | num_frames = len(frame_names) 209 | if num_frames == 0: 210 | raise RuntimeError(f"no images found in {jpg_folder}") 211 | img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] 212 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] 213 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] 214 | 215 | if async_loading_frames: 216 | lazy_images = AsyncVideoFrameLoader( 217 | img_paths, 218 | image_size, 219 | offload_video_to_cpu, 220 | img_mean, 221 | img_std, 222 | compute_device, 223 | ) 224 | return lazy_images, lazy_images.video_height, lazy_images.video_width 225 | 226 | images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) 227 | for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): 228 | images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) 229 | if not offload_video_to_cpu: 230 | images = images.to(compute_device) 231 | img_mean = img_mean.to(compute_device) 232 | img_std = img_std.to(compute_device) 233 | # normalize by mean and std 234 | images -= img_mean 235 | images /= img_std 236 | return images, video_height, video_width 237 | 238 | 239 | def fill_holes_in_mask_scores(mask, max_area): 240 | """ 241 | A post processor to fill small holes in mask scores with area under `max_area`. 242 | """ 243 | # Holes are those connected components in background with area <= self.max_area 244 | # (background regions are those with mask scores <= 0) 245 | assert max_area > 0, "max_area must be positive" 246 | 247 | input_mask = mask 248 | try: 249 | labels, areas = get_connected_components(mask <= 0) 250 | is_hole = (labels > 0) & (areas <= max_area) 251 | # We fill holes with a small positive mask score (0.1) to change them to foreground. 252 | mask = torch.where(is_hole, 0.1, mask) 253 | except Exception as e: 254 | # Skip the post-processing step on removing small holes if the CUDA kernel fails 255 | warnings.warn( 256 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 257 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 258 | "functionality may be limited (which doesn't affect the results in most cases; see " 259 | "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", 260 | category=UserWarning, 261 | stacklevel=2, 262 | ) 263 | mask = input_mask 264 | 265 | return mask 266 | 267 | 268 | def concat_points(old_point_inputs, new_points, new_labels): 269 | """Add new points and labels to previous point inputs (add at the end).""" 270 | if old_point_inputs is None: 271 | points, labels = new_points, new_labels 272 | else: 273 | points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) 274 | labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) 275 | 276 | return {"point_coords": points, "point_labels": labels} 277 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /sam2_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | # Package metadata 11 | NAME = "SAM 2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" 14 | URL = "https://github.com/facebookresearch/segment-anything-2" 15 | AUTHOR = "Meta AI" 16 | AUTHOR_EMAIL = "segment-anything@meta.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | 21 | 22 | # Required dependencies 23 | REQUIRED_PACKAGES = [ 24 | "torch>=2.3.1", 25 | "torchvision>=0.18.1", 26 | "numpy>=1.24.4", 27 | "tqdm>=4.66.1", 28 | "hydra-core>=1.3.2", 29 | "iopath>=0.1.10", 30 | "pillow>=9.4.0", 31 | ] 32 | 33 | EXTRA_PACKAGES = { 34 | "demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"], 35 | "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], 36 | } 37 | 38 | 39 | def get_extensions(): 40 | srcs = ["sam2/csrc/connected_components.cu"] 41 | compile_args = { 42 | "cxx": [], 43 | "nvcc": [ 44 | "-DCUDA_HAS_FP16=1", 45 | "-D__CUDA_NO_HALF_OPERATORS__", 46 | "-D__CUDA_NO_HALF_CONVERSIONS__", 47 | "-D__CUDA_NO_HALF2_OPERATORS__", 48 | "-allow-unsupported-compiler" 49 | ], 50 | } 51 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 52 | return ext_modules 53 | 54 | 55 | # Setup configuration 56 | setup( 57 | name=NAME, 58 | version=VERSION, 59 | description=DESCRIPTION, 60 | long_description="LONG_DESCRIPTION", 61 | long_description_content_type="text/markdown", 62 | url=URL, 63 | author=AUTHOR, 64 | author_email=AUTHOR_EMAIL, 65 | license=LICENSE, 66 | packages=find_packages(exclude="notebooks"), 67 | install_requires=REQUIRED_PACKAGES, 68 | extras_require=EXTRA_PACKAGES, 69 | python_requires=">=3.9.0", 70 | ext_modules=get_extensions(), 71 | cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, 72 | ) 73 | --------------------------------------------------------------------------------