├── .watchmanconfig ├── sam2 ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_t.yaml ├── sam2_hiera_b+.yaml ├── modeling │ ├── __init__.py │ ├── sam │ │ ├── __init__.py │ │ ├── prompt_encoder.py │ │ ├── mask_decoder.py │ │ └── transformer.py │ ├── backbones │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── image_encoder.py │ │ └── hieradet.py │ ├── memory_encoder.py │ ├── memory_attention.py │ └── position_encoding.py ├── utils │ ├── __init__.py │ ├── transforms.py │ └── amg.py ├── __init__.py ├── configs │ ├── sam2 │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_s.yaml │ │ ├── sam2_hiera_l.yaml │ │ └── sam2_hiera_t.yaml │ ├── sam2.1 │ │ ├── sam2.1_hiera_b+.yaml │ │ ├── sam2.1_hiera_s.yaml │ │ ├── sam2.1_hiera_l.yaml │ │ └── sam2.1_hiera_t.yaml │ └── sam2.1_training │ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml ├── build_sam.py └── csrc │ └── connected_components.cu ├── img ├── logo.png ├── teaser.png └── pipeline.png ├── sav_dataset ├── example │ └── sav_000001.mp4 ├── requirements.txt ├── LICENSE_VOS_BENCHMARK ├── LICENSE ├── LICENSE_DAVIS ├── sav_evaluator.py ├── utils │ └── sav_utils.py └── README.md ├── pyproject.toml ├── .gitignore ├── docker-compose.yaml ├── backend.Dockerfile ├── .clang-format ├── CODE_OF_CONDUCT.md ├── tools └── README.md ├── setup.py └── README.md /.watchmanconfig: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mark12Ding/SAM2Long/HEAD/img/logo.png -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mark12Ding/SAM2Long/HEAD/img/teaser.png -------------------------------------------------------------------------------- /img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mark12Ding/SAM2Long/HEAD/img/pipeline.png -------------------------------------------------------------------------------- /sav_dataset/example/sav_000001.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mark12Ding/SAM2Long/HEAD/sav_dataset/example/sav_000001.mp4 -------------------------------------------------------------------------------- /sav_dataset/requirements.txt: -------------------------------------------------------------------------------- 1 | pycocoevalcap 2 | scikit-image 3 | opencv-python 4 | tqdm 5 | pillow 6 | numpy 7 | matplotlib -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=61.0", 4 | "torch>=2.3.1", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | outputs/* 10 | checkpoints/*.pt 11 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE_VOS_BENCHMARK: -------------------------------------------------------------------------------- 1 | Copyright 2023 Rex Cheng 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | frontend: 3 | image: sam2/frontend 4 | build: 5 | context: ./demo/frontend 6 | dockerfile: frontend.Dockerfile 7 | ports: 8 | - 7262:80 9 | 10 | backend: 11 | image: sam2/backend 12 | build: 13 | context: . 14 | dockerfile: backend.Dockerfile 15 | ports: 16 | - 7263:5000 17 | volumes: 18 | - ./demo/data/:/data/:rw 19 | environment: 20 | - SERVER_ENVIRONMENT=DEV 21 | - GUNICORN_WORKERS=1 22 | # Inference API needs to have at least 2 threads to handle an incoming 23 | # parallel cancel propagation request 24 | - GUNICORN_THREADS=2 25 | - GUNICORN_PORT=5000 26 | - API_URL=http://localhost:7263 27 | - DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 28 | # # ffmpeg/video encode settings 29 | - FFMPEG_NUM_THREADS=1 30 | - VIDEO_ENCODE_CODEC=libx264 31 | - VIDEO_ENCODE_CRF=23 32 | - VIDEO_ENCODE_FPS=24 33 | - VIDEO_ENCODE_MAX_WIDTH=1280 34 | - VIDEO_ENCODE_MAX_HEIGHT=720 35 | - VIDEO_ENCODE_VERBOSE=False 36 | deploy: 37 | resources: 38 | reservations: 39 | devices: 40 | - driver: nvidia 41 | count: 1 42 | capabilities: [gpu] 43 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SAM 2 Eval software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE_DAVIS: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /backend.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime 2 | ARG MODEL_SIZE=base_plus 3 | 4 | FROM ${BASE_IMAGE} 5 | 6 | # Gunicorn environment variables 7 | ENV GUNICORN_WORKERS=1 8 | ENV GUNICORN_THREADS=2 9 | ENV GUNICORN_PORT=5000 10 | 11 | # SAM 2 environment variables 12 | ENV APP_ROOT=/opt/sam2 13 | ENV PYTHONUNBUFFERED=1 14 | ENV SAM2_BUILD_CUDA=0 15 | ENV MODEL_SIZE=${MODEL_SIZE} 16 | 17 | # Install system requirements 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | ffmpeg \ 20 | libavutil-dev \ 21 | libavcodec-dev \ 22 | libavformat-dev \ 23 | libswscale-dev \ 24 | pkg-config \ 25 | build-essential \ 26 | libffi-dev 27 | 28 | COPY setup.py . 29 | COPY README.md . 30 | 31 | RUN pip install --upgrade pip setuptools 32 | RUN pip install -e ".[interactive-demo]" 33 | 34 | # https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/issues/69#issuecomment-1826764707 35 | RUN rm /opt/conda/bin/ffmpeg && ln -s /bin/ffmpeg /opt/conda/bin/ffmpeg 36 | 37 | # Make app directory. This directory will host all files required for the 38 | # backend and SAM 2 inference files. 39 | RUN mkdir ${APP_ROOT} 40 | 41 | # Copy backend server files 42 | COPY demo/backend/server ${APP_ROOT}/server 43 | 44 | # Copy SAM 2 inference files 45 | COPY sam2 ${APP_ROOT}/server/sam2 46 | 47 | # Download SAM 2.1 checkpoints 48 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_tiny.pt 49 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_small.pt 50 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_base_plus.pt 51 | ADD https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ${APP_ROOT}/checkpoints/sam2.1_hiera_large.pt 52 | 53 | WORKDIR ${APP_ROOT}/server 54 | 55 | # https://pythonspeed.com/articles/gunicorn-in-docker/ 56 | CMD gunicorn --worker-tmp-dir /dev/shm \ 57 | --worker-class gthread app:app \ 58 | --log-level info \ 59 | --access-logfile /dev/stdout \ 60 | --log-file /dev/stderr \ 61 | --workers ${GUNICORN_WORKERS} \ 62 | --threads ${GUNICORN_THREADS} \ 63 | --bind 0.0.0.0:${GUNICORN_PORT} \ 64 | --timeout 60 65 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | AccessModifierOffset: -1 2 | AlignAfterOpenBracket: AlwaysBreak 3 | AlignConsecutiveAssignments: false 4 | AlignConsecutiveDeclarations: false 5 | AlignEscapedNewlinesLeft: true 6 | AlignOperands: false 7 | AlignTrailingComments: false 8 | AllowAllParametersOfDeclarationOnNextLine: false 9 | AllowShortBlocksOnASingleLine: false 10 | AllowShortCaseLabelsOnASingleLine: false 11 | AllowShortFunctionsOnASingleLine: Empty 12 | AllowShortIfStatementsOnASingleLine: false 13 | AllowShortLoopsOnASingleLine: false 14 | AlwaysBreakAfterReturnType: None 15 | AlwaysBreakBeforeMultilineStrings: true 16 | AlwaysBreakTemplateDeclarations: true 17 | BinPackArguments: false 18 | BinPackParameters: false 19 | BraceWrapping: 20 | AfterClass: false 21 | AfterControlStatement: false 22 | AfterEnum: false 23 | AfterFunction: false 24 | AfterNamespace: false 25 | AfterObjCDeclaration: false 26 | AfterStruct: false 27 | AfterUnion: false 28 | BeforeCatch: false 29 | BeforeElse: false 30 | IndentBraces: false 31 | BreakBeforeBinaryOperators: None 32 | BreakBeforeBraces: Attach 33 | BreakBeforeTernaryOperators: true 34 | BreakConstructorInitializersBeforeComma: false 35 | BreakAfterJavaFieldAnnotations: false 36 | BreakStringLiterals: false 37 | ColumnLimit: 80 38 | CommentPragmas: '^ IWYU pragma:' 39 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 40 | ConstructorInitializerIndentWidth: 4 41 | ContinuationIndentWidth: 4 42 | Cpp11BracedListStyle: true 43 | DerivePointerAlignment: false 44 | DisableFormat: false 45 | ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] 46 | IncludeCategories: 47 | - Regex: '^<.*\.h(pp)?>' 48 | Priority: 1 49 | - Regex: '^<.*' 50 | Priority: 2 51 | - Regex: '.*' 52 | Priority: 3 53 | IndentCaseLabels: true 54 | IndentWidth: 2 55 | IndentWrappedFunctionNames: false 56 | KeepEmptyLinesAtTheStartOfBlocks: false 57 | MacroBlockBegin: '' 58 | MacroBlockEnd: '' 59 | MaxEmptyLinesToKeep: 1 60 | NamespaceIndentation: None 61 | ObjCBlockIndentWidth: 2 62 | ObjCSpaceAfterProperty: false 63 | ObjCSpaceBeforeProtocolList: false 64 | PenaltyBreakBeforeFirstCallParameter: 1 65 | PenaltyBreakComment: 300 66 | PenaltyBreakFirstLessLess: 120 67 | PenaltyBreakString: 1000 68 | PenaltyExcessCharacter: 1000000 69 | PenaltyReturnTypeOnItsOwnLine: 200 70 | PointerAlignment: Left 71 | ReflowComments: true 72 | SortIncludes: true 73 | SpaceAfterCStyleCast: false 74 | SpaceBeforeAssignmentOperators: true 75 | SpaceBeforeParens: ControlStatements 76 | SpaceInEmptyParentheses: false 77 | SpacesBeforeTrailingComments: 1 78 | SpacesInAngles: false 79 | SpacesInContainerLiterals: true 80 | SpacesInCStyleCastParentheses: false 81 | SpacesInParentheses: false 82 | SpacesInSquareBrackets: false 83 | Standard: Cpp11 84 | TabWidth: 8 85 | UseTab: Never 86 | -------------------------------------------------------------------------------- /sav_dataset/sav_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | 7 | # adapted from https://github.com/hkchengrex/vos-benchmark 8 | # and https://github.com/davisvideochallenge/davis2017-evaluation 9 | # with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files 10 | # in the sav_dataset directory. 11 | from argparse import ArgumentParser 12 | 13 | from utils.sav_benchmark import benchmark 14 | 15 | """ 16 | The structure of the {GT_ROOT} can be either of the follow two structures. 17 | {GT_ROOT} and {PRED_ROOT} should be of the same format 18 | 19 | 1. SA-V val/test structure 20 | {GT_ROOT} # gt root folder 21 | ├── {video_id} 22 | │ ├── 000 # all masks associated with obj 000 23 | │ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask) 24 | │ │ └── ... 25 | │ ├── 001 # all masks associated with obj 001 26 | │ ├── 002 # all masks associated with obj 002 27 | │ └── ... 28 | ├── {video_id} 29 | ├── {video_id} 30 | └── ... 31 | 32 | 2. Similar to DAVIS structure: 33 | 34 | {GT_ROOT} # gt root folder 35 | ├── {video_id} 36 | │ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects) 37 | │ └── ... 38 | ├── {video_id} 39 | ├── {video_id} 40 | └── ... 41 | """ 42 | 43 | 44 | parser = ArgumentParser() 45 | parser.add_argument( 46 | "--gt_root", 47 | required=True, 48 | help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps", 49 | ) 50 | parser.add_argument( 51 | "--pred_root", 52 | required=True, 53 | help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root", 54 | ) 55 | parser.add_argument( 56 | "-n", "--num_processes", default=16, type=int, help="Number of concurrent processes" 57 | ) 58 | parser.add_argument( 59 | "-s", 60 | "--strict", 61 | help="Make sure every video in the gt_root folder has a corresponding video in the prediction", 62 | action="store_true", 63 | ) 64 | parser.add_argument( 65 | "-q", 66 | "--quiet", 67 | help="Quietly run evaluation without printing the information out", 68 | action="store_true", 69 | ) 70 | 71 | # https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85 72 | parser.add_argument( 73 | "--do_not_skip_first_and_last_frame", 74 | help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. " 75 | "Set this to true for evaluation on settings that doesn't skip first and last frames", 76 | action="store_true", 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | args = parser.parse_args() 82 | benchmark( 83 | [args.gt_root], 84 | [args.pred_root], 85 | args.strict, 86 | args.num_processes, 87 | verbose=not args.quiet, 88 | skip_first_and_last=not args.do_not_skip_first_and_last_frame, 89 | ) 90 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | no_obj_embed_spatial: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: true 105 | proj_tpos_enc_in_obj_ptrs: true 106 | use_signed_tpos_enc_to_obj_ptrs: true 107 | only_obj_ptrs_in_the_past_for_eval: true 108 | # object occlusion prediction 109 | pred_obj_scores: true 110 | pred_obj_scores_mlp: true 111 | fixed_no_obj_ptr: true 112 | # multimask tracking settings 113 | multimask_output_for_tracking: true 114 | use_multimask_token_for_obj_ptr: true 115 | multimask_min_pt_num: 0 116 | multimask_max_pt_num: 1 117 | use_mlp_for_obj_ptr_proj: true 118 | # Compilation flag 119 | compile_image_encoder: False 120 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Semi-supervised VOS Inference 2 | 3 | The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [LVOS](https://lingyihongfd.github.io/lvos.github.io/), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. 4 | 5 | After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. 6 | ```bash 7 | python ./tools/vos_inference.py \ 8 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 9 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 10 | --base_video_dir /path-to-davis-2017/JPEGImages/480p \ 11 | --input_mask_dir /path-to-davis-2017/Annotations/480p \ 12 | --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ 13 | --output_mask_dir ./outputs/davis_2017_pred_pngs \ 14 | --num_pathway 3 \ 15 | --iou_thre 0.1 \ 16 | --uncertainty 2 \ 17 | ``` 18 | (replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) 19 | 20 | --num_pathway: Defines the number of segmentation pathways to maintain. 21 | 22 | --iou_thre: Sets the IoU threshold, filtering out low-confidence masks. 23 | 24 | --uncertainty: Set the uncertainty threshold when selecting masks. 25 | 26 | 27 | To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. 28 | ```bash 29 | python ./tools/vos_inference.py \ 30 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 31 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 32 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 33 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 34 | --video_list_file /path-to-sav-val/sav_val.txt \ 35 | --per_obj_png_file \ 36 | --output_mask_dir ./outputs/sav_val_pred_pngs 37 | --num_pathway 3 \ 38 | --iou_thre 0.1 \ 39 | --uncertainty 2 \ 40 | ``` 41 | (replace `/path-to-sav-val` with the path to SA-V val) 42 | 43 | Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. 44 | 45 | Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. 46 | 47 | 48 | ## Multi-Node Inference for Accelerated Processing 49 | In default, you can run the above command to perform the inference on a single GPU. 50 | Also, we provide multi-node inference to speed up the process. 51 | 52 | The following SLURM script runs inference in parallel across multiple GPUs. It assumes each node has 8 GPUs and evenly distributes video processing tasks across these GPUs. You can adjust the number of nodes by specifying the `--num_nodes` argument. 53 | 54 | ```bash 55 | set -x 56 | 57 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 58 | IFS=',' read -ra GPULIST <<< "$gpu_list" 59 | NODE_ID=${SLURM_PROCID} 60 | echo "NODE ID: $NODE_ID" 61 | CHUNKS=8 62 | 63 | for IDX in $(seq 0 $((CHUNKS-1))); do 64 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python ./tools/vos_inference.py \ 65 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 66 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 67 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 68 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 69 | --video_list_file /path-to-sav-val/sav_val.txt \ 70 | --per_obj_png_file \ 71 | --output_mask_dir ./outputs/sav_val_pred_pngs \ 72 | --num_pathway 3 \ 73 | --iou_thre 0.1 \ 74 | --uncertainty 2 \ 75 | --num_nodes $1 \ 76 | --node_id $NODE_ID \ 77 | --num_chunks $CHUNKS \ 78 | --chunk_id $IDX & 79 | done 80 | 81 | wait 82 | ``` 83 | 84 | To launch the inference, run the following command: 85 | ```bash 86 | srun -p $PARTITION --cpus-per-task=8 --nodes=2 --ntasks-per-node=1 --gres=gpu:8 bash inference.sh 2 87 | ``` 88 | In this example, we initialize 2 nodes with a total of 16 GPUs for inference. Each node processes a portion of the video sequences in parallel, which significantly accelerates the overall inference process. 89 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | self.d_model = d_model 75 | for dim in backbone_channel_list: 76 | current = nn.Sequential() 77 | current.add_module( 78 | "conv", 79 | nn.Conv2d( 80 | in_channels=dim, 81 | out_channels=d_model, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | ), 86 | ) 87 | 88 | self.convs.append(current) 89 | self.fpn_interp_model = fpn_interp_model 90 | assert fuse_type in ["sum", "avg"] 91 | self.fuse_type = fuse_type 92 | 93 | # levels to have top-down features in its outputs 94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 95 | # have top-down propagation, while outputs of level 0 and level 1 have only 96 | # lateral features from the same backbone level. 97 | if fpn_top_down_levels is None: 98 | # default is to have top-down features on all levels 99 | fpn_top_down_levels = range(len(self.convs)) 100 | self.fpn_top_down_levels = list(fpn_top_down_levels) 101 | 102 | def forward(self, xs: List[torch.Tensor]): 103 | 104 | out = [None] * len(self.convs) 105 | pos = [None] * len(self.convs) 106 | assert len(xs) == len(self.convs) 107 | # fpn forward pass 108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 109 | prev_features = None 110 | # forward in top-down order (from low to high resolution) 111 | n = len(self.convs) - 1 112 | for i in range(n, -1, -1): 113 | x = xs[i] 114 | lateral_features = self.convs[n - i](x) 115 | if i in self.fpn_top_down_levels and prev_features is not None: 116 | top_down_features = F.interpolate( 117 | prev_features.to(dtype=torch.float32), 118 | scale_factor=2.0, 119 | mode=self.fpn_interp_model, 120 | align_corners=( 121 | None if self.fpn_interp_model == "nearest" else False 122 | ), 123 | antialias=False, 124 | ) 125 | prev_features = lateral_features + top_down_features 126 | if self.fuse_type == "avg": 127 | prev_features /= 2 128 | else: 129 | prev_features = lateral_features 130 | x_out = prev_features 131 | out[i] = x_out 132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 133 | 134 | return out, pos 135 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | # Package metadata 11 | NAME = "SAM-2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" 14 | URL = "https://github.com/facebookresearch/sam2" 15 | AUTHOR = "Meta AI" 16 | AUTHOR_EMAIL = "segment-anything@meta.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r", encoding="utf-8") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.3.1", 26 | "torchvision>=0.18.1", 27 | "numpy>=1.24.4", 28 | "tqdm>=4.66.1", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=9.4.0", 32 | ] 33 | 34 | EXTRA_PACKAGES = { 35 | "notebooks": [ 36 | "matplotlib>=3.9.1", 37 | "jupyter>=1.0.0", 38 | "opencv-python>=4.7.0", 39 | "eva-decord>=0.6.1", 40 | ], 41 | "interactive-demo": [ 42 | "Flask>=3.0.3", 43 | "Flask-Cors>=5.0.0", 44 | "av>=13.0.0", 45 | "dataclasses-json>=0.6.7", 46 | "eva-decord>=0.6.1", 47 | "gunicorn>=23.0.0", 48 | "imagesize>=1.4.1", 49 | "pycocotools>=2.0.8", 50 | "strawberry-graphql>=0.239.2", 51 | ], 52 | "dev": [ 53 | "black==24.2.0", 54 | "usort==1.0.2", 55 | "ufmt==2.0.0b2", 56 | "fvcore>=0.1.5.post20221221", 57 | "pandas>=2.2.2", 58 | "scikit-image>=0.24.0", 59 | "tensorboard>=2.17.0", 60 | "pycocotools>=2.0.8", 61 | "tensordict>=0.5.0", 62 | "opencv-python>=4.7.0", 63 | "submitit>=1.5.1", 64 | ], 65 | } 66 | 67 | # By default, we also build the SAM 2 CUDA extension. 68 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. 69 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" 70 | # By default, we allow SAM 2 installation to proceed even with build errors. 71 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. 72 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" 73 | 74 | # Catch and skip errors during extension building and print a warning message 75 | # (note that this message only shows up under verbose build mode 76 | # "pip install -v -e ." or "python setup.py build_ext -v") 77 | CUDA_ERROR_MSG = ( 78 | "{}\n\n" 79 | "Failed to build the SAM 2 CUDA extension due to the error above. " 80 | "You can still use SAM 2 and it's OK to ignore the error above, although some " 81 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 82 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" 83 | ) 84 | 85 | 86 | def get_extensions(): 87 | if not BUILD_CUDA: 88 | return [] 89 | 90 | try: 91 | from torch.utils.cpp_extension import CUDAExtension 92 | 93 | srcs = ["sam2/csrc/connected_components.cu"] 94 | compile_args = { 95 | "cxx": [], 96 | "nvcc": [ 97 | "-DCUDA_HAS_FP16=1", 98 | "-D__CUDA_NO_HALF_OPERATORS__", 99 | "-D__CUDA_NO_HALF_CONVERSIONS__", 100 | "-D__CUDA_NO_HALF2_OPERATORS__", 101 | ], 102 | } 103 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 104 | except Exception as e: 105 | if BUILD_ALLOW_ERRORS: 106 | print(CUDA_ERROR_MSG.format(e)) 107 | ext_modules = [] 108 | else: 109 | raise e 110 | 111 | return ext_modules 112 | 113 | 114 | try: 115 | from torch.utils.cpp_extension import BuildExtension 116 | 117 | class BuildExtensionIgnoreErrors(BuildExtension): 118 | 119 | def finalize_options(self): 120 | try: 121 | super().finalize_options() 122 | except Exception as e: 123 | print(CUDA_ERROR_MSG.format(e)) 124 | self.extensions = [] 125 | 126 | def build_extensions(self): 127 | try: 128 | super().build_extensions() 129 | except Exception as e: 130 | print(CUDA_ERROR_MSG.format(e)) 131 | self.extensions = [] 132 | 133 | def get_ext_filename(self, ext_name): 134 | try: 135 | return super().get_ext_filename(ext_name) 136 | except Exception as e: 137 | print(CUDA_ERROR_MSG.format(e)) 138 | self.extensions = [] 139 | return "_C.so" 140 | 141 | cmdclass = { 142 | "build_ext": ( 143 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) 144 | if BUILD_ALLOW_ERRORS 145 | else BuildExtension.with_options(no_python_abi_suffix=True) 146 | ) 147 | } 148 | except Exception as e: 149 | cmdclass = {} 150 | if BUILD_ALLOW_ERRORS: 151 | print(CUDA_ERROR_MSG.format(e)) 152 | else: 153 | raise e 154 | 155 | 156 | # Setup configuration 157 | setup( 158 | name=NAME, 159 | version=VERSION, 160 | description=DESCRIPTION, 161 | long_description=LONG_DESCRIPTION, 162 | long_description_content_type="text/markdown", 163 | url=URL, 164 | author=AUTHOR, 165 | author_email=AUTHOR_EMAIL, 166 | license=LICENSE, 167 | packages=find_packages(exclude="notebooks"), 168 | include_package_data=True, 169 | install_requires=REQUIRED_PACKAGES, 170 | extras_require=EXTRA_PACKAGES, 171 | python_requires=">=3.10.0", 172 | ext_modules=get_extensions(), 173 | cmdclass=cmdclass, 174 | ) 175 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sav_dataset/utils/sav_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | import json 7 | import os 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pycocotools.mask as mask_util 14 | 15 | 16 | def decode_video(video_path: str) -> List[np.ndarray]: 17 | """ 18 | Decode the video and return the RGB frames 19 | """ 20 | video = cv2.VideoCapture(video_path) 21 | video_frames = [] 22 | while video.isOpened(): 23 | ret, frame = video.read() 24 | if ret: 25 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 26 | video_frames.append(frame) 27 | else: 28 | break 29 | return video_frames 30 | 31 | 32 | def show_anns(masks, colors: List, borders=True) -> None: 33 | """ 34 | show the annotations 35 | """ 36 | # return if no masks 37 | if len(masks) == 0: 38 | return 39 | 40 | # sort masks by size 41 | sorted_annot_and_color = sorted( 42 | zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True 43 | ) 44 | H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1] 45 | 46 | canvas = np.ones((H, W, 4)) 47 | canvas[:, :, 3] = 0 # set the alpha channel 48 | contour_thickness = max(1, int(min(5, 0.01 * min(H, W)))) 49 | for mask, color in sorted_annot_and_color: 50 | canvas[mask] = np.concatenate([color, [0.55]]) 51 | if borders: 52 | contours, _ = cv2.findContours( 53 | np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE 54 | ) 55 | cv2.drawContours( 56 | canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness 57 | ) 58 | 59 | ax = plt.gca() 60 | ax.imshow(canvas) 61 | 62 | 63 | class SAVDataset: 64 | """ 65 | SAVDataset is a class to load the SAV dataset and visualize the annotations. 66 | """ 67 | 68 | def __init__(self, sav_dir, annot_sample_rate=4): 69 | """ 70 | Args: 71 | sav_dir: the directory of the SAV dataset 72 | annot_sample_rate: the sampling rate of the annotations. 73 | The annotations are aligned with the videos at 6 fps. 74 | """ 75 | self.sav_dir = sav_dir 76 | self.annot_sample_rate = annot_sample_rate 77 | self.manual_mask_colors = np.random.random((256, 3)) 78 | self.auto_mask_colors = np.random.random((256, 3)) 79 | 80 | def read_frames(self, mp4_path: str) -> None: 81 | """ 82 | Read the frames and downsample them to align with the annotations. 83 | """ 84 | if not os.path.exists(mp4_path): 85 | print(f"{mp4_path} doesn't exist.") 86 | return None 87 | else: 88 | # decode the video 89 | frames = decode_video(mp4_path) 90 | print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).") 91 | 92 | # downsample the frames to align with the annotations 93 | frames = frames[:: self.annot_sample_rate] 94 | print( 95 | f"Videos are annotated every {self.annot_sample_rate} frames. " 96 | "To align with the annotations, " 97 | f"downsample the video to {len(frames)} frames." 98 | ) 99 | return frames 100 | 101 | def get_frames_and_annotations( 102 | self, video_id: str 103 | ) -> Tuple[List | None, Dict | None, Dict | None]: 104 | """ 105 | Get the frames and annotations for video. 106 | """ 107 | # load the video 108 | mp4_path = os.path.join(self.sav_dir, video_id + ".mp4") 109 | frames = self.read_frames(mp4_path) 110 | if frames is None: 111 | return None, None, None 112 | 113 | # load the manual annotations 114 | manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json") 115 | if not os.path.exists(manual_annot_path): 116 | print(f"{manual_annot_path} doesn't exist. Something might be wrong.") 117 | manual_annot = None 118 | else: 119 | manual_annot = json.load(open(manual_annot_path)) 120 | 121 | # load the manual annotations 122 | auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json") 123 | if not os.path.exists(auto_annot_path): 124 | print(f"{auto_annot_path} doesn't exist.") 125 | auto_annot = None 126 | else: 127 | auto_annot = json.load(open(auto_annot_path)) 128 | 129 | return frames, manual_annot, auto_annot 130 | 131 | def visualize_annotation( 132 | self, 133 | frames: List[np.ndarray], 134 | auto_annot: Optional[Dict], 135 | manual_annot: Optional[Dict], 136 | annotated_frame_id: int, 137 | show_auto=True, 138 | show_manual=True, 139 | ) -> None: 140 | """ 141 | Visualize the annotations on the annotated_frame_id. 142 | If show_manual is True, show the manual annotations. 143 | If show_auto is True, show the auto annotations. 144 | By default, show both auto and manual annotations. 145 | """ 146 | 147 | if annotated_frame_id >= len(frames): 148 | print("invalid annotated_frame_id") 149 | return 150 | 151 | rles = [] 152 | colors = [] 153 | if show_manual and manual_annot is not None: 154 | rles.extend(manual_annot["masklet"][annotated_frame_id]) 155 | colors.extend( 156 | self.manual_mask_colors[ 157 | : len(manual_annot["masklet"][annotated_frame_id]) 158 | ] 159 | ) 160 | if show_auto and auto_annot is not None: 161 | rles.extend(auto_annot["masklet"][annotated_frame_id]) 162 | colors.extend( 163 | self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])] 164 | ) 165 | 166 | plt.imshow(frames[annotated_frame_id]) 167 | 168 | if len(rles) > 0: 169 | masks = [mask_util.decode(rle) > 0 for rle in rles] 170 | show_anns(masks, colors) 171 | else: 172 | print("No annotation will be shown") 173 | 174 | plt.axis("off") 175 | plt.show() 176 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | import torch 11 | from hydra import compose 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf 14 | 15 | import sam2 16 | 17 | # Check if the user is running Python from the parent directory of the sam2 repo 18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since 19 | # it could shadow the sam2 package and cause issues. 20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): 21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself 22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). 23 | # This typically happens because the user is running Python from the parent directory 24 | # that contains the sam2 repo they cloned. 25 | raise RuntimeError( 26 | "You're likely running Python from the parent directory of the sam2 repository " 27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " 28 | "This is not supported since the `sam2` Python package could be shadowed by the " 29 | "repository name (the repository is also named `sam2` and contains the Python package " 30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " 31 | "rather than its parent dir, or from your home directory) after installing SAM 2." 32 | ) 33 | 34 | 35 | HF_MODEL_ID_TO_FILENAMES = { 36 | "facebook/sam2-hiera-tiny": ( 37 | "configs/sam2/sam2_hiera_t.yaml", 38 | "sam2_hiera_tiny.pt", 39 | ), 40 | "facebook/sam2-hiera-small": ( 41 | "configs/sam2/sam2_hiera_s.yaml", 42 | "sam2_hiera_small.pt", 43 | ), 44 | "facebook/sam2-hiera-base-plus": ( 45 | "configs/sam2/sam2_hiera_b+.yaml", 46 | "sam2_hiera_base_plus.pt", 47 | ), 48 | "facebook/sam2-hiera-large": ( 49 | "configs/sam2/sam2_hiera_l.yaml", 50 | "sam2_hiera_large.pt", 51 | ), 52 | "facebook/sam2.1-hiera-tiny": ( 53 | "configs/sam2.1/sam2.1_hiera_t.yaml", 54 | "sam2.1_hiera_tiny.pt", 55 | ), 56 | "facebook/sam2.1-hiera-small": ( 57 | "configs/sam2.1/sam2.1_hiera_s.yaml", 58 | "sam2.1_hiera_small.pt", 59 | ), 60 | "facebook/sam2.1-hiera-base-plus": ( 61 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 62 | "sam2.1_hiera_base_plus.pt", 63 | ), 64 | "facebook/sam2.1-hiera-large": ( 65 | "configs/sam2.1/sam2.1_hiera_l.yaml", 66 | "sam2.1_hiera_large.pt", 67 | ), 68 | } 69 | 70 | 71 | def build_sam2( 72 | config_file, 73 | ckpt_path=None, 74 | device="cuda", 75 | mode="eval", 76 | hydra_overrides_extra=[], 77 | apply_postprocessing=True, 78 | **kwargs, 79 | ): 80 | 81 | if apply_postprocessing: 82 | hydra_overrides_extra = hydra_overrides_extra.copy() 83 | hydra_overrides_extra += [ 84 | # dynamically fall back to multi-mask if the single mask is not stable 85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 88 | ] 89 | # Read config and init model 90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 91 | OmegaConf.resolve(cfg) 92 | model = instantiate(cfg.model, _recursive_=True) 93 | _load_checkpoint(model, ckpt_path) 94 | model = model.to(device) 95 | if mode == "eval": 96 | model.eval() 97 | return model 98 | 99 | 100 | def build_sam2_video_predictor( 101 | config_file, 102 | ckpt_path=None, 103 | device="cuda", 104 | mode="eval", 105 | hydra_overrides_extra=[], 106 | apply_postprocessing=True, 107 | **kwargs, 108 | ): 109 | hydra_overrides = [ 110 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 111 | ] 112 | if apply_postprocessing: 113 | hydra_overrides_extra = hydra_overrides_extra.copy() 114 | hydra_overrides_extra += [ 115 | # dynamically fall back to multi-mask if the single mask is not stable 116 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 117 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 118 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 119 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 120 | "++model.binarize_mask_from_pts_for_mem_enc=true", 121 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 122 | "++model.fill_hole_area=8", 123 | ] 124 | hydra_overrides.extend(hydra_overrides_extra) 125 | 126 | # Read config and init model 127 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 128 | OmegaConf.resolve(cfg) 129 | model = instantiate(cfg.model, _recursive_=True) 130 | _load_checkpoint(model, ckpt_path) 131 | model = model.to(device) 132 | if mode == "eval": 133 | model.eval() 134 | return model 135 | 136 | 137 | def _hf_download(model_id): 138 | from huggingface_hub import hf_hub_download 139 | 140 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 141 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 142 | return config_name, ckpt_path 143 | 144 | 145 | def build_sam2_hf(model_id, **kwargs): 146 | config_name, ckpt_path = _hf_download(model_id) 147 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 148 | 149 | 150 | def build_sam2_video_predictor_hf(model_id, **kwargs): 151 | config_name, ckpt_path = _hf_download(model_id) 152 | return build_sam2_video_predictor( 153 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 154 | ) 155 | 156 | 157 | def _load_checkpoint(model, ckpt_path): 158 | if ckpt_path is not None: 159 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 160 | missing_keys, unexpected_keys = model.load_state_dict(sd) 161 | if missing_keys: 162 | logging.error(missing_keys) 163 | raise RuntimeError() 164 | if unexpected_keys: 165 | logging.error(unexpected_keys) 166 | raise RuntimeError() 167 | logging.info("Loaded checkpoint sucessfully") 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM2Long 2 | This repository is the official implementation of SAM2Long. 3 | 7 | 8 |

9 | 10 | [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC_BY--NC_4.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/)
11 | 12 | 13 | 14 | 15 |

16 | 17 | 18 | 19 | >[**SAM2Long: Enhancing SAM 2 for Long Video Segmentation with a Training-Free Memory Tree**](https://arxiv.org/abs/2410.16268)
20 | > [Shuangrui Ding](https://mark12ding.github.io/), [Rui Qian](https://shvdiwnkozbw.github.io/), [Xiaoyi Dong](https://lightdxy.github.io/), [Pan Zhang](https://panzhang0212.github.io/)
21 | [Yuhang Zang](https://yuhangzang.github.io/), [Yuhang Cao](https://scholar.google.com/citations?user=sJkqsqkAAAAJ), [Yuwei Guo](https://guoyww.github.io/), [Dahua Lin](http://dahua.site/), [Jiaqi Wang](https://myownskyw7.github.io/)
22 | CUHK, Shanghai AI Lab 23 | 24 | https://github.com/user-attachments/assets/265a1f01-ea60-4480-b1d6-ce1b85e48c89 25 | 26 | (More Demos are shown in the [project page](https://mark12ding.github.io/project/SAM2Long/)!) 27 | 28 | ## 📰 News 29 | [2025/7/29]🔥🔥🔥 Our latest work [SeC](https://arxiv.org/abs/2507.15852) enables SAM 2 to track objects across multi-transition videos and it is fully open-sourced!
30 | [2025/6/26]🍺🍺🍺 Aloha! SAM2Long is accepted at ICCV 2025. See you in Hawaii.
31 | [2024/12/18]🔥🔥🔥 We include SAM2Long's performance on VOT benchmarks, Lasot, LaSoText, and GoT10k. Refer to our updated [paper📄](https://arxiv.org/pdf/2410.16268).
32 | [2024/12/9]🔥🔥🔥 The SAM2Long demo is now live on Hugging Face Spaces 🤗[Link](https://huggingface.co/spaces/Mar2Ding/SAM2Long-Demo). Take a look! 33 | 34 | 35 | 36 | ## 💡 Highlights 37 | 38 | ### 🔥 Enhanced Capability in Long-Term Video Segmentation 39 | 40 | SAM2Long significantly improves upon SAM 2 by addressing **error accumulation** issue, particularly in challenging long-term video scenarios involving object occlusion and reappearance. With SAM2Long, the segmentation process becomes more resilient and accurate over time, maintaining strong performance even as objects are occluded or reappear in the video stream. 41 | 42 | 46 | 47 | ### ⚡️ A Simple Training-free Memory Tree 48 | 49 | SAM2Long introduces a **training-free** memory tree that effectively reduces the risk of error propagation over time. By maintaining diverse segmentation hypotheses and dynamically pruning less optimal paths as the video progresses, this approach enhances segmentation without the need for additional parameters or further training. It maximizes the potential of SAM 2 to deliver better results in complex video scenarios. 50 | 51 | ### 🤯 Superior Performance Compared to SAM 2 52 | 53 | SAM2Long pushes the performance limits of SAM 2 even further across various video object segmentation benchmarks, especially achieving an average improvement of 3 in J & F scores across all 24 head-to-head comparisons on long-term video datasets like SA-V and LVOS. 54 | 55 | 56 | ## 🚀 Main Results 57 | 58 | ### SAM 2 checkpoints 59 | The table below provides a one-to-one comparison between SAM 2 and SAM2Long using the improved SAM 2 checkpoints. 60 | | Method | Backbone | SA-V val (J & F) | SA-V test (J & F) | LVOS v2 (J & F) | 61 | | :------: | :--------: | :------: | :--------: | :--------: | 62 | | SAM 2 | Tiny | 73.5 | 74.6 | 77.8 | 63 | | SAM2Long| Tiny | 77.0 | 78.7 | 81.4 | 64 | | SAM 2 | Small | 73.0 | 74.6 | 79.7 | 65 | | SAM2Long| Small | 77.7 | 78.1 | 83.2 | 66 | | SAM 2 | Base+ | 75.4 | 74.6 | 80.2 | 67 | | SAM2Long| Base+ | 78.4 | 78.5 | 82.3 | 68 | | SAM 2 | Large | 76.3 | 75.5 | 83.0 | 69 | | SAM2Long| Large | 80.8 | 80.8 | 85.2 | 70 | 71 | ### SAM 2.1 checkpoints 72 | The table below provides a one-to-one comparison between SAM 2 and SAM2Long using the SAM 2.1 checkpoints. 73 | 74 | | Method | Backbone | SA-V val (J & F) | SA-V test (J & F) | LVOS v2 (J & F) | 75 | | :------: | :--------: | :------: | :--------: | :--------: | 76 | | SAM 2 | Tiny | 75.1 | 76.3 | 81.6 | 77 | | SAM2Long| Tiny | 78.9 | 79.0 | 82.4 | 78 | | SAM 2 | Small | 76.9 | 76.9 | 82.1 | 79 | | SAM2Long| Small | 79.6 | 80.4 | 84.3 | 80 | | SAM 2 | Base+ | 78.0 | 77.7 | 83.1 | 81 | | SAM2Long| Base+ | 80.5 | 80.8 | 85.2 | 82 | | SAM 2 | Large | 78.6 | 79.6 | 84.0 | 83 | | SAM2Long| Large | 81.1 | 81.2 | 85.3 | 84 | 85 | ## 🛠️ Usage 86 | 87 | ### Installation 88 | Please follow the instruction of [official SAM 2 repo](https://github.com/facebookresearch/sam2?tab=readme-ov-file#installation). If you encounter issues running the code, it's recommended to create a new environment specifically for SAM2Long instead of sharing it with SAM2. For further details, please check this issue [here](https://github.com/Mark12Ding/SAM2Long/issues/5#issuecomment-2458974462). 89 | 90 | ### Download Checkpoints 91 | All the model checkpoints can be downloaded by running: 92 | ``` 93 | bash 94 | cd checkpoints && \ 95 | ./download_ckpts.sh && \ 96 | cd .. 97 | ``` 98 | 99 | ### Inference 100 | The inference instruction is in [INFERENCE.md](tools/README.md). 101 | 102 | ### Evaluation 103 | 104 | The evaluation code can be found [here](sav_dataset/README.md). 105 | 106 | To evaluate performance on seen and unseen categories in the LVOS dataset, refer to the evaluation code available [here](https://github.com/LingyiHongfd/lvos-evaluation). 107 | 108 | ## ☎️ Contact 109 | Shuangrui Ding: mark12ding@gmail.com 110 | 111 | 112 | ## 🔒 License 113 | The majority of this project is released under the CC-BY-NC 4.0 license as found in the LICENSE file. The original SAM 2 model checkpoints and SAM 2 training code are licensed under [Apache 2.0](https://github.com/facebookresearch/sam2/blob/main/LICENSE). 114 | 115 | 116 | ## 👍 Acknowledgements 117 | I would like to thank [Yixuan Wang](https://wangyixuan12.github.io/) for his assistance with dataset preparation and [Haohang Xu](https://scholar.google.com/citations?user=9nqZkmUAAAAJ) for his insightful disscusion. 118 | 119 | This project is built upon [SAM 2](https://github.com/facebookresearch/sam2) and the format of this README is inspired by [VideoMAE](https://github.com/MCG-NJU/VideoMAE/blob/main/README.md). 120 | 121 | ## ✒️ Citation 122 | If you find our work helpful for your research, please consider giving a star ⭐ and citation 📝. 123 | ```bibtex 124 | @article{ding2024sam2long, 125 | title={SAM2Long: Enhancing SAM 2 for Long Video Segmentation with a Training-Free Memory Tree}, 126 | author={Ding, Shuangrui and Qian, Rui and Dong, Xiaoyi and Zhang, Pan and Zang, Yuhang and Cao, Yuhang and Guo, Yuwei and Lin, Dahua and Wang, Jiaqi}, 127 | journal={arXiv preprint arXiv:2410.16268}, 128 | year={2024} 129 | } 130 | ``` 131 | 132 | 133 | -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | """ 147 | Embeds different types of prompts, returning both sparse and dense 148 | embeddings. 149 | 150 | Arguments: 151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 152 | and labels to embed. 153 | boxes (torch.Tensor or none): boxes to embed 154 | masks (torch.Tensor or none): masks to embed 155 | 156 | Returns: 157 | torch.Tensor: sparse embeddings for the points and boxes, with shape 158 | BxNx(embed_dim), where N is determined by the number of input points 159 | and boxes. 160 | torch.Tensor: dense embeddings for the masks, in the shape 161 | Bx(embed_dim)x(embed_H)x(embed_W) 162 | """ 163 | bs = self._get_batch_size(points, boxes, masks) 164 | sparse_embeddings = torch.empty( 165 | (bs, 0, self.embed_dim), device=self._get_device() 166 | ) 167 | if points is not None: 168 | coords, labels = points 169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 171 | if boxes is not None: 172 | box_embeddings = self._embed_boxes(boxes) 173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 174 | 175 | if masks is not None: 176 | dense_embeddings = self._embed_masks(masks) 177 | else: 178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 180 | ) 181 | 182 | return sparse_embeddings, dense_embeddings 183 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | import pdb 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0, object_frame_scores=None, object_ptr_scores=None): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | if object_frame_scores is None: 75 | key = memory + pos if self.pos_enc_at_cross_attn_keys else memory 76 | else: # relative 77 | key_original = memory + pos if self.pos_enc_at_cross_attn_keys else memory 78 | num_frame, num_ptr = len(object_frame_scores), len(object_ptr_scores) 79 | num_frame_ = int(num_frame*4096) 80 | num_object = key_original.shape[0] 81 | key_frame = key_original[:, :num_frame_].reshape(num_object, num_frame, 4096, -1) 82 | key_ptr = key_original[:, num_frame_:].reshape(num_object, num_ptr, 4, -1) 83 | scaling_low = 0.95 84 | scaling_high = 1.05 85 | if num_frame == 1: 86 | key = key_original 87 | else: 88 | weight_frame = torch.stack(object_frame_scores, dim=1) # num_object, num_frame 89 | weight_ptr = torch.stack(object_ptr_scores, dim=1) # num_object, num_ptr 90 | 91 | standard_weight_frame = torch.linspace(scaling_low, scaling_high, num_frame).to(weight_frame) # num_frame 92 | standard_weight_ptr = torch.linspace(scaling_low, scaling_high, num_ptr).to(weight_ptr) # num_ptr 93 | 94 | new_weight_frame = torch.zeros_like(weight_frame) 95 | new_weight_ptr = torch.zeros_like(weight_ptr) 96 | 97 | new_weight_frame.scatter_(1, torch.argsort(weight_frame, dim=1), standard_weight_frame.unsqueeze(0).repeat([num_object, 1])) 98 | new_weight_ptr.scatter_(1, torch.argsort(weight_ptr, dim=1), standard_weight_ptr.unsqueeze(0).repeat([num_object, 1])) 99 | 100 | key_frame_scale = (new_weight_frame[:, :, None, None].to(key_frame.device) * key_frame) 101 | key_ptr_scale = (new_weight_ptr[:, :, None, None].to(key_ptr.device) * key_ptr) 102 | key = torch.cat([key_frame_scale.reshape(num_object, num_frame_, -1), key_ptr_scale.reshape(num_object, int(num_ptr*4), -1)], dim=1) 103 | # key = memory + pos if self.pos_enc_at_cross_attn_keys else memory 104 | tgt2 = self.cross_attn_image( 105 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 106 | k=key, 107 | v=memory, 108 | **kwds, 109 | ) 110 | tgt = tgt + self.dropout2(tgt2) 111 | return tgt 112 | 113 | def forward( 114 | self, 115 | tgt, 116 | memory, 117 | pos: Optional[Tensor] = None, 118 | query_pos: Optional[Tensor] = None, 119 | num_k_exclude_rope: int = 0, 120 | object_frame_scores = None, 121 | object_ptr_scores = None, 122 | ) -> torch.Tensor: 123 | 124 | # Self-Attn, Cross-Attn 125 | tgt = self._forward_sa(tgt, query_pos) 126 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope, object_frame_scores, object_ptr_scores) 127 | # MLP 128 | tgt2 = self.norm3(tgt) 129 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 130 | tgt = tgt + self.dropout3(tgt2) 131 | return tgt 132 | 133 | 134 | class MemoryAttention(nn.Module): 135 | def __init__( 136 | self, 137 | d_model: int, 138 | pos_enc_at_input: bool, 139 | layer: nn.Module, 140 | num_layers: int, 141 | batch_first: bool = True, # Do layers expect batch first input? 142 | ): 143 | super().__init__() 144 | self.d_model = d_model 145 | self.layers = get_clones(layer, num_layers) 146 | self.num_layers = num_layers 147 | self.norm = nn.LayerNorm(d_model) 148 | self.pos_enc_at_input = pos_enc_at_input 149 | self.batch_first = batch_first 150 | 151 | def forward( 152 | self, 153 | curr: torch.Tensor, # self-attention inputs 154 | memory: torch.Tensor, # cross-attention inputs 155 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 156 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 157 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 158 | object_frame_scores=None, 159 | object_ptr_scores=None, 160 | ): 161 | if isinstance(curr, list): 162 | assert isinstance(curr_pos, list) 163 | assert len(curr) == len(curr_pos) == 1 164 | curr, curr_pos = ( 165 | curr[0], 166 | curr_pos[0], 167 | ) 168 | 169 | assert ( 170 | curr.shape[1] == memory.shape[1] 171 | ), "Batch size must be the same for curr and memory" 172 | 173 | output = curr 174 | if self.pos_enc_at_input and curr_pos is not None: 175 | output = output + 0.1 * curr_pos 176 | 177 | if self.batch_first: 178 | # Convert to batch first 179 | output = output.transpose(0, 1) 180 | curr_pos = curr_pos.transpose(0, 1) 181 | memory = memory.transpose(0, 1) 182 | memory_pos = memory_pos.transpose(0, 1) 183 | 184 | for layer in self.layers: 185 | kwds = {} 186 | if isinstance(layer.cross_attn_image, RoPEAttention): 187 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens, 188 | "object_frame_scores": object_frame_scores, 189 | "object_ptr_scores":object_ptr_scores} 190 | 191 | output = layer( 192 | tgt=output, 193 | memory=memory, 194 | pos=memory_pos, 195 | query_pos=curr_pos, 196 | **kwds, 197 | ) 198 | normed_output = self.norm(output) 199 | 200 | if self.batch_first: 201 | # Convert back to seq first 202 | normed_output = normed_output.transpose(0, 1) 203 | curr_pos = curr_pos.transpose(0, 1) 204 | 205 | return normed_output -------------------------------------------------------------------------------- /sav_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Video (SA-V) Dataset 2 | 3 | ## Overview 4 | 5 | [Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset). 6 | 7 | ![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true) 8 | 9 | ## Getting Started 10 | 11 | ### Download the dataset 12 | 13 | Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets. 14 | 15 | ### Dataset Stats 16 | 17 | | | Num Videos | Num Masklets | 18 | | ---------- | ---------- | ----------------------------------------- | 19 | | SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) | 20 | | SA-V val | 155 | 293 | 21 | | SA-V test | 150 | 278 | 22 | 23 | ### Notebooks 24 | 25 | To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook. 26 | 27 | ### SA-V train 28 | 29 | For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below. 30 | 31 | ``` 32 | { 33 | "video_id" : str; video id 34 | "video_duration" : float64; the duration in seconds of this video 35 | "video_frame_count" : float64; the number of frames in the video 36 | "video_height" : float64; the height of the video 37 | "video_width" : float64; the width of the video 38 | "video_resolution" : float64; video_height $\times$ video_width 39 | "video_environment" : List[str]; "Indoor" or "Outdoor" 40 | "video_split" : str; "train" for training set 41 | "masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs. 42 | The outer list is over frames in the video and the inner list 43 | is over objects in the video. 44 | "masklet_id" : List[int]; the masklet ids 45 | "masklet_size_rel" : List[float]; the average mask area normalized by resolution 46 | across all the frames where the object is visible 47 | "masklet_size_abs" : List[float]; the average mask area (in pixels) 48 | across all the frames where the object is visible 49 | "masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$, 50 | "medium": $32^2$ <= masklet_size_abs < $96^2$, 51 | and "large": masklet_size_abs > $96^2$ 52 | "masklet_visibility_changes" : List[int]; the number of times where the visibility changes 53 | after the first appearance (e.g., invisible -> visible 54 | or visible -> invisible) 55 | "masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears 56 | the first time in the video. Always 0 for auto masklets. 57 | "masklet_frame_count" : List[int]; the number of frames being annotated. Note that 58 | videos are annotated at 6 fps (annotated every 4 frames) 59 | while the videos are at 24 fps. 60 | "masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators. 61 | Always 0 for auto masklets. 62 | "masklet_type" : List[str]; "auto" or "manual" 63 | "masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only. 64 | "masklet_num" : int; the number of manual/auto masklets in the video 65 | 66 | } 67 | ``` 68 | 69 | Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations. 70 | 71 | ### SA-V val and test 72 | 73 | For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure: 74 | 75 | ``` 76 | sav_val(sav_test) 77 | ├── sav_val.txt (sav_test.txt): a list of video ids in the split 78 | ├── JPEGImages_24fps # videos are extracted at 24 fps 79 | │ ├── {video_id} 80 | │ │ ├── 00000.jpg # video frame 81 | │ │ ├── 00001.jpg # video frame 82 | │ │ ├── 00002.jpg # video frame 83 | │ │ ├── 00003.jpg # video frame 84 | │ │ └── ... 85 | │ ├── {video_id} 86 | │ ├── {video_id} 87 | │ └── ... 88 | └── Annotations_6fps # videos are annotated at 6 fps 89 | ├── {video_id} 90 | │ ├── 000 # obj 000 91 | │ │ ├── 00000.png # mask for object 000 in 00000.jpg 92 | │ │ ├── 00004.png # mask for object 000 in 00004.jpg 93 | │ │ ├── 00008.png # mask for object 000 in 00008.jpg 94 | │ │ ├── 00012.png # mask for object 000 in 00012.jpg 95 | │ │ └── ... 96 | │ ├── 001 # obj 001 97 | │ ├── 002 # obj 002 98 | │ └── ... 99 | ├── {video_id} 100 | ├── {video_id} 101 | └── ... 102 | ``` 103 | 104 | All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands. 105 | 106 | ## SA-V Val and Test Evaluation 107 | 108 | We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows: 109 | 110 | ``` 111 | pip install -r requirements.txt 112 | ``` 113 | 114 | Then we can evaluate the predictions as follows: 115 | 116 | ``` 117 | python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT} 118 | ``` 119 | 120 | or run 121 | 122 | ``` 123 | python sav_evaluator.py --help 124 | ``` 125 | 126 | to print a complete help message. 127 | 128 | The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure. 129 | 130 | - Same as SA-V val and test directory structure 131 | 132 | ``` 133 | {GT_ROOT} # gt root folder 134 | ├── {video_id} 135 | │ ├── 000 # all masks associated with obj 000 136 | │ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask) 137 | │ │ └── ... 138 | │ ├── 001 # all masks associated with obj 001 139 | │ ├── 002 # all masks associated with obj 002 140 | │ └── ... 141 | ├── {video_id} 142 | ├── {video_id} 143 | └── ... 144 | ``` 145 | 146 | In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations. 147 | 148 | - Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure 149 | 150 | ``` 151 | {GT_ROOT} # gt root folder 152 | ├── {video_id} 153 | │ ├── 00000.png # annotations in frame 00000 (may contain multiple objects) 154 | │ └── ... 155 | ├── {video_id} 156 | ├── {video_id} 157 | └── ... 158 | ``` 159 | 160 | 161 | 162 | 163 | ## License 164 | 165 | The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0. 166 | 167 | Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)). 168 | -------------------------------------------------------------------------------- /sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be an even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be an even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | ): 29 | super().__init__() 30 | assert num_pos_feats % 2 == 0, "Expecting even model width" 31 | self.num_pos_feats = num_pos_feats // 2 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | self.cache = {} 41 | 42 | def _encode_xy(self, x, y): 43 | # The positions are expected to be normalized 44 | assert len(x) == len(y) and x.ndim == y.ndim == 1 45 | x_embed = x * self.scale 46 | y_embed = y * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, None] / dim_t 52 | pos_y = y_embed[:, None] / dim_t 53 | pos_x = torch.stack( 54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 55 | ).flatten(1) 56 | pos_y = torch.stack( 57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 58 | ).flatten(1) 59 | return pos_x, pos_y 60 | 61 | @torch.no_grad() 62 | def encode_boxes(self, x, y, w, h): 63 | pos_x, pos_y = self._encode_xy(x, y) 64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 65 | return pos 66 | 67 | encode = encode_boxes # Backwards compatibility 68 | 69 | @torch.no_grad() 70 | def encode_points(self, x, y, labels): 71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 72 | assert bx == by and nx == ny and bx == bl and nx == nl 73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 76 | return pos 77 | 78 | @torch.no_grad() 79 | def forward(self, x: torch.Tensor): 80 | cache_key = (x.shape[-2], x.shape[-1]) 81 | if cache_key in self.cache: 82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 83 | y_embed = ( 84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 85 | .view(1, -1, 1) 86 | .repeat(x.shape[0], 1, x.shape[-1]) 87 | ) 88 | x_embed = ( 89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 90 | .view(1, 1, -1) 91 | .repeat(x.shape[0], x.shape[-2], 1) 92 | ) 93 | 94 | if self.normalize: 95 | eps = 1e-6 96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 98 | 99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 101 | 102 | pos_x = x_embed[:, :, :, None] / dim_t 103 | pos_y = y_embed[:, :, :, None] / dim_t 104 | pos_x = torch.stack( 105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 106 | ).flatten(3) 107 | pos_y = torch.stack( 108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 109 | ).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | self.cache[cache_key] = pos[0] 112 | return pos 113 | 114 | 115 | class PositionEmbeddingRandom(nn.Module): 116 | """ 117 | Positional encoding using random spatial frequencies. 118 | """ 119 | 120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 121 | super().__init__() 122 | if scale is None or scale <= 0.0: 123 | scale = 1.0 124 | self.register_buffer( 125 | "positional_encoding_gaussian_matrix", 126 | scale * torch.randn((2, num_pos_feats)), 127 | ) 128 | 129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 130 | """Positionally encode points that are normalized to [0,1].""" 131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 132 | coords = 2 * coords - 1 133 | coords = coords @ self.positional_encoding_gaussian_matrix 134 | coords = 2 * np.pi * coords 135 | # outputs d_1 x ... x d_n x C shape 136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 137 | 138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 139 | """Generate positional encoding for a grid of the specified size.""" 140 | h, w = size 141 | device: Any = self.positional_encoding_gaussian_matrix.device 142 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 143 | y_embed = grid.cumsum(dim=0) - 0.5 144 | x_embed = grid.cumsum(dim=1) - 0.5 145 | y_embed = y_embed / h 146 | x_embed = x_embed / w 147 | 148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 149 | return pe.permute(2, 0, 1) # C x H x W 150 | 151 | def forward_with_coords( 152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 153 | ) -> torch.Tensor: 154 | """Positionally encode points that are not normalized to [0,1].""" 155 | coords = coords_input.clone() 156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 159 | 160 | 161 | # Rotary Positional Encoding, adapted from: 162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 163 | # 2. https://github.com/naver-ai/rope-vit 164 | # 3. https://github.com/lucidrains/rotary-embedding-torch 165 | 166 | 167 | def init_t_xy(end_x: int, end_y: int): 168 | t = torch.arange(end_x * end_y, dtype=torch.float32) 169 | t_x = (t % end_x).float() 170 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 171 | return t_x, t_y 172 | 173 | 174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 177 | 178 | t_x, t_y = init_t_xy(end_x, end_y) 179 | freqs_x = torch.outer(t_x, freqs_x) 180 | freqs_y = torch.outer(t_y, freqs_y) 181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 184 | 185 | 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | ndim = x.ndim 188 | assert 0 <= 1 < ndim 189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 191 | return freqs_cis.view(*shape) 192 | 193 | 194 | def apply_rotary_enc( 195 | xq: torch.Tensor, 196 | xk: torch.Tensor, 197 | freqs_cis: torch.Tensor, 198 | repeat_freqs_k: bool = False, 199 | ): 200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 201 | xk_ = ( 202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 203 | if xk.shape[-2] != 0 204 | else None 205 | ) 206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 208 | if xk_ is None: 209 | # no keys to rotate, due to dropout 210 | return xq_out.type_as(xq).to(xq.device), xk 211 | # repeat freqs along seq_len dim to match k seq_len 212 | if repeat_freqs_k: 213 | r = xk_.shape[-2] // xq_.shape[-2] 214 | if freqs_cis.is_cuda: 215 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 216 | else: 217 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 218 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 219 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 220 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 221 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 222 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from functools import partial 9 | from typing import List, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from iopath.common.file_io import g_pathmgr 15 | 16 | from sam2.modeling.backbones.utils import ( 17 | PatchEmbed, 18 | window_partition, 19 | window_unpartition, 20 | ) 21 | 22 | from sam2.modeling.sam2_utils import DropPath, MLP 23 | 24 | 25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 26 | if pool is None: 27 | return x 28 | # (B, H, W, C) -> (B, C, H, W) 29 | x = x.permute(0, 3, 1, 2) 30 | x = pool(x) 31 | # (B, C, H', W') -> (B, H', W', C) 32 | x = x.permute(0, 2, 3, 1) 33 | if norm: 34 | x = norm(x) 35 | 36 | return x 37 | 38 | 39 | class MultiScaleAttention(nn.Module): 40 | def __init__( 41 | self, 42 | dim: int, 43 | dim_out: int, 44 | num_heads: int, 45 | q_pool: nn.Module = None, 46 | ): 47 | super().__init__() 48 | 49 | self.dim = dim 50 | self.dim_out = dim_out 51 | self.num_heads = num_heads 52 | self.q_pool = q_pool 53 | self.qkv = nn.Linear(dim, dim_out * 3) 54 | self.proj = nn.Linear(dim_out, dim_out) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | B, H, W, _ = x.shape 58 | # qkv with shape (B, H * W, 3, nHead, C) 59 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 60 | # q, k, v with shape (B, H * W, nheads, C) 61 | q, k, v = torch.unbind(qkv, 2) 62 | 63 | # Q pooling (for downsample at stage changes) 64 | if self.q_pool: 65 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 66 | H, W = q.shape[1:3] # downsampled shape 67 | q = q.reshape(B, H * W, self.num_heads, -1) 68 | 69 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 70 | x = F.scaled_dot_product_attention( 71 | q.transpose(1, 2), 72 | k.transpose(1, 2), 73 | v.transpose(1, 2), 74 | ) 75 | # Transpose back 76 | x = x.transpose(1, 2) 77 | x = x.reshape(B, H, W, -1) 78 | 79 | x = self.proj(x) 80 | 81 | return x 82 | 83 | 84 | class MultiScaleBlock(nn.Module): 85 | def __init__( 86 | self, 87 | dim: int, 88 | dim_out: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | drop_path: float = 0.0, 92 | norm_layer: Union[nn.Module, str] = "LayerNorm", 93 | q_stride: Tuple[int, int] = None, 94 | act_layer: nn.Module = nn.GELU, 95 | window_size: int = 0, 96 | ): 97 | super().__init__() 98 | 99 | if isinstance(norm_layer, str): 100 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 101 | 102 | self.dim = dim 103 | self.dim_out = dim_out 104 | self.norm1 = norm_layer(dim) 105 | 106 | self.window_size = window_size 107 | 108 | self.pool, self.q_stride = None, q_stride 109 | if self.q_stride: 110 | self.pool = nn.MaxPool2d( 111 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 112 | ) 113 | 114 | self.attn = MultiScaleAttention( 115 | dim, 116 | dim_out, 117 | num_heads=num_heads, 118 | q_pool=self.pool, 119 | ) 120 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 121 | 122 | self.norm2 = norm_layer(dim_out) 123 | self.mlp = MLP( 124 | dim_out, 125 | int(dim_out * mlp_ratio), 126 | dim_out, 127 | num_layers=2, 128 | activation=act_layer, 129 | ) 130 | 131 | if dim != dim_out: 132 | self.proj = nn.Linear(dim, dim_out) 133 | 134 | def forward(self, x: torch.Tensor) -> torch.Tensor: 135 | shortcut = x # B, H, W, C 136 | x = self.norm1(x) 137 | 138 | # Skip connection 139 | if self.dim != self.dim_out: 140 | shortcut = do_pool(self.proj(x), self.pool) 141 | 142 | # Window partition 143 | window_size = self.window_size 144 | if window_size > 0: 145 | H, W = x.shape[1], x.shape[2] 146 | x, pad_hw = window_partition(x, window_size) 147 | 148 | # Window Attention + Q Pooling (if stage change) 149 | x = self.attn(x) 150 | if self.q_stride: 151 | # Shapes have changed due to Q pooling 152 | window_size = self.window_size // self.q_stride[0] 153 | H, W = shortcut.shape[1:3] 154 | 155 | pad_h = (window_size - H % window_size) % window_size 156 | pad_w = (window_size - W % window_size) % window_size 157 | pad_hw = (H + pad_h, W + pad_w) 158 | 159 | # Reverse window partition 160 | if self.window_size > 0: 161 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 162 | 163 | x = shortcut + self.drop_path(x) 164 | # MLP 165 | x = x + self.drop_path(self.mlp(self.norm2(x))) 166 | return x 167 | 168 | 169 | class Hiera(nn.Module): 170 | """ 171 | Reference: https://arxiv.org/abs/2306.00989 172 | """ 173 | 174 | def __init__( 175 | self, 176 | embed_dim: int = 96, # initial embed dim 177 | num_heads: int = 1, # initial number of heads 178 | drop_path_rate: float = 0.0, # stochastic depth 179 | q_pool: int = 3, # number of q_pool stages 180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 182 | dim_mul: float = 2.0, # dim_mul factor at stage shift 183 | head_mul: float = 2.0, # head_mul factor at stage shift 184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 185 | # window size per stage, when not using global att. 186 | window_spec: Tuple[int, ...] = ( 187 | 8, 188 | 4, 189 | 14, 190 | 7, 191 | ), 192 | # global attn in these blocks 193 | global_att_blocks: Tuple[int, ...] = ( 194 | 12, 195 | 16, 196 | 20, 197 | ), 198 | weights_path=None, 199 | return_interm_layers=True, # return feats from every stage 200 | ): 201 | super().__init__() 202 | 203 | assert len(stages) == len(window_spec) 204 | self.window_spec = window_spec 205 | 206 | depth = sum(stages) 207 | self.q_stride = q_stride 208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 209 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 211 | self.return_interm_layers = return_interm_layers 212 | 213 | self.patch_embed = PatchEmbed( 214 | embed_dim=embed_dim, 215 | ) 216 | # Which blocks have global att? 217 | self.global_att_blocks = global_att_blocks 218 | 219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 221 | self.pos_embed = nn.Parameter( 222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 223 | ) 224 | self.pos_embed_window = nn.Parameter( 225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 226 | ) 227 | 228 | dpr = [ 229 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 230 | ] # stochastic depth decay rule 231 | 232 | cur_stage = 1 233 | self.blocks = nn.ModuleList() 234 | 235 | for i in range(depth): 236 | dim_out = embed_dim 237 | # lags by a block, so first block of 238 | # next stage uses an initial window size 239 | # of previous stage and final window size of current stage 240 | window_size = self.window_spec[cur_stage - 1] 241 | 242 | if self.global_att_blocks is not None: 243 | window_size = 0 if i in self.global_att_blocks else window_size 244 | 245 | if i - 1 in self.stage_ends: 246 | dim_out = int(embed_dim * dim_mul) 247 | num_heads = int(num_heads * head_mul) 248 | cur_stage += 1 249 | 250 | block = MultiScaleBlock( 251 | dim=embed_dim, 252 | dim_out=dim_out, 253 | num_heads=num_heads, 254 | drop_path=dpr[i], 255 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 256 | window_size=window_size, 257 | ) 258 | 259 | embed_dim = dim_out 260 | self.blocks.append(block) 261 | 262 | self.channel_list = ( 263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 264 | if return_interm_layers 265 | else [self.blocks[-1].dim_out] 266 | ) 267 | 268 | if weights_path is not None: 269 | with g_pathmgr.open(weights_path, "rb") as f: 270 | chkpt = torch.load(f, map_location="cpu") 271 | logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) 272 | 273 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 274 | h, w = hw 275 | window_embed = self.pos_embed_window 276 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 277 | pos_embed = pos_embed + window_embed.tile( 278 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 279 | ) 280 | pos_embed = pos_embed.permute(0, 2, 3, 1) 281 | return pos_embed 282 | 283 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 284 | x = self.patch_embed(x) 285 | # x: (B, H, W, C) 286 | 287 | # Add pos embed 288 | x = x + self._get_pos_embed(x.shape[1:3]) 289 | 290 | outputs = [] 291 | for i, blk in enumerate(self.blocks): 292 | x = blk(x) 293 | if (i == self.stage_ends[-1]) or ( 294 | i in self.stage_ends and self.return_interm_layers 295 | ): 296 | feats = x.permute(0, 3, 1, 2) 297 | outputs.append(feats) 298 | 299 | return outputs 300 | 301 | def get_layer_id(self, layer_name): 302 | # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 303 | num_layers = self.get_num_layers() 304 | 305 | if layer_name.find("rel_pos") != -1: 306 | return num_layers + 1 307 | elif layer_name.find("pos_embed") != -1: 308 | return 0 309 | elif layer_name.find("patch_embed") != -1: 310 | return 0 311 | elif layer_name.find("blocks") != -1: 312 | return int(layer_name.split("blocks")[1].split(".")[1]) + 1 313 | else: 314 | return num_layers + 1 315 | 316 | def get_num_layers(self) -> int: 317 | return len(self.blocks) 318 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | scratch: 4 | resolution: 1024 5 | train_batch_size: 1 6 | num_train_workers: 10 7 | num_frames: 8 8 | max_num_objects: 3 9 | base_lr: 5.0e-6 10 | vision_lr: 3.0e-06 11 | phases_per_epoch: 1 12 | num_epochs: 40 13 | 14 | dataset: 15 | # PATHS to Dataset 16 | img_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/JPEGImages # PATH to MOSE JPEGImages folder 17 | gt_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/Annotations/ # PATH to MOSE Annotations folder 18 | file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training 19 | multiplier: 2 20 | 21 | # Video transforms 22 | vos: 23 | train_transforms: 24 | - _target_: training.dataset.transforms.ComposeAPI 25 | transforms: 26 | - _target_: training.dataset.transforms.RandomHorizontalFlip 27 | consistent_transform: True 28 | - _target_: training.dataset.transforms.RandomAffine 29 | degrees: 25 30 | shear: 20 31 | image_interpolation: bilinear 32 | consistent_transform: True 33 | - _target_: training.dataset.transforms.RandomResizeAPI 34 | sizes: ${scratch.resolution} 35 | square: true 36 | consistent_transform: True 37 | - _target_: training.dataset.transforms.ColorJitter 38 | consistent_transform: True 39 | brightness: 0.1 40 | contrast: 0.03 41 | saturation: 0.03 42 | hue: null 43 | - _target_: training.dataset.transforms.RandomGrayscale 44 | p: 0.05 45 | consistent_transform: True 46 | - _target_: training.dataset.transforms.ColorJitter 47 | consistent_transform: False 48 | brightness: 0.1 49 | contrast: 0.05 50 | saturation: 0.05 51 | hue: null 52 | - _target_: training.dataset.transforms.ToTensorAPI 53 | - _target_: training.dataset.transforms.NormalizeAPI 54 | mean: [0.485, 0.456, 0.406] 55 | std: [0.229, 0.224, 0.225] 56 | 57 | trainer: 58 | _target_: training.trainer.Trainer 59 | mode: train_only 60 | max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} 61 | accelerator: cuda 62 | seed_value: 123 63 | 64 | model: 65 | _target_: training.model.sam2.SAM2Train 66 | image_encoder: 67 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 68 | scalp: 1 69 | trunk: 70 | _target_: sam2.modeling.backbones.hieradet.Hiera 71 | embed_dim: 112 72 | num_heads: 2 73 | drop_path_rate: 0.1 74 | neck: 75 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 76 | position_encoding: 77 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 78 | num_pos_feats: 256 79 | normalize: true 80 | scale: null 81 | temperature: 10000 82 | d_model: 256 83 | backbone_channel_list: [896, 448, 224, 112] 84 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 85 | fpn_interp_model: nearest 86 | 87 | memory_attention: 88 | _target_: sam2.modeling.memory_attention.MemoryAttention 89 | d_model: 256 90 | pos_enc_at_input: true 91 | layer: 92 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 93 | activation: relu 94 | dim_feedforward: 2048 95 | dropout: 0.1 96 | pos_enc_at_attn: false 97 | self_attention: 98 | _target_: sam2.modeling.sam.transformer.RoPEAttention 99 | rope_theta: 10000.0 100 | feat_sizes: [32, 32] 101 | embedding_dim: 256 102 | num_heads: 1 103 | downsample_rate: 1 104 | dropout: 0.1 105 | d_model: 256 106 | pos_enc_at_cross_attn_keys: true 107 | pos_enc_at_cross_attn_queries: false 108 | cross_attention: 109 | _target_: sam2.modeling.sam.transformer.RoPEAttention 110 | rope_theta: 10000.0 111 | feat_sizes: [32, 32] 112 | rope_k_repeat: True 113 | embedding_dim: 256 114 | num_heads: 1 115 | downsample_rate: 1 116 | dropout: 0.1 117 | kv_in_dim: 64 118 | num_layers: 4 119 | 120 | memory_encoder: 121 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 122 | out_dim: 64 123 | position_encoding: 124 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 125 | num_pos_feats: 64 126 | normalize: true 127 | scale: null 128 | temperature: 10000 129 | mask_downsampler: 130 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 131 | kernel_size: 3 132 | stride: 2 133 | padding: 1 134 | fuser: 135 | _target_: sam2.modeling.memory_encoder.Fuser 136 | layer: 137 | _target_: sam2.modeling.memory_encoder.CXBlock 138 | dim: 256 139 | kernel_size: 7 140 | padding: 3 141 | layer_scale_init_value: 1e-6 142 | use_dwconv: True # depth-wise convs 143 | num_layers: 2 144 | 145 | num_maskmem: 7 146 | image_size: ${scratch.resolution} 147 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 148 | sigmoid_scale_for_mem_enc: 20.0 149 | sigmoid_bias_for_mem_enc: -10.0 150 | use_mask_input_as_output_without_sam: true 151 | # Memory 152 | directly_add_no_mem_embed: true 153 | no_obj_embed_spatial: true 154 | # use high-resolution feature map in the SAM mask decoder 155 | use_high_res_features_in_sam: true 156 | # output 3 masks on the first click on initial conditioning frames 157 | multimask_output_in_sam: true 158 | # SAM heads 159 | iou_prediction_use_sigmoid: True 160 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 161 | use_obj_ptrs_in_encoder: true 162 | add_tpos_enc_to_obj_ptrs: true 163 | proj_tpos_enc_in_obj_ptrs: true 164 | use_signed_tpos_enc_to_obj_ptrs: true 165 | only_obj_ptrs_in_the_past_for_eval: true 166 | # object occlusion prediction 167 | pred_obj_scores: true 168 | pred_obj_scores_mlp: true 169 | fixed_no_obj_ptr: true 170 | # multimask tracking settings 171 | multimask_output_for_tracking: true 172 | use_multimask_token_for_obj_ptr: true 173 | multimask_min_pt_num: 0 174 | multimask_max_pt_num: 1 175 | use_mlp_for_obj_ptr_proj: true 176 | # Compilation flag 177 | # compile_image_encoder: False 178 | 179 | ####### Training specific params ####### 180 | # box/point input and corrections 181 | prob_to_use_pt_input_for_train: 0.5 182 | prob_to_use_pt_input_for_eval: 0.0 183 | prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points 184 | prob_to_use_box_input_for_eval: 0.0 185 | prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors 186 | num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame) 187 | num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame 188 | rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2 189 | add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame) 190 | # maximum 2 initial conditioning frames 191 | num_init_cond_frames_for_train: 2 192 | rand_init_cond_frames_for_train: True # random 1~2 193 | num_correction_pt_per_frame: 7 194 | use_act_ckpt_iterative_pt_sampling: false 195 | 196 | 197 | 198 | num_init_cond_frames_for_eval: 1 # only mask on the first frame 199 | forward_backbone_per_frame_for_eval: True 200 | 201 | 202 | data: 203 | train: 204 | _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset 205 | phases_per_epoch: ${scratch.phases_per_epoch} 206 | batch_sizes: 207 | - ${scratch.train_batch_size} 208 | 209 | datasets: 210 | - _target_: training.dataset.utils.RepeatFactorWrapper 211 | dataset: 212 | _target_: training.dataset.utils.ConcatDataset 213 | datasets: 214 | - _target_: training.dataset.vos_dataset.VOSDataset 215 | transforms: ${vos.train_transforms} 216 | training: true 217 | video_dataset: 218 | _target_: training.dataset.vos_raw_dataset.PNGRawDataset 219 | img_folder: ${dataset.img_folder} 220 | gt_folder: ${dataset.gt_folder} 221 | file_list_txt: ${dataset.file_list_txt} 222 | sampler: 223 | _target_: training.dataset.vos_sampler.RandomUniformSampler 224 | num_frames: ${scratch.num_frames} 225 | max_num_objects: ${scratch.max_num_objects} 226 | multiplier: ${dataset.multiplier} 227 | shuffle: True 228 | num_workers: ${scratch.num_train_workers} 229 | pin_memory: True 230 | drop_last: True 231 | collate_fn: 232 | _target_: training.utils.data_utils.collate_fn 233 | _partial_: true 234 | dict_key: all 235 | 236 | optim: 237 | amp: 238 | enabled: True 239 | amp_dtype: bfloat16 240 | 241 | optimizer: 242 | _target_: torch.optim.AdamW 243 | 244 | gradient_clip: 245 | _target_: training.optimizer.GradientClipper 246 | max_norm: 0.1 247 | norm_type: 2 248 | 249 | param_group_modifiers: 250 | - _target_: training.optimizer.layer_decay_param_modifier 251 | _partial_: True 252 | layer_decay_value: 0.9 253 | apply_to: 'image_encoder.trunk' 254 | overrides: 255 | - pattern: '*pos_embed*' 256 | value: 1.0 257 | 258 | options: 259 | lr: 260 | - scheduler: 261 | _target_: fvcore.common.param_scheduler.CosineParamScheduler 262 | start_value: ${scratch.base_lr} 263 | end_value: ${divide:${scratch.base_lr},10} 264 | - scheduler: 265 | _target_: fvcore.common.param_scheduler.CosineParamScheduler 266 | start_value: ${scratch.vision_lr} 267 | end_value: ${divide:${scratch.vision_lr},10} 268 | param_names: 269 | - 'image_encoder.*' 270 | weight_decay: 271 | - scheduler: 272 | _target_: fvcore.common.param_scheduler.ConstantParamScheduler 273 | value: 0.1 274 | - scheduler: 275 | _target_: fvcore.common.param_scheduler.ConstantParamScheduler 276 | value: 0.0 277 | param_names: 278 | - '*bias*' 279 | module_cls_names: ['torch.nn.LayerNorm'] 280 | 281 | loss: 282 | all: 283 | _target_: training.loss_fns.MultiStepMultiMasksAndIous 284 | weight_dict: 285 | loss_mask: 20 286 | loss_dice: 1 287 | loss_iou: 1 288 | loss_class: 1 289 | supervise_all_iou: true 290 | iou_use_l1_loss: true 291 | pred_obj_scores: true 292 | focal_gamma_obj_score: 0.0 293 | focal_alpha_obj_score: -1.0 294 | 295 | distributed: 296 | backend: nccl 297 | find_unused_parameters: True 298 | 299 | logging: 300 | tensorboard_writer: 301 | _target_: training.utils.logger.make_tensorboard_logger 302 | log_dir: ${launcher.experiment_log_dir}/tensorboard 303 | flush_secs: 120 304 | should_log: True 305 | log_dir: ${launcher.experiment_log_dir}/logs 306 | log_freq: 10 307 | 308 | # initialize from a SAM 2 checkpoint 309 | checkpoint: 310 | save_dir: ${launcher.experiment_log_dir}/checkpoints 311 | save_freq: 0 # 0 only last checkpoint is saved. 312 | model_weight_initializer: 313 | _partial_: True 314 | _target_: training.utils.checkpoint_utils.load_state_dict_into_model 315 | strict: True 316 | ignore_unexpected_keys: null 317 | ignore_missing_keys: null 318 | 319 | state_dict: 320 | _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels 321 | checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint 322 | ckpt_state_dict_keys: ['model'] 323 | 324 | launcher: 325 | num_nodes: 1 326 | gpus_per_node: 8 327 | experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} 328 | 329 | # SLURM args if running on a cluster 330 | submitit: 331 | partition: null 332 | account: null 333 | qos: null 334 | cpus_per_task: 10 335 | use_cluster: false 336 | timeout_hour: 24 337 | name: null 338 | port_range: [10000, 65000] 339 | 340 | -------------------------------------------------------------------------------- /sam2/modeling/sam/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.sam2_utils import LayerNorm2d, MLP 13 | 14 | 15 | class MaskDecoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | transformer_dim: int, 20 | transformer: nn.Module, 21 | num_multimask_outputs: int = 3, 22 | activation: Type[nn.Module] = nn.GELU, 23 | iou_head_depth: int = 3, 24 | iou_head_hidden_dim: int = 256, 25 | use_high_res_features: bool = False, 26 | iou_prediction_use_sigmoid=False, 27 | dynamic_multimask_via_stability=False, 28 | dynamic_multimask_stability_delta=0.05, 29 | dynamic_multimask_stability_thresh=0.98, 30 | pred_obj_scores: bool = False, 31 | pred_obj_scores_mlp: bool = False, 32 | use_multimask_token_for_obj_ptr: bool = False, 33 | ) -> None: 34 | """ 35 | Predicts masks given an image and prompt embeddings, using a 36 | transformer architecture. 37 | 38 | Arguments: 39 | transformer_dim (int): the channel dimension of the transformer 40 | transformer (nn.Module): the transformer used to predict masks 41 | num_multimask_outputs (int): the number of masks to predict 42 | when disambiguating masks 43 | activation (nn.Module): the type of activation to use when 44 | upscaling masks 45 | iou_head_depth (int): the depth of the MLP used to predict 46 | mask quality 47 | iou_head_hidden_dim (int): the hidden dimension of the MLP 48 | used to predict mask quality 49 | """ 50 | super().__init__() 51 | self.transformer_dim = transformer_dim 52 | self.transformer = transformer 53 | 54 | self.num_multimask_outputs = num_multimask_outputs 55 | 56 | self.iou_token = nn.Embedding(1, transformer_dim) 57 | self.num_mask_tokens = num_multimask_outputs + 1 58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 59 | 60 | self.pred_obj_scores = pred_obj_scores 61 | if self.pred_obj_scores: 62 | self.obj_score_token = nn.Embedding(1, transformer_dim) 63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr 64 | 65 | self.output_upscaling = nn.Sequential( 66 | nn.ConvTranspose2d( 67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 68 | ), 69 | LayerNorm2d(transformer_dim // 4), 70 | activation(), 71 | nn.ConvTranspose2d( 72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 73 | ), 74 | activation(), 75 | ) 76 | self.use_high_res_features = use_high_res_features 77 | if use_high_res_features: 78 | self.conv_s0 = nn.Conv2d( 79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 80 | ) 81 | self.conv_s1 = nn.Conv2d( 82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 83 | ) 84 | 85 | self.output_hypernetworks_mlps = nn.ModuleList( 86 | [ 87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 88 | for i in range(self.num_mask_tokens) 89 | ] 90 | ) 91 | 92 | self.iou_prediction_head = MLP( 93 | transformer_dim, 94 | iou_head_hidden_dim, 95 | self.num_mask_tokens, 96 | iou_head_depth, 97 | sigmoid_output=iou_prediction_use_sigmoid, 98 | ) 99 | if self.pred_obj_scores: 100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1) 101 | if pred_obj_scores_mlp: 102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) 103 | 104 | # When outputting a single mask, optionally we can dynamically fall back to the best 105 | # multimask output token if the single mask output token gives low stability scores. 106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability 107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta 108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh 109 | 110 | def forward( 111 | self, 112 | image_embeddings: torch.Tensor, 113 | image_pe: torch.Tensor, 114 | sparse_prompt_embeddings: torch.Tensor, 115 | dense_prompt_embeddings: torch.Tensor, 116 | multimask_output: bool, 117 | repeat_image: bool, 118 | high_res_features: Optional[List[torch.Tensor]] = None, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """ 121 | Predict masks given image and prompt embeddings. 122 | 123 | Arguments: 124 | image_embeddings (torch.Tensor): the embeddings from the image encoder 125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 128 | multimask_output (bool): Whether to return multiple masks or a single 129 | mask. 130 | 131 | Returns: 132 | torch.Tensor: batched predicted masks 133 | torch.Tensor: batched predictions of mask quality 134 | torch.Tensor: batched SAM token for mask output 135 | """ 136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( 137 | image_embeddings=image_embeddings, 138 | image_pe=image_pe, 139 | sparse_prompt_embeddings=sparse_prompt_embeddings, 140 | dense_prompt_embeddings=dense_prompt_embeddings, 141 | repeat_image=repeat_image, 142 | high_res_features=high_res_features, 143 | ) 144 | 145 | # Select the correct mask or masks for output 146 | if multimask_output: 147 | masks = masks[:, 1:, :, :] 148 | iou_pred = iou_pred[:, 1:] 149 | elif self.dynamic_multimask_via_stability and not self.training: 150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) 151 | else: 152 | masks = masks[:, 0:1, :, :] 153 | iou_pred = iou_pred[:, 0:1] 154 | 155 | if multimask_output and self.use_multimask_token_for_obj_ptr: 156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape 157 | else: 158 | # Take the mask output token. Here we *always* use the token for single mask output. 159 | # At test time, even if we track after 1-click (and using multimask_output=True), 160 | # we still take the single mask token here. The rationale is that we always track 161 | # after multiple clicks during training, so the past tokens seen during training 162 | # are always the single mask token (and we'll let it be the object-memory token). 163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape 164 | 165 | # Prepare output 166 | return masks, iou_pred, sam_tokens_out, object_score_logits 167 | 168 | def predict_masks( 169 | self, 170 | image_embeddings: torch.Tensor, 171 | image_pe: torch.Tensor, 172 | sparse_prompt_embeddings: torch.Tensor, 173 | dense_prompt_embeddings: torch.Tensor, 174 | repeat_image: bool, 175 | high_res_features: Optional[List[torch.Tensor]] = None, 176 | ) -> Tuple[torch.Tensor, torch.Tensor]: 177 | """Predicts masks. See 'forward' for more details.""" 178 | # Concatenate output tokens 179 | s = 0 180 | if self.pred_obj_scores: 181 | output_tokens = torch.cat( 182 | [ 183 | self.obj_score_token.weight, 184 | self.iou_token.weight, 185 | self.mask_tokens.weight, 186 | ], 187 | dim=0, 188 | ) 189 | s = 1 190 | else: 191 | output_tokens = torch.cat( 192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 193 | ) 194 | output_tokens = output_tokens.unsqueeze(0).expand( 195 | sparse_prompt_embeddings.size(0), -1, -1 196 | ) 197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 198 | 199 | # Expand per-image data in batch direction to be per-mask 200 | if repeat_image: 201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 202 | else: 203 | assert image_embeddings.shape[0] == tokens.shape[0] 204 | src = image_embeddings 205 | src = src + dense_prompt_embeddings 206 | assert ( 207 | image_pe.size(0) == 1 208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" 209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 210 | b, c, h, w = src.shape 211 | 212 | # Run the transformer 213 | hs, src = self.transformer(src, pos_src, tokens) 214 | iou_token_out = hs[:, s, :] 215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] 216 | 217 | # Upscale mask embeddings and predict masks using the mask tokens 218 | src = src.transpose(1, 2).view(b, c, h, w) 219 | if not self.use_high_res_features: 220 | upscaled_embedding = self.output_upscaling(src) 221 | else: 222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling 223 | feat_s0, feat_s1 = high_res_features 224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) 225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) 226 | 227 | hyper_in_list: List[torch.Tensor] = [] 228 | for i in range(self.num_mask_tokens): 229 | hyper_in_list.append( 230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 231 | ) 232 | hyper_in = torch.stack(hyper_in_list, dim=1) 233 | b, c, h, w = upscaled_embedding.shape 234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 235 | 236 | # Generate mask quality predictions 237 | iou_pred = self.iou_prediction_head(iou_token_out) 238 | if self.pred_obj_scores: 239 | assert s == 1 240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) 241 | else: 242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) 244 | 245 | return masks, iou_pred, mask_tokens_out, object_score_logits 246 | 247 | def _get_stability_scores(self, mask_logits): 248 | """ 249 | Compute stability scores of the mask logits based on the IoU between upper and 250 | lower thresholds. 251 | """ 252 | mask_logits = mask_logits.flatten(-2) 253 | stability_delta = self.dynamic_multimask_stability_delta 254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() 255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() 256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) 257 | return stability_scores 258 | 259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): 260 | """ 261 | When outputting a single mask, if the stability score from the current single-mask 262 | output (based on output token 0) falls below a threshold, we instead select from 263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted 264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking. 265 | """ 266 | # The best mask from multimask output tokens (1~3) 267 | multimask_logits = all_mask_logits[:, 1:, :, :] 268 | multimask_iou_scores = all_iou_scores[:, 1:] 269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) 270 | batch_inds = torch.arange( 271 | multimask_iou_scores.size(0), device=all_iou_scores.device 272 | ) 273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] 274 | best_multimask_logits = best_multimask_logits.unsqueeze(1) 275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] 276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) 277 | 278 | # The mask from singlemask output token 0 and its stability score 279 | singlemask_logits = all_mask_logits[:, 0:1, :, :] 280 | singlemask_iou_scores = all_iou_scores[:, 0:1] 281 | stability_scores = self._get_stability_scores(singlemask_logits) 282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh 283 | 284 | # Dynamically fall back to best multimask output upon low stability scores. 285 | mask_logits_out = torch.where( 286 | is_stable[..., None, None].expand_as(singlemask_logits), 287 | singlemask_logits, 288 | best_multimask_logits, 289 | ) 290 | iou_scores_out = torch.where( 291 | is_stable.expand_as(singlemask_iou_scores), 292 | singlemask_iou_scores, 293 | best_multimask_iou_scores, 294 | ) 295 | return mask_logits_out, iou_scores_out 296 | -------------------------------------------------------------------------------- /sam2/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from copy import deepcopy 9 | from itertools import product 10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py 16 | 17 | 18 | class MaskData: 19 | """ 20 | A structure for storing masks and their related data in batched format. 21 | Implements basic filtering and concatenation. 22 | """ 23 | 24 | def __init__(self, **kwargs) -> None: 25 | for v in kwargs.values(): 26 | assert isinstance( 27 | v, (list, np.ndarray, torch.Tensor) 28 | ), "MaskData only supports list, numpy arrays, and torch tensors." 29 | self._stats = dict(**kwargs) 30 | 31 | def __setitem__(self, key: str, item: Any) -> None: 32 | assert isinstance( 33 | item, (list, np.ndarray, torch.Tensor) 34 | ), "MaskData only supports list, numpy arrays, and torch tensors." 35 | self._stats[key] = item 36 | 37 | def __delitem__(self, key: str) -> None: 38 | del self._stats[key] 39 | 40 | def __getitem__(self, key: str) -> Any: 41 | return self._stats[key] 42 | 43 | def items(self) -> ItemsView[str, Any]: 44 | return self._stats.items() 45 | 46 | def filter(self, keep: torch.Tensor) -> None: 47 | for k, v in self._stats.items(): 48 | if v is None: 49 | self._stats[k] = None 50 | elif isinstance(v, torch.Tensor): 51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 52 | elif isinstance(v, np.ndarray): 53 | self._stats[k] = v[keep.detach().cpu().numpy()] 54 | elif isinstance(v, list) and keep.dtype == torch.bool: 55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 56 | elif isinstance(v, list): 57 | self._stats[k] = [v[i] for i in keep] 58 | else: 59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 60 | 61 | def cat(self, new_stats: "MaskData") -> None: 62 | for k, v in new_stats.items(): 63 | if k not in self._stats or self._stats[k] is None: 64 | self._stats[k] = deepcopy(v) 65 | elif isinstance(v, torch.Tensor): 66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 67 | elif isinstance(v, np.ndarray): 68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 69 | elif isinstance(v, list): 70 | self._stats[k] = self._stats[k] + deepcopy(v) 71 | else: 72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 73 | 74 | def to_numpy(self) -> None: 75 | for k, v in self._stats.items(): 76 | if isinstance(v, torch.Tensor): 77 | self._stats[k] = v.float().detach().cpu().numpy() 78 | 79 | 80 | def is_box_near_crop_edge( 81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 82 | ) -> torch.Tensor: 83 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 90 | return torch.any(near_crop_edge, dim=1) 91 | 92 | 93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 94 | box_xywh = deepcopy(box_xyxy) 95 | box_xywh[2] = box_xywh[2] - box_xywh[0] 96 | box_xywh[3] = box_xywh[3] - box_xywh[1] 97 | return box_xywh 98 | 99 | 100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 101 | assert len(args) > 0 and all( 102 | len(a) == len(args[0]) for a in args 103 | ), "Batched iteration must have inputs of all the same size." 104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 105 | for b in range(n_batches): 106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 107 | 108 | 109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 110 | """ 111 | Encodes masks to an uncompressed RLE, in the format expected by 112 | pycoco tools. 113 | """ 114 | # Put in fortran order and flatten h,w 115 | b, h, w = tensor.shape 116 | tensor = tensor.permute(0, 2, 1).flatten(1) 117 | 118 | # Compute change indices 119 | diff = tensor[:, 1:] ^ tensor[:, :-1] 120 | change_indices = diff.nonzero() 121 | 122 | # Encode run length 123 | out = [] 124 | for i in range(b): 125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 126 | cur_idxs = torch.cat( 127 | [ 128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | cur_idxs + 1, 130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 131 | ] 132 | ) 133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 134 | counts = [] if tensor[i, 0] == 0 else [0] 135 | counts.extend(btw_idxs.detach().cpu().tolist()) 136 | out.append({"size": [h, w], "counts": counts}) 137 | return out 138 | 139 | 140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 141 | """Compute a binary mask from an uncompressed RLE.""" 142 | h, w = rle["size"] 143 | mask = np.empty(h * w, dtype=bool) 144 | idx = 0 145 | parity = False 146 | for count in rle["counts"]: 147 | mask[idx : idx + count] = parity 148 | idx += count 149 | parity ^= True 150 | mask = mask.reshape(w, h) 151 | return mask.transpose() # Put in C order 152 | 153 | 154 | def area_from_rle(rle: Dict[str, Any]) -> int: 155 | return sum(rle["counts"][1::2]) 156 | 157 | 158 | def calculate_stability_score( 159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 160 | ) -> torch.Tensor: 161 | """ 162 | Computes the stability score for a batch of masks. The stability 163 | score is the IoU between the binary masks obtained by thresholding 164 | the predicted mask logits at high and low values. 165 | """ 166 | # One mask is always contained inside the other. 167 | # Save memory by preventing unnecessary cast to torch.int64 168 | intersections = ( 169 | (masks > (mask_threshold + threshold_offset)) 170 | .sum(-1, dtype=torch.int16) 171 | .sum(-1, dtype=torch.int32) 172 | ) 173 | unions = ( 174 | (masks > (mask_threshold - threshold_offset)) 175 | .sum(-1, dtype=torch.int16) 176 | .sum(-1, dtype=torch.int32) 177 | ) 178 | return intersections / unions 179 | 180 | 181 | def build_point_grid(n_per_side: int) -> np.ndarray: 182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 183 | offset = 1 / (2 * n_per_side) 184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 188 | return points 189 | 190 | 191 | def build_all_layer_point_grids( 192 | n_per_side: int, n_layers: int, scale_per_layer: int 193 | ) -> List[np.ndarray]: 194 | """Generates point grids for all crop layers.""" 195 | points_by_layer = [] 196 | for i in range(n_layers + 1): 197 | n_points = int(n_per_side / (scale_per_layer**i)) 198 | points_by_layer.append(build_point_grid(n_points)) 199 | return points_by_layer 200 | 201 | 202 | def generate_crop_boxes( 203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 204 | ) -> Tuple[List[List[int]], List[int]]: 205 | """ 206 | Generates a list of crop boxes of different sizes. Each layer 207 | has (2**i)**2 boxes for the ith layer. 208 | """ 209 | crop_boxes, layer_idxs = [], [] 210 | im_h, im_w = im_size 211 | short_side = min(im_h, im_w) 212 | 213 | # Original image 214 | crop_boxes.append([0, 0, im_w, im_h]) 215 | layer_idxs.append(0) 216 | 217 | def crop_len(orig_len, n_crops, overlap): 218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 219 | 220 | for i_layer in range(n_layers): 221 | n_crops_per_side = 2 ** (i_layer + 1) 222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 223 | 224 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 225 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 226 | 227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 229 | 230 | # Crops in XYWH format 231 | for x0, y0 in product(crop_box_x0, crop_box_y0): 232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 233 | crop_boxes.append(box) 234 | layer_idxs.append(i_layer + 1) 235 | 236 | return crop_boxes, layer_idxs 237 | 238 | 239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 240 | x0, y0, _, _ = crop_box 241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 242 | # Check if boxes has a channel dimension 243 | if len(boxes.shape) == 3: 244 | offset = offset.unsqueeze(1) 245 | return boxes + offset 246 | 247 | 248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 249 | x0, y0, _, _ = crop_box 250 | offset = torch.tensor([[x0, y0]], device=points.device) 251 | # Check if points has a channel dimension 252 | if len(points.shape) == 3: 253 | offset = offset.unsqueeze(1) 254 | return points + offset 255 | 256 | 257 | def uncrop_masks( 258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 259 | ) -> torch.Tensor: 260 | x0, y0, x1, y1 = crop_box 261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 262 | return masks 263 | # Coordinate transform masks 264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 265 | pad = (x0, pad_x - x0, y0, pad_y - y0) 266 | return torch.nn.functional.pad(masks, pad, value=0) 267 | 268 | 269 | def remove_small_regions( 270 | mask: np.ndarray, area_thresh: float, mode: str 271 | ) -> Tuple[np.ndarray, bool]: 272 | """ 273 | Removes small disconnected regions and holes in a mask. Returns the 274 | mask and an indicator of if the mask has been modified. 275 | """ 276 | import cv2 # type: ignore 277 | 278 | assert mode in ["holes", "islands"] 279 | correct_holes = mode == "holes" 280 | working_mask = (correct_holes ^ mask).astype(np.uint8) 281 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 282 | sizes = stats[:, -1][1:] # Row 0 is background label 283 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 284 | if len(small_regions) == 0: 285 | return mask, False 286 | fill_labels = [0] + small_regions 287 | if not correct_holes: 288 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 289 | # If every region is below threshold, keep largest 290 | if len(fill_labels) == 0: 291 | fill_labels = [int(np.argmax(sizes)) + 1] 292 | mask = np.isin(regions, fill_labels) 293 | return mask, True 294 | 295 | 296 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 297 | from pycocotools import mask as mask_utils # type: ignore 298 | 299 | h, w = uncompressed_rle["size"] 300 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 301 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 302 | return rle 303 | 304 | 305 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 306 | """ 307 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 308 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 309 | """ 310 | # torch.max below raises an error on empty inputs, just skip in this case 311 | if torch.numel(masks) == 0: 312 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 313 | 314 | # Normalize shape to CxHxW 315 | shape = masks.shape 316 | h, w = shape[-2:] 317 | if len(shape) > 2: 318 | masks = masks.flatten(0, -3) 319 | else: 320 | masks = masks.unsqueeze(0) 321 | 322 | # Get top and bottom edges 323 | in_height, _ = torch.max(masks, dim=-1) 324 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 325 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 326 | in_height_coords = in_height_coords + h * (~in_height) 327 | top_edges, _ = torch.min(in_height_coords, dim=-1) 328 | 329 | # Get left and right edges 330 | in_width, _ = torch.max(masks, dim=-2) 331 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 332 | right_edges, _ = torch.max(in_width_coords, dim=-1) 333 | in_width_coords = in_width_coords + w * (~in_width) 334 | left_edges, _ = torch.min(in_width_coords, dim=-1) 335 | 336 | # If the mask is empty the right edge will be to the left of the left edge. 337 | # Replace these boxes with [0, 0, 0, 0] 338 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 339 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 340 | out = out * (~empty_filter).unsqueeze(-1) 341 | 342 | # Return to original shape 343 | if len(shape) > 2: 344 | out = out.reshape(*shape[:-2], 4) 345 | else: 346 | out = out[0] 347 | 348 | return out 349 | -------------------------------------------------------------------------------- /sam2/modeling/sam/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import math 9 | import warnings 10 | from functools import partial 11 | from typing import Tuple, Type 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis 18 | from sam2.modeling.sam2_utils import MLP 19 | from sam2.utils.misc import get_sdpa_settings 20 | 21 | warnings.simplefilter(action="ignore", category=FutureWarning) 22 | # Check whether Flash Attention is available (and use it by default) 23 | OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() 24 | # A fallback setting to allow all available kernels if Flash Attention fails 25 | ALLOW_ALL_KERNELS = False 26 | 27 | 28 | def sdp_kernel_context(dropout_p): 29 | """ 30 | Get the context for the attention scaled dot-product kernel. We use Flash Attention 31 | by default, but fall back to all available kernels if Flash Attention fails. 32 | """ 33 | if ALLOW_ALL_KERNELS: 34 | return contextlib.nullcontext() 35 | 36 | return torch.backends.cuda.sdp_kernel( 37 | enable_flash=USE_FLASH_ATTN, 38 | # if Flash attention kernel is off, then math kernel needs to be enabled 39 | enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, 40 | enable_mem_efficient=OLD_GPU, 41 | ) 42 | 43 | 44 | class TwoWayTransformer(nn.Module): 45 | def __init__( 46 | self, 47 | depth: int, 48 | embedding_dim: int, 49 | num_heads: int, 50 | mlp_dim: int, 51 | activation: Type[nn.Module] = nn.ReLU, 52 | attention_downsample_rate: int = 2, 53 | ) -> None: 54 | """ 55 | A transformer decoder that attends to an input image using 56 | queries whose positional embedding is supplied. 57 | 58 | Args: 59 | depth (int): number of layers in the transformer 60 | embedding_dim (int): the channel dimension for the input embeddings 61 | num_heads (int): the number of heads for multihead attention. Must 62 | divide embedding_dim 63 | mlp_dim (int): the channel dimension internal to the MLP block 64 | activation (nn.Module): the activation to use in the MLP block 65 | """ 66 | super().__init__() 67 | self.depth = depth 68 | self.embedding_dim = embedding_dim 69 | self.num_heads = num_heads 70 | self.mlp_dim = mlp_dim 71 | self.layers = nn.ModuleList() 72 | 73 | for i in range(depth): 74 | self.layers.append( 75 | TwoWayAttentionBlock( 76 | embedding_dim=embedding_dim, 77 | num_heads=num_heads, 78 | mlp_dim=mlp_dim, 79 | activation=activation, 80 | attention_downsample_rate=attention_downsample_rate, 81 | skip_first_layer_pe=(i == 0), 82 | ) 83 | ) 84 | 85 | self.final_attn_token_to_image = Attention( 86 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 87 | ) 88 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 89 | 90 | def forward( 91 | self, 92 | image_embedding: Tensor, 93 | image_pe: Tensor, 94 | point_embedding: Tensor, 95 | ) -> Tuple[Tensor, Tensor]: 96 | """ 97 | Args: 98 | image_embedding (torch.Tensor): image to attend to. Should be shape 99 | B x embedding_dim x h x w for any h and w. 100 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 101 | have the same shape as image_embedding. 102 | point_embedding (torch.Tensor): the embedding to add to the query points. 103 | Must have shape B x N_points x embedding_dim for any N_points. 104 | 105 | Returns: 106 | torch.Tensor: the processed point_embedding 107 | torch.Tensor: the processed image_embedding 108 | """ 109 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 110 | bs, c, h, w = image_embedding.shape 111 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 112 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 113 | 114 | # Prepare queries 115 | queries = point_embedding 116 | keys = image_embedding 117 | 118 | # Apply transformer blocks and final layernorm 119 | for layer in self.layers: 120 | queries, keys = layer( 121 | queries=queries, 122 | keys=keys, 123 | query_pe=point_embedding, 124 | key_pe=image_pe, 125 | ) 126 | 127 | # Apply the final attention layer from the points to the image 128 | q = queries + point_embedding 129 | k = keys + image_pe 130 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 131 | queries = queries + attn_out 132 | queries = self.norm_final_attn(queries) 133 | 134 | return queries, keys 135 | 136 | 137 | class TwoWayAttentionBlock(nn.Module): 138 | def __init__( 139 | self, 140 | embedding_dim: int, 141 | num_heads: int, 142 | mlp_dim: int = 2048, 143 | activation: Type[nn.Module] = nn.ReLU, 144 | attention_downsample_rate: int = 2, 145 | skip_first_layer_pe: bool = False, 146 | ) -> None: 147 | """ 148 | A transformer block with four layers: (1) self-attention of sparse 149 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 150 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 151 | inputs. 152 | 153 | Arguments: 154 | embedding_dim (int): the channel dimension of the embeddings 155 | num_heads (int): the number of heads in the attention layers 156 | mlp_dim (int): the hidden dimension of the mlp block 157 | activation (nn.Module): the activation of the mlp block 158 | skip_first_layer_pe (bool): skip the PE on the first layer 159 | """ 160 | super().__init__() 161 | self.self_attn = Attention(embedding_dim, num_heads) 162 | self.norm1 = nn.LayerNorm(embedding_dim) 163 | 164 | self.cross_attn_token_to_image = Attention( 165 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 166 | ) 167 | self.norm2 = nn.LayerNorm(embedding_dim) 168 | 169 | self.mlp = MLP( 170 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation 171 | ) 172 | self.norm3 = nn.LayerNorm(embedding_dim) 173 | 174 | self.norm4 = nn.LayerNorm(embedding_dim) 175 | self.cross_attn_image_to_token = Attention( 176 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 177 | ) 178 | 179 | self.skip_first_layer_pe = skip_first_layer_pe 180 | 181 | def forward( 182 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 183 | ) -> Tuple[Tensor, Tensor]: 184 | # Self attention block 185 | if self.skip_first_layer_pe: 186 | queries = self.self_attn(q=queries, k=queries, v=queries) 187 | else: 188 | q = queries + query_pe 189 | attn_out = self.self_attn(q=q, k=q, v=queries) 190 | queries = queries + attn_out 191 | queries = self.norm1(queries) 192 | 193 | # Cross attention block, tokens attending to image embedding 194 | q = queries + query_pe 195 | k = keys + key_pe 196 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 197 | queries = queries + attn_out 198 | queries = self.norm2(queries) 199 | 200 | # MLP block 201 | mlp_out = self.mlp(queries) 202 | queries = queries + mlp_out 203 | queries = self.norm3(queries) 204 | 205 | # Cross attention block, image embedding attending to tokens 206 | q = queries + query_pe 207 | k = keys + key_pe 208 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 209 | keys = keys + attn_out 210 | keys = self.norm4(keys) 211 | 212 | return queries, keys 213 | 214 | 215 | class Attention(nn.Module): 216 | """ 217 | An attention layer that allows for downscaling the size of the embedding 218 | after projection to queries, keys, and values. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | embedding_dim: int, 224 | num_heads: int, 225 | downsample_rate: int = 1, 226 | dropout: float = 0.0, 227 | kv_in_dim: int = None, 228 | ) -> None: 229 | super().__init__() 230 | self.embedding_dim = embedding_dim 231 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim 232 | self.internal_dim = embedding_dim // downsample_rate 233 | self.num_heads = num_heads 234 | assert ( 235 | self.internal_dim % num_heads == 0 236 | ), "num_heads must divide embedding_dim." 237 | 238 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 239 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 240 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 241 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 242 | 243 | self.dropout_p = dropout 244 | 245 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 246 | b, n, c = x.shape 247 | x = x.reshape(b, n, num_heads, c // num_heads) 248 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 249 | 250 | def _recombine_heads(self, x: Tensor) -> Tensor: 251 | b, n_heads, n_tokens, c_per_head = x.shape 252 | x = x.transpose(1, 2) 253 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 254 | 255 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 256 | # Input projections 257 | q = self.q_proj(q) 258 | k = self.k_proj(k) 259 | v = self.v_proj(v) 260 | 261 | # Separate into heads 262 | q = self._separate_heads(q, self.num_heads) 263 | k = self._separate_heads(k, self.num_heads) 264 | v = self._separate_heads(v, self.num_heads) 265 | 266 | dropout_p = self.dropout_p if self.training else 0.0 267 | # Attention 268 | try: 269 | with sdp_kernel_context(dropout_p): 270 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 271 | except Exception as e: 272 | # Fall back to all kernels if the Flash attention kernel fails 273 | warnings.warn( 274 | f"Flash Attention kernel failed due to: {e}\nFalling back to all available " 275 | f"kernels for scaled_dot_product_attention (which may have a slower speed).", 276 | category=UserWarning, 277 | stacklevel=2, 278 | ) 279 | global ALLOW_ALL_KERNELS 280 | ALLOW_ALL_KERNELS = True 281 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 282 | 283 | out = self._recombine_heads(out) 284 | out = self.out_proj(out) 285 | 286 | return out 287 | 288 | 289 | class RoPEAttention(Attention): 290 | """Attention with rotary position encoding.""" 291 | 292 | def __init__( 293 | self, 294 | *args, 295 | rope_theta=10000.0, 296 | # whether to repeat q rope to match k length 297 | # this is needed for cross-attention to memories 298 | rope_k_repeat=False, 299 | feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution 300 | **kwargs, 301 | ): 302 | super().__init__(*args, **kwargs) 303 | 304 | self.compute_cis = partial( 305 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta 306 | ) 307 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) 308 | self.freqs_cis = freqs_cis 309 | self.rope_k_repeat = rope_k_repeat 310 | 311 | def forward( 312 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 313 | ) -> Tensor: 314 | # Input projections 315 | q = self.q_proj(q) 316 | k = self.k_proj(k) 317 | v = self.v_proj(v) 318 | 319 | # Separate into heads 320 | q = self._separate_heads(q, self.num_heads) 321 | k = self._separate_heads(k, self.num_heads) 322 | v = self._separate_heads(v, self.num_heads) 323 | 324 | # Apply rotary position encoding 325 | w = h = math.sqrt(q.shape[-2]) 326 | self.freqs_cis = self.freqs_cis.to(q.device) 327 | if self.freqs_cis.shape[0] != q.shape[-2]: 328 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) 329 | if q.shape[-2] != k.shape[-2]: 330 | assert self.rope_k_repeat 331 | 332 | num_k_rope = k.size(-2) - num_k_exclude_rope 333 | q, k[:, :, :num_k_rope] = apply_rotary_enc( 334 | q, 335 | k[:, :, :num_k_rope], 336 | freqs_cis=self.freqs_cis, 337 | repeat_freqs_k=self.rope_k_repeat, 338 | ) 339 | 340 | dropout_p = self.dropout_p if self.training else 0.0 341 | # Attention 342 | try: 343 | with sdp_kernel_context(dropout_p): 344 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 345 | except Exception as e: 346 | # Fall back to all kernels if the Flash attention kernel fails 347 | warnings.warn( 348 | f"Flash Attention kernel failed due to: {e}\nFalling back to all available " 349 | f"kernels for scaled_dot_product_attention (which may have a slower speed).", 350 | category=UserWarning, 351 | stacklevel=2, 352 | ) 353 | global ALLOW_ALL_KERNELS 354 | ALLOW_ALL_KERNELS = True 355 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 356 | 357 | out = self._recombine_heads(out) 358 | out = self.out_proj(out) 359 | 360 | return out 361 | --------------------------------------------------------------------------------