├── sav_dataset ├── __init__.py ├── utils │ ├── __init__.py │ └── endo_sav_utils.py ├── requirements.txt ├── LICENSE_VOS_BENCHMARK ├── LICENSE ├── LICENSE_DAVIS ├── endo_sav_evaluator.py └── README.md ├── sam2 ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_t.yaml ├── sam2_hiera_b+.yaml ├── modeling │ ├── __init__.py │ ├── sam │ │ ├── __init__.py │ │ └── prompt_encoder.py │ ├── backbones │ │ ├── __init__.py │ │ ├── utils.py │ │ └── image_encoder.py │ ├── memory_attention.py │ ├── memory_encoder.py │ └── position_encoding.py ├── utils │ ├── __init__.py │ └── transforms.py ├── __init__.py ├── benchmark.py ├── configs │ ├── sam2 │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_s.yaml │ │ ├── sam2_hiera_l.yaml │ │ └── sam2_hiera_t.yaml │ └── sam2.1 │ │ ├── sam2.1_hiera_b+.yaml │ │ ├── sam2.1_hiera_s.yaml │ │ ├── sam2.1_hiera_l.yaml │ │ └── sam2.1_hiera_t.yaml ├── build_sam.py └── csrc │ └── connected_components.cu ├── assets ├── architecture.png └── comparison.png ├── pyproject.toml ├── training ├── __init__.py ├── model │ └── __init__.py ├── utils │ ├── __init__.py │ ├── data_utils.py │ ├── logger.py │ └── train_utils.py ├── dataset │ ├── __init__.py │ ├── vos_sampler.py │ ├── utils.py │ ├── vos_dataset.py │ └── sam2_datasets.py ├── assets │ └── MOSE_sample_val_list.txt ├── scripts │ └── sav_frame_extraction_submitit.py └── README.md ├── .gitignore ├── LICENSE_cctorch ├── tools └── README.md ├── checkpoints └── download_ckpts.sh ├── README.md ├── setup.py └── INSTALL.md /sav_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sav_dataset/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Surgical-SAM-2/HEAD/assets/architecture.png -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinlab-imvr/Surgical-SAM-2/HEAD/assets/comparison.png -------------------------------------------------------------------------------- /sav_dataset/requirements.txt: -------------------------------------------------------------------------------- 1 | pycocoevalcap 2 | scikit-image 3 | opencv-python 4 | tqdm 5 | pillow 6 | numpy 7 | matplotlib -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=61.0", 4 | "torch>=2.5.1", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .idea/ 3 | .DS_Store 4 | __pycache__/ 5 | *-checkpoint.ipynb 6 | .venv 7 | *.egg* 8 | build/* 9 | _C.* 10 | outputs/* 11 | checkpoints/*.pt 12 | demo/backend/checkpoints/*.pt 13 | datasets/* 14 | model_weights/* 15 | /sav_dataset/example/ 16 | -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE_VOS_BENCHMARK: -------------------------------------------------------------------------------- 1 | Copyright 2023 Rex Cheng 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /sav_dataset/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SAM 2 Eval software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE_DAVIS: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /LICENSE_cctorch: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | ## SAM 2 toolkits 2 | 3 | This directory provides toolkits for additional SAM 2 use cases. 4 | 5 | ### Semi-supervised VOS inference 6 | 7 | The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. 8 | 9 | After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. 10 | ```bash 11 | python ./tools/vos_inference.py \ 12 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 13 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 14 | --base_video_dir /path-to-davis-2017/JPEGImages/480p \ 15 | --input_mask_dir /path-to-davis-2017/Annotations/480p \ 16 | --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ 17 | --output_mask_dir ./outputs/davis_2017_pred_pngs 18 | ``` 19 | (replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) 20 | 21 | To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. 22 | ```bash 23 | python ./tools/vos_inference.py \ 24 | --sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \ 25 | --sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \ 26 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 27 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 28 | --video_list_file /path-to-sav-val/sav_val.txt \ 29 | --per_obj_png_file \ 30 | --output_mask_dir ./outputs/sav_val_pred_pngs 31 | ``` 32 | (replace `/path-to-sav-val` with the path to SA-V val) 33 | 34 | Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. 35 | 36 | Note: by default, the `vos_inference.py` script above assumes that all objects to track already appear on frame 0 in each video (as is the case in DAVIS, MOSE or SA-V). **For VOS datasets that don't have all objects to track appearing in the first frame (such as LVOS or YouTube-VOS), please add the `--track_object_appearing_later_in_video` flag when using `vos_inference.py`**. 37 | -------------------------------------------------------------------------------- /checkpoints/download_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Use either wget or curl to download the checkpoints 10 | if command -v wget &> /dev/null; then 11 | CMD="wget" 12 | elif command -v curl &> /dev/null; then 13 | CMD="curl -L -O" 14 | else 15 | echo "Please install wget or curl to download the checkpoints." 16 | exit 1 17 | fi 18 | 19 | # Define the URLs for SAM 2 checkpoints 20 | # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824" 21 | # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt" 22 | # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt" 23 | # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt" 24 | # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt" 25 | 26 | # Download each of the four checkpoints using wget 27 | # echo "Downloading sam2_hiera_tiny.pt checkpoint..." 28 | # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; } 29 | 30 | # echo "Downloading sam2_hiera_small.pt checkpoint..." 31 | # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; } 32 | 33 | # echo "Downloading sam2_hiera_base_plus.pt checkpoint..." 34 | # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; } 35 | 36 | # echo "Downloading sam2_hiera_large.pt checkpoint..." 37 | # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; } 38 | 39 | # Define the URLs for SAM 2.1 checkpoints 40 | SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824" 41 | sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt" 42 | sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt" 43 | sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt" 44 | sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt" 45 | 46 | # SAM 2.1 checkpoints 47 | #echo "Downloading sam2.1_hiera_tiny.pt checkpoint..." 48 | #$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; } 49 | 50 | echo "Downloading sam2.1_hiera_small.pt checkpoint..." 51 | $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; } 52 | 53 | #echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..." 54 | #$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; } 55 | 56 | #echo "Downloading sam2.1_hiera_large.pt checkpoint..." 57 | #$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; } 58 | 59 | #echo "All checkpoints are downloaded successfully." 60 | -------------------------------------------------------------------------------- /training/assets/MOSE_sample_val_list.txt: -------------------------------------------------------------------------------- 1 | 32e5d721 2 | 5bad0bab 3 | 267bfd6c 4 | 0a43a414 5 | 56c56ca9 6 | 9a1146b3 7 | c6ad7aaf 8 | 78a1f4b1 9 | fc455e73 10 | 072e7b3f 11 | 77ccb57d 12 | a76ee415 13 | 8cdcfc17 14 | 5d518b42 15 | 376dd830 16 | 0e843fc8 17 | 2af0e766 18 | 2bd4e845 19 | de2f2a6a 20 | ade9ee91 21 | 001ca3cb 22 | fc4c1c67 23 | 8ef55579 24 | b84ce852 25 | 4cc8528a 26 | 767ffaaa 27 | 112a2ef0 28 | a338c8aa 29 | cbd144f5 30 | 5ff72128 31 | 86a949e2 32 | 9f2323ac 33 | 1fab1d1c 34 | 75924351 35 | ef55817b 36 | 02deca50 37 | 4d979d99 38 | 4d65f873 39 | 28470fa0 40 | 0d1575fe 41 | 06ea172e 42 | 29a6ddc2 43 | 797f1bec 44 | 780e7a99 45 | b9ed5b44 46 | 02a236b4 47 | 607d8ff5 48 | af5666b2 49 | 0558d0ed 50 | a938c6b2 51 | 103df575 52 | 77110e80 53 | 739e5a07 54 | 6763a576 55 | 06ebc138 56 | ba4b3b09 57 | b35cc2f3 58 | 4e0597a0 59 | 5949ee84 60 | 5348d547 61 | 323c4236 62 | b3b51117 63 | 55727ddd 64 | ab2714f3 65 | d2878895 66 | c0734cb3 67 | 94f7c53e 68 | 2a2745e5 69 | 442ffb54 70 | 3592425a 71 | 50ae03b0 72 | 5f150435 73 | 3067f9fa 74 | 9ffb2818 75 | adeaf5aa 76 | 31caacec 77 | 1cd99b86 78 | aa22f9d0 79 | 8fa50320 80 | e6348d2c 81 | 42ff84a5 82 | 8c8b7913 83 | c96adcbc 84 | 495be321 85 | db735509 86 | ee113fc4 87 | a678cdab 88 | c409ca4d 89 | 68d2b259 90 | 592b4dee 91 | 4e2b4dc7 92 | eb4d26e1 93 | 2009a00f 94 | bec5c89d 95 | 67191f24 96 | a3e85b4b 97 | da7080cd 98 | 80d978e9 99 | 36dcb93f 100 | a41e8c44 101 | 12fdc864 102 | 46d140ea 103 | 657c9dd9 104 | a86f84ee 105 | 90c1c43d 106 | 33015509 107 | afc7664d 108 | 23df06e1 109 | 291d4799 110 | 0ab75563 111 | 251bf059 112 | bcefdcc4 113 | ce9a2796 114 | 94d3403a 115 | 8f2e04bc 116 | f9cda066 117 | 9dfa2cc5 118 | 66924c91 119 | e765a09e 120 | 15654ee1 121 | 48e0bd39 122 | ee095221 123 | 2463609b 124 | 544d0d1f 125 | 51b8c2e1 126 | d321dde4 127 | 4cb11a5f 128 | d7058a0d 129 | 37af282a 130 | fabae187 131 | 7be91184 132 | 181ec185 133 | 2d16ceeb 134 | b56be4b1 135 | 6699eff0 136 | 79acac96 137 | d61c4665 138 | 0c13e1e7 139 | 100f6ecf 140 | 71217dfc 141 | 82df0888 142 | 4c42c747 143 | c9fdf703 144 | d2efeb4b 145 | 69ed9d14 146 | 64914fb6 147 | 255bedbc 148 | 4ea934d8 149 | a034feb2 150 | e4f4ddae 151 | e36a3026 152 | c1489591 153 | 111bb373 154 | e1d9fb32 155 | 93e22d48 156 | c1ec4b26 157 | d9638e69 158 | 60ab04c5 159 | cfe7773a 160 | 62132822 161 | 2f5fb2a3 162 | 7bdd197d 163 | 033333fd 164 | 130fcdbe 165 | 12e509c2 166 | 67138c33 167 | 6f90cc5f 168 | 4e3020fe 169 | bbdd8bb7 170 | b399ccdb 171 | fecd10d2 172 | 2e0967f7 173 | f509054f 174 | 792c6ff7 175 | 48e2afc5 176 | d904c048 177 | 111e0a5c 178 | b83024e2 179 | e6a7b79c 180 | bdc5ccf7 181 | b8146d00 182 | 9d394f1a 183 | 645b84f9 184 | 95ab2d0f 185 | e6f8a31d 186 | b4f876fb 187 | dc2c570d 188 | 3afd02d7 189 | 5c80c82c 190 | b1b32ddd 191 | 9f25fc61 192 | ba538072 193 | f8916fef 194 | 43c04ad2 195 | a658e949 196 | 2861dd53 197 | f6e40aba 198 | 09d305d1 199 | aac33bff 200 | 8d9d4c08 201 | -------------------------------------------------------------------------------- /sam2/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from sam2.build_sam import build_sam2_video_predictor 15 | 16 | # Only cuda supported 17 | assert torch.cuda.is_available() 18 | device = torch.device("cuda") 19 | 20 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 21 | if torch.cuda.get_device_properties(0).major >= 8: 22 | # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) 23 | torch.backends.cuda.matmul.allow_tf32 = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | 26 | # Config and checkpoint 27 | sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" 28 | model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" 29 | 30 | # Build video predictor with vos_optimized=True setting 31 | predictor = build_sam2_video_predictor( 32 | model_cfg, sam2_checkpoint, device=device, vos_optimized=True 33 | ) 34 | 35 | 36 | # Initialize with video 37 | video_dir = "notebooks/videos/bedroom" 38 | # scan all the JPEG frame names in this directory 39 | frame_names = [ 40 | p 41 | for p in os.listdir(video_dir) 42 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] 43 | ] 44 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 45 | inference_state = predictor.init_state(video_path=video_dir) 46 | 47 | 48 | # Number of runs, warmup etc 49 | warm_up, runs = 5, 25 50 | verbose = True 51 | num_frames = len(frame_names) 52 | total, count = 0, 0 53 | torch.cuda.empty_cache() 54 | 55 | # We will select an object with a click. 56 | # See video_predictor_example.ipynb for more detailed explanation 57 | ann_frame_idx, ann_obj_id = 0, 1 58 | # Add a positive click at (x, y) = (210, 350) 59 | # For labels, `1` means positive click 60 | points = np.array([[210, 350]], dtype=np.float32) 61 | labels = np.array([1], np.int32) 62 | 63 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( 64 | inference_state=inference_state, 65 | frame_idx=ann_frame_idx, 66 | obj_id=ann_obj_id, 67 | points=points, 68 | labels=labels, 69 | ) 70 | 71 | # Warmup and then average FPS over several runs 72 | with torch.autocast("cuda", torch.bfloat16): 73 | with torch.inference_mode(): 74 | for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): 75 | start = time.time() 76 | # Start tracking 77 | for ( 78 | out_frame_idx, 79 | out_obj_ids, 80 | out_mask_logits, 81 | ) in predictor.propagate_in_video(inference_state): 82 | pass 83 | 84 | end = time.time() 85 | total += end - start 86 | count += 1 87 | if i == warm_up - 1: 88 | print("Warmup FPS: ", count * num_frames / total) 89 | total = 0 90 | count = 0 91 | 92 | print("FPS: ", count * num_frames / total) 93 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) 36 | return windows, (Hp, Wp) 37 | 38 | 39 | def window_unpartition(windows, window_size, pad_hw, hw): 40 | """ 41 | Window unpartition into original sequences and removing padding. 42 | Args: 43 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 44 | window_size (int): window size. 45 | pad_hw (Tuple): padded height and width (Hp, Wp). 46 | hw (Tuple): original height and width (H, W) before padding. 47 | Returns: 48 | x: unpartitioned sequences with [B, H, W, C]. 49 | """ 50 | Hp, Wp = pad_hw 51 | H, W = hw 52 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 53 | x = windows.reshape( 54 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 55 | ) 56 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) 57 | 58 | if Hp > H or Wp > W: 59 | x = x[:, :H, :W, :] 60 | return x 61 | 62 | 63 | class PatchEmbed(nn.Module): 64 | """ 65 | Image to Patch Embedding. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | kernel_size: Tuple[int, ...] = (7, 7), 71 | stride: Tuple[int, ...] = (4, 4), 72 | padding: Tuple[int, ...] = (3, 3), 73 | in_chans: int = 3, 74 | embed_dim: int = 768, 75 | ): 76 | """ 77 | Args: 78 | kernel_size (Tuple): kernel size of the projection layer. 79 | stride (Tuple): stride of the projection layer. 80 | padding (Tuple): padding size of the projection layer. 81 | in_chans (int): Number of input image channels. 82 | embed_dim (int): embed_dim (int): Patch embedding dimension. 83 | """ 84 | super().__init__() 85 | self.proj = nn.Conv2d( 86 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 87 | ) 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | x = self.proj(x) 91 | # B C H W -> B H W C 92 | x = x.permute(0, 2, 3, 1) 93 | return x 94 | -------------------------------------------------------------------------------- /sav_dataset/endo_sav_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | 7 | # adapted from https://github.com/hkchengrex/vos-benchmark 8 | # and https://github.com/davisvideochallenge/davis2017-evaluation 9 | # with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files 10 | # in the sav_dataset directory. 11 | from argparse import ArgumentParser 12 | 13 | from utils.endo_sav_benchmark import benchmark 14 | 15 | """ 16 | The structure of the {GT_ROOT} can be either of the follow two structures. 17 | {GT_ROOT} and {PRED_ROOT} should be of the same format 18 | 19 | 1. SA-V val/test structure 20 | {GT_ROOT} # gt root folder 21 | ├── {video_id} 22 | │ ├── 000 # all masks associated with obj 000 23 | │ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask) 24 | │ │ └── ... 25 | │ ├── 001 # all masks associated with obj 001 26 | │ ├── 002 # all masks associated with obj 002 27 | │ └── ... 28 | ├── {video_id} 29 | ├── {video_id} 30 | └── ... 31 | 32 | 2. Similar to DAVIS structure: 33 | 34 | {GT_ROOT} # gt root folder 35 | ├── {video_id} 36 | │ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects) 37 | │ └── ... 38 | ├── {video_id} 39 | ├── {video_id} 40 | └── ... 41 | """ 42 | 43 | 44 | parser = ArgumentParser() 45 | parser.add_argument( 46 | "--gt_root", 47 | required=True, 48 | help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps", 49 | ) 50 | parser.add_argument( 51 | "--pred_root", 52 | required=True, 53 | help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root", 54 | ) 55 | parser.add_argument( 56 | "-n", "--num_processes", default=16, type=int, help="Number of concurrent processes" 57 | ) 58 | parser.add_argument( 59 | "-s", 60 | "--strict", 61 | help="Make sure every video in the gt_root folder has a corresponding video in the prediction", 62 | action="store_true", 63 | ) 64 | parser.add_argument( 65 | "-q", 66 | "--quiet", 67 | help="Quietly run evaluation without printing the information out", 68 | action="store_true", 69 | ) 70 | # https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85 71 | parser.add_argument( 72 | "--do_not_skip_first_and_last_frame", 73 | help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. " 74 | "Set this to true for evaluation on settings that doen't skip first and last frames", 75 | action="store_true", 76 | ) 77 | parser.add_argument( 78 | "--evaluate_blank_frame", 79 | type=int, default=1 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | args = parser.parse_args() 85 | benchmark( 86 | [args.gt_root], 87 | [args.pred_root], 88 | args.strict, 89 | args.num_processes, 90 | verbose=not args.quiet, 91 | skip_first_and_last=not args.do_not_skip_first_and_last_frame, 92 | ) 93 | -------------------------------------------------------------------------------- /training/dataset/vos_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from dataclasses import dataclass 9 | from typing import List 10 | 11 | from training.dataset.vos_segment_loader import LazySegments 12 | 13 | MAX_RETRIES = 1000 14 | 15 | 16 | @dataclass 17 | class SampledFramesAndObjects: 18 | frames: List[int] 19 | object_ids: List[int] 20 | 21 | 22 | class VOSSampler: 23 | def __init__(self, sort_frames=True): 24 | # frames are ordered by frame id when sort_frames is True 25 | self.sort_frames = sort_frames 26 | 27 | def sample(self, video): 28 | raise NotImplementedError() 29 | 30 | 31 | class RandomUniformSampler(VOSSampler): 32 | def __init__( 33 | self, 34 | num_frames, 35 | max_num_objects, 36 | reverse_time_prob=0.0, 37 | ): 38 | self.num_frames = num_frames 39 | self.max_num_objects = max_num_objects 40 | self.reverse_time_prob = reverse_time_prob 41 | 42 | def sample(self, video, segment_loader, epoch=None): 43 | 44 | for retry in range(MAX_RETRIES): 45 | if len(video.frames) < self.num_frames: 46 | raise Exception( 47 | f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." 48 | ) 49 | start = random.randrange(0, len(video.frames) - self.num_frames + 1) 50 | frames = [video.frames[start + step] for step in range(self.num_frames)] 51 | if random.uniform(0, 1) < self.reverse_time_prob: 52 | # Reverse time 53 | frames = frames[::-1] 54 | 55 | # Get first frame object ids 56 | visible_object_ids = [] 57 | loaded_segms = segment_loader.load(frames[0].frame_idx) 58 | if isinstance(loaded_segms, LazySegments): 59 | # LazySegments for SA1BRawDataset 60 | visible_object_ids = list(loaded_segms.keys()) 61 | else: 62 | for object_id, segment in segment_loader.load( 63 | frames[0].frame_idx 64 | ).items(): 65 | if segment.sum(): 66 | visible_object_ids.append(object_id) 67 | 68 | # First frame needs to have at least a target to track 69 | if len(visible_object_ids) > 0: 70 | break 71 | if retry >= MAX_RETRIES - 1: 72 | raise Exception("No visible objects") 73 | 74 | object_ids = random.sample( 75 | visible_object_ids, 76 | min(len(visible_object_ids), self.max_num_objects), 77 | ) 78 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 79 | 80 | 81 | class EvalSampler(VOSSampler): 82 | """ 83 | VOS Sampler for evaluation: sampling all the frames and all the objects in a video 84 | """ 85 | 86 | def __init__( 87 | self, 88 | ): 89 | super().__init__() 90 | 91 | def sample(self, video, segment_loader, epoch=None): 92 | """ 93 | Sampling all the frames and all the objects 94 | """ 95 | if self.sort_frames: 96 | # ordered by frame id 97 | frames = sorted(video.frames, key=lambda x: x.frame_idx) 98 | else: 99 | # use the original order 100 | frames = video.frames 101 | object_ids = segment_loader.load(frames[0].frame_idx).keys() 102 | if len(object_ids) == 0: 103 | raise Exception("First frame of the video has no objects") 104 | 105 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 106 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [64, 64] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [64, 64] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [64, 64] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [64, 64] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [64, 64] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [64, 64] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /training/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | from torch.utils.data import ( 13 | ConcatDataset as TorchConcatDataset, 14 | Dataset, 15 | Subset as TorchSubset, 16 | ) 17 | 18 | 19 | class ConcatDataset(TorchConcatDataset): 20 | def __init__(self, datasets: Iterable[Dataset]) -> None: 21 | super(ConcatDataset, self).__init__(datasets) 22 | 23 | self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) 24 | 25 | def set_epoch(self, epoch: int): 26 | for dataset in self.datasets: 27 | if hasattr(dataset, "epoch"): 28 | dataset.epoch = epoch 29 | if hasattr(dataset, "set_epoch"): 30 | dataset.set_epoch(epoch) 31 | 32 | 33 | class Subset(TorchSubset): 34 | def __init__(self, dataset, indices) -> None: 35 | super(Subset, self).__init__(dataset, indices) 36 | 37 | self.repeat_factors = dataset.repeat_factors[indices] 38 | assert len(indices) == len(self.repeat_factors) 39 | 40 | 41 | # Adapted from Detectron2 42 | class RepeatFactorWrapper(Dataset): 43 | """ 44 | Thin wrapper around a dataset to implement repeat factor sampling. 45 | The underlying dataset must have a repeat_factors member to indicate the per-image factor. 46 | Set it to uniformly ones to disable repeat factor sampling 47 | """ 48 | 49 | def __init__(self, dataset, seed: int = 0): 50 | self.dataset = dataset 51 | self.epoch_ids = None 52 | self._seed = seed 53 | 54 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 55 | self._int_part = torch.trunc(dataset.repeat_factors) 56 | self._frac_part = dataset.repeat_factors - self._int_part 57 | 58 | def _get_epoch_indices(self, generator): 59 | """ 60 | Create a list of dataset indices (with repeats) to use for one epoch. 61 | 62 | Args: 63 | generator (torch.Generator): pseudo random number generator used for 64 | stochastic rounding. 65 | 66 | Returns: 67 | torch.Tensor: list of dataset indices to use in one epoch. Each index 68 | is repeated based on its calculated repeat factor. 69 | """ 70 | # Since repeat factors are fractional, we use stochastic rounding so 71 | # that the target repeat factor is achieved in expectation over the 72 | # course of training 73 | rands = torch.rand(len(self._frac_part), generator=generator) 74 | rep_factors = self._int_part + (rands < self._frac_part).float() 75 | # Construct a list of indices in which we repeat images as specified 76 | indices = [] 77 | for dataset_index, rep_factor in enumerate(rep_factors): 78 | indices.extend([dataset_index] * int(rep_factor.item())) 79 | return torch.tensor(indices, dtype=torch.int64) 80 | 81 | def __len__(self): 82 | if self.epoch_ids is None: 83 | # Here we raise an error instead of returning original len(self.dataset) avoid 84 | # accidentally using unwrapped length. Otherwise it's error-prone since the 85 | # length changes to `len(self.epoch_ids)`changes after set_epoch is called. 86 | raise RuntimeError("please call set_epoch first to get wrapped length") 87 | # return len(self.dataset) 88 | 89 | return len(self.epoch_ids) 90 | 91 | def set_epoch(self, epoch: int): 92 | g = torch.Generator() 93 | g.manual_seed(self._seed + epoch) 94 | self.epoch_ids = self._get_epoch_indices(g) 95 | if hasattr(self.dataset, "set_epoch"): 96 | self.dataset.set_epoch(epoch) 97 | 98 | def __getitem__(self, idx): 99 | if self.epoch_ids is None: 100 | raise RuntimeError( 101 | "Repeat ids haven't been computed. Did you forget to call set_epoch?" 102 | ) 103 | 104 | return self.dataset[self.epoch_ids[idx]] 105 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_frame_to_prune: 2 88 | num_maskmem: 7 89 | image_size: 512 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [64, 64] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [64, 64] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [64, 64] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [64, 64] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Surgical SAM 2: Real-time Segment Anything in Surgical Video by Efficient Frame Pruning 2 | 3 | Official implementation for SurgSAM2, an innovative model that leverages the power of the Segment Anything Model 2 (SAM2), integrating it with an efficient frame pruning mechanism for real-time surgical video segmentation. 4 | 5 | > [Surgical SAM 2: Real-time Segment Anything in Surgical Video by Efficient Frame Pruning](https://openreview.net/forum?id=WSDrF5mKVp) 6 | > 7 | > Haofeng Liu, Erli Zhang, Junde Wu, Mingxuan Hong, Yueming Jin 8 | > 9 | > NeurIPS 2024 Workshop AIM-FM 10 | 11 | ## Overview 12 | 13 | We introduce Surgical SAM 2 (SurgSAM-2), an innovative model that leverages the power of the Segment Anything Model 2 (SAM2), integrating it with an efficient frame pruning mechanism for real-time surgical video segmentation. The proposed SurgSAM-2 14 | 15 | - dramatically reduces memory usage and computational cost of SAM2 for real-time clinical application; 16 | - achieves superior performance with 3× FPS (86 FPS), making real-time surgical segmentation in resource-constrained environments a feasible reality. 17 | 18 | ![architecture](./assets/architecture.png) 19 | 20 | ## Dataset Acquisition and Preprocessing 21 | 22 | ### Data Download 23 | 24 | 1. Please download the training and validation sets used in our experiments: 25 | 1. [VOS-Endovis17](https://drive.google.com/file/d/1tw7KzpXqOC3HsjsUknro4MOqQ2Nr3vD1/view?usp=drive_link) 26 | 2. [VOS-Endovis18](https://drive.google.com/file/d/1Vod5jKoC8CAEqlYdiXZ2HMexP9IFXbGp/view?usp=drive_link) 27 | 2. The original image data can be obtained from the official websites: 28 | 1. [Endovis17 Official Dataset](https://endovissub2017-roboticinstrumentsegmentation.grand-challenge.org/) 29 | 2. [Endovis18 Official Dataset](http://endovissub2018-roboticscenesegmentation.grand-challenge.org/) 30 | 31 | ### Data Preprocessing 32 | 33 | Follow the data preprocessing instructions provided in the [ISINet](https://github.com/BCV-Uniandes/ISINet) repository. 34 | 35 | ### Dataset Structure 36 | 37 | After downloading, organize your data according to the following structure: 38 | 39 | ``` 40 | project_root/ 41 | └── datasets/ 42 | └── VOS-Endovis18/ 43 | └── train/ 44 | └── JPEGImages/ 45 | └── Annotations/ 46 | └── valid/ 47 | └── JPEGImages/ 48 | └── Annotations/ 49 | └── VOS/ 50 | ``` 51 | 52 | ## Training 53 | 54 | To train the model, run: 55 | 56 | ``` 57 | CUDA_VISIBLE_DEVICES=0 python training/train.py --config configs/sam2.1_training/sam2.1_hiera_s_endovis18_instrument 58 | ``` 59 | 60 | ## Evaluation 61 | 62 | Download the pretrained weights from [sam2.1_hiera_s_endo18](https://drive.google.com/file/d/1DyrrLKst1ZQwkgKM7BWCCwLxSXAgOcMI/view?usp=drive_link). Place the file at `project_root/checkpoints/sam2.1_hiera_s_endo18.pth`. 63 | 64 | ``` 65 | python tools/vos_inference.py --sam2_cfg configs/sam2.1/sam2.1_hiera_s.yaml --sam2_checkpoint ./checkpoints/sam2.1_hiera_s_endo18.pth --output_mask_dir ./results/sam2.1/endovis_2018/instrument --input_mask_dir ./datasets/VOS-Endovis18/valid/VOS/Annotations_vos_instrument --base_video_dir ./datasets/VOS-Endovis18/valid/JPEGImages --gt_root ./datasets/VOS-Endovis18/valid/Annotations --gpu_id 0 66 | ``` 67 | 68 | ## Demo 69 | 70 | Demo data from Endovis 2018 can be downloaded from [2018 demo data](https://drive.google.com/file/d/1RG9DIGXFQwXckYpaTLEUyxYq4DexgBOY/view?usp=sharing). 71 | 72 | After downloading, arrange the files according to the following structure: 73 | 74 | ``` 75 | project_root/ 76 | └── datasets/ 77 | └── endovis18/ 78 | └── images/ 79 | └── seq_2/ 80 | └── ... 81 | ``` 82 | 83 | ## Acknowledgement 84 | 85 | This research utilizes datasets from [Endovis 2017](https://endovissub2017-roboticinstrumentsegmentation.grand-challenge.org/Downloads/) and [Endovis 2018](https://endovissub2018-roboticscenesegmentation.grand-challenge.org/Downloads/).. If you wish to use these datasets, please request access through their respective official websites. 86 | 87 | Our implementation builds upon the [segment anything 2](https://github.com/facebookresearch/segment-anything-2) framework. We extend our sincere appreciation to the authors for their outstanding work and significant contributions to the field of video segmentation. 88 | 89 | ## Citation 90 | 91 | ``` 92 | @misc{liu2024surgicalsam2realtime, 93 | title={Surgical SAM 2: Real-time Segment Anything in Surgical Video by Efficient Frame Pruning}, 94 | author={Haofeng Liu and Erli Zhang and Junde Wu and Mingxuan Hong and Yueming Jin}, 95 | year={2024}, 96 | eprint={2408.07931}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.CV}, 99 | url={https://arxiv.org/abs/2408.07931}, 100 | } 101 | ``` -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | self.d_model = d_model 75 | for dim in backbone_channel_list: 76 | current = nn.Sequential() 77 | current.add_module( 78 | "conv", 79 | nn.Conv2d( 80 | in_channels=dim, 81 | out_channels=d_model, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | ), 86 | ) 87 | 88 | self.convs.append(current) 89 | self.fpn_interp_model = fpn_interp_model 90 | assert fuse_type in ["sum", "avg"] 91 | self.fuse_type = fuse_type 92 | 93 | # levels to have top-down features in its outputs 94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 95 | # have top-down propagation, while outputs of level 0 and level 1 have only 96 | # lateral features from the same backbone level. 97 | if fpn_top_down_levels is None: 98 | # default is to have top-down features on all levels 99 | fpn_top_down_levels = range(len(self.convs)) 100 | self.fpn_top_down_levels = list(fpn_top_down_levels) 101 | 102 | def forward(self, xs: List[torch.Tensor]): 103 | 104 | out = [None] * len(self.convs) 105 | pos = [None] * len(self.convs) 106 | assert len(xs) == len(self.convs) 107 | # fpn forward pass 108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 109 | prev_features = None 110 | # forward in top-down order (from low to high resolution) 111 | n = len(self.convs) - 1 112 | for i in range(n, -1, -1): 113 | x = xs[i] 114 | lateral_features = self.convs[n - i](x) 115 | if i in self.fpn_top_down_levels and prev_features is not None: 116 | top_down_features = F.interpolate( 117 | prev_features.to(dtype=torch.float32), 118 | scale_factor=2.0, 119 | mode=self.fpn_interp_model, 120 | align_corners=( 121 | None if self.fpn_interp_model == "nearest" else False 122 | ), 123 | antialias=False, 124 | ) 125 | prev_features = lateral_features + top_down_features 126 | if self.fuse_type == "avg": 127 | prev_features /= 2 128 | else: 129 | prev_features = lateral_features 130 | x_out = prev_features 131 | out[i] = x_out 132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 133 | 134 | return out, pos 135 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /training/scripts/sav_frame_extraction_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | import argparse 4 | import os 5 | from pathlib import Path 6 | 7 | import cv2 8 | 9 | import numpy as np 10 | import submitit 11 | import tqdm 12 | 13 | 14 | def get_args_parser(): 15 | parser = argparse.ArgumentParser( 16 | description="[SA-V Preprocessing] Extracting JPEG frames", 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 18 | ) 19 | 20 | # ------------ 21 | # DATA 22 | # ------------ 23 | data_parser = parser.add_argument_group( 24 | title="SA-V dataset data root", 25 | description="What data to load and how to process it.", 26 | ) 27 | data_parser.add_argument( 28 | "--sav-vid-dir", 29 | type=str, 30 | required=True, 31 | help=("Where to find the SAV videos"), 32 | ) 33 | data_parser.add_argument( 34 | "--sav-frame-sample-rate", 35 | type=int, 36 | default=4, 37 | help="Rate at which to sub-sample frames", 38 | ) 39 | 40 | # ------------ 41 | # LAUNCH 42 | # ------------ 43 | launch_parser = parser.add_argument_group( 44 | title="Cluster launch settings", 45 | description="Number of jobs and retry settings.", 46 | ) 47 | launch_parser.add_argument( 48 | "--n-jobs", 49 | type=int, 50 | required=True, 51 | help="Shard the run over this many jobs.", 52 | ) 53 | launch_parser.add_argument( 54 | "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes." 55 | ) 56 | launch_parser.add_argument( 57 | "--partition", type=str, required=True, help="Partition to launch on." 58 | ) 59 | launch_parser.add_argument( 60 | "--account", type=str, required=True, help="Partition to launch on." 61 | ) 62 | launch_parser.add_argument("--qos", type=str, required=True, help="QOS.") 63 | 64 | # ------------ 65 | # OUTPUT 66 | # ------------ 67 | output_parser = parser.add_argument_group( 68 | title="Setting for results output", description="Where and how to save results." 69 | ) 70 | output_parser.add_argument( 71 | "--output-dir", 72 | type=str, 73 | required=True, 74 | help=("Where to dump the extracted jpeg frames"), 75 | ) 76 | output_parser.add_argument( 77 | "--slurm-output-root-dir", 78 | type=str, 79 | required=True, 80 | help=("Where to save slurm outputs"), 81 | ) 82 | return parser 83 | 84 | 85 | def decode_video(video_path: str): 86 | assert os.path.exists(video_path) 87 | video = cv2.VideoCapture(video_path) 88 | video_frames = [] 89 | while video.isOpened(): 90 | ret, frame = video.read() 91 | if ret: 92 | video_frames.append(frame) 93 | else: 94 | break 95 | return video_frames 96 | 97 | 98 | def extract_frames(video_path, sample_rate): 99 | frames = decode_video(video_path) 100 | return frames[::sample_rate] 101 | 102 | 103 | def submitit_launch(video_paths, sample_rate, save_root): 104 | for path in tqdm.tqdm(video_paths): 105 | frames = extract_frames(path, sample_rate) 106 | output_folder = os.path.join(save_root, Path(path).stem) 107 | if not os.path.exists(output_folder): 108 | os.makedirs(output_folder) 109 | for fid, frame in enumerate(frames): 110 | frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg") 111 | cv2.imwrite(frame_path, frame) 112 | print(f"Saved output to {save_root}") 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = get_args_parser() 117 | args = parser.parse_args() 118 | 119 | sav_vid_dir = args.sav_vid_dir 120 | save_root = args.output_dir 121 | sample_rate = args.sav_frame_sample_rate 122 | 123 | # List all SA-V videos 124 | mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")]) 125 | mp4_files = np.array(mp4_files) 126 | chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)] 127 | 128 | print(f"Processing videos in: {sav_vid_dir}") 129 | print(f"Processing {len(mp4_files)} files") 130 | print(f"Beginning processing in {args.n_jobs} processes") 131 | 132 | # Submitit params 133 | jobs_dir = os.path.join(args.slurm_output_root_dir, "%j") 134 | cpus_per_task = 4 135 | executor = submitit.AutoExecutor(folder=jobs_dir) 136 | executor.update_parameters( 137 | timeout_min=args.timeout, 138 | gpus_per_node=0, 139 | tasks_per_node=1, 140 | slurm_array_parallelism=args.n_jobs, 141 | cpus_per_task=cpus_per_task, 142 | slurm_partition=args.partition, 143 | slurm_account=args.account, 144 | slurm_qos=args.qos, 145 | ) 146 | executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"]) 147 | 148 | # Launch 149 | jobs = [] 150 | with executor.batch(): 151 | for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)): 152 | job = executor.submit( 153 | submitit_launch, 154 | video_paths=mp4_chunk, 155 | sample_rate=sample_rate, 156 | save_root=save_root, 157 | ) 158 | jobs.append(job) 159 | 160 | for j in jobs: 161 | print(f"Slurm JobID: {j.job_id}") 162 | print(f"Saving outputs to {save_root}") 163 | print(f"Slurm outputs at {args.slurm_output_root_dir}") 164 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | # Package metadata 11 | NAME = "SAM-2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" 14 | URL = "https://github.com/facebookresearch/sam2" 15 | AUTHOR = "Meta AI" 16 | AUTHOR_EMAIL = "segment-anything@meta.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r", encoding="utf-8") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.5.1", 26 | "torchvision>=0.20.1", 27 | "numpy>=1.24.4", 28 | "tqdm>=4.66.1", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=9.4.0", 32 | ] 33 | 34 | EXTRA_PACKAGES = { 35 | "notebooks": [ 36 | "matplotlib>=3.9.1", 37 | "jupyter>=1.0.0", 38 | "opencv-python>=4.7.0", 39 | "eva-decord>=0.6.1", 40 | ], 41 | "interactive-demo": [ 42 | "Flask>=3.0.3", 43 | "Flask-Cors>=5.0.0", 44 | "av>=13.0.0", 45 | "dataclasses-json>=0.6.7", 46 | "eva-decord>=0.6.1", 47 | "gunicorn>=23.0.0", 48 | "imagesize>=1.4.1", 49 | "pycocotools>=2.0.8", 50 | "strawberry-graphql>=0.243.0", 51 | ], 52 | "dev": [ 53 | "black==24.2.0", 54 | "usort==1.0.2", 55 | "ufmt==2.0.0b2", 56 | "fvcore>=0.1.5.post20221221", 57 | "pandas>=2.2.2", 58 | "scikit-image>=0.24.0", 59 | "tensorboard>=2.17.0", 60 | "pycocotools>=2.0.8", 61 | "tensordict>=0.6.0", 62 | "opencv-python>=4.7.0", 63 | "submitit>=1.5.1", 64 | ], 65 | } 66 | 67 | # By default, we also build the SAM 2 CUDA extension. 68 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. 69 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" 70 | # By default, we allow SAM 2 installation to proceed even with build errors. 71 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. 72 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" 73 | 74 | # Catch and skip errors during extension building and print a warning message 75 | # (note that this message only shows up under verbose build mode 76 | # "pip install -v -e ." or "python setup.py build_ext -v") 77 | CUDA_ERROR_MSG = ( 78 | "{}\n\n" 79 | "Failed to build the SAM 2 CUDA extension due to the error above. " 80 | "You can still use SAM 2 and it's OK to ignore the error above, although some " 81 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 82 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" 83 | ) 84 | 85 | 86 | def get_extensions(): 87 | if not BUILD_CUDA: 88 | return [] 89 | 90 | try: 91 | from torch.utils.cpp_extension import CUDAExtension 92 | 93 | srcs = ["sam2/csrc/connected_components.cu"] 94 | compile_args = { 95 | "cxx": [], 96 | "nvcc": [ 97 | "-DCUDA_HAS_FP16=1", 98 | "-D__CUDA_NO_HALF_OPERATORS__", 99 | "-D__CUDA_NO_HALF_CONVERSIONS__", 100 | "-D__CUDA_NO_HALF2_OPERATORS__", 101 | ], 102 | } 103 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 104 | except Exception as e: 105 | if BUILD_ALLOW_ERRORS: 106 | print(CUDA_ERROR_MSG.format(e)) 107 | ext_modules = [] 108 | else: 109 | raise e 110 | 111 | return ext_modules 112 | 113 | 114 | try: 115 | from torch.utils.cpp_extension import BuildExtension 116 | 117 | class BuildExtensionIgnoreErrors(BuildExtension): 118 | 119 | def finalize_options(self): 120 | try: 121 | super().finalize_options() 122 | except Exception as e: 123 | print(CUDA_ERROR_MSG.format(e)) 124 | self.extensions = [] 125 | 126 | def build_extensions(self): 127 | try: 128 | super().build_extensions() 129 | except Exception as e: 130 | print(CUDA_ERROR_MSG.format(e)) 131 | self.extensions = [] 132 | 133 | def get_ext_filename(self, ext_name): 134 | try: 135 | return super().get_ext_filename(ext_name) 136 | except Exception as e: 137 | print(CUDA_ERROR_MSG.format(e)) 138 | self.extensions = [] 139 | return "_C.so" 140 | 141 | cmdclass = { 142 | "build_ext": ( 143 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) 144 | if BUILD_ALLOW_ERRORS 145 | else BuildExtension.with_options(no_python_abi_suffix=True) 146 | ) 147 | } 148 | except Exception as e: 149 | cmdclass = {} 150 | if BUILD_ALLOW_ERRORS: 151 | print(CUDA_ERROR_MSG.format(e)) 152 | else: 153 | raise e 154 | 155 | 156 | # Setup configuration 157 | setup( 158 | name=NAME, 159 | version=VERSION, 160 | description=DESCRIPTION, 161 | long_description=LONG_DESCRIPTION, 162 | long_description_content_type="text/markdown", 163 | url=URL, 164 | author=AUTHOR, 165 | author_email=AUTHOR_EMAIL, 166 | license=LICENSE, 167 | packages=find_packages(exclude="notebooks"), 168 | include_package_data=True, 169 | install_requires=REQUIRED_PACKAGES, 170 | extras_require=EXTRA_PACKAGES, 171 | python_requires=">=3.10.0", 172 | ext_modules=get_extensions(), 173 | cmdclass=cmdclass, 174 | ) 175 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /training/dataset/vos_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import random 9 | from copy import deepcopy 10 | 11 | import numpy as np 12 | 13 | import torch 14 | from iopath.common.file_io import g_pathmgr 15 | from PIL import Image as PILImage 16 | from torchvision.datasets.vision import VisionDataset 17 | 18 | from training.dataset.vos_raw_dataset import VOSRawDataset 19 | from training.dataset.vos_sampler import VOSSampler 20 | from training.dataset.vos_segment_loader import JSONSegmentLoader 21 | 22 | from training.utils.data_utils import Frame, Object, VideoDatapoint 23 | 24 | MAX_RETRIES = 100 25 | 26 | 27 | class VOSDataset(VisionDataset): 28 | def __init__( 29 | self, 30 | transforms, 31 | training: bool, 32 | video_dataset: VOSRawDataset, 33 | sampler: VOSSampler, 34 | multiplier: int, 35 | always_target=True, 36 | target_segments_available=True, 37 | ): 38 | self._transforms = transforms 39 | self.training = training 40 | self.video_dataset = video_dataset 41 | self.sampler = sampler 42 | 43 | self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) 44 | self.repeat_factors *= multiplier 45 | print(f"Raw dataset length = {len(self.video_dataset)}") 46 | 47 | self.curr_epoch = 0 # Used in case data loader behavior changes across epochs 48 | self.always_target = always_target 49 | self.target_segments_available = target_segments_available 50 | 51 | def _get_datapoint(self, idx): 52 | 53 | for retry in range(MAX_RETRIES): 54 | try: 55 | if isinstance(idx, torch.Tensor): 56 | idx = idx.item() 57 | # sample a video 58 | video, segment_loader = self.video_dataset.get_video(idx) 59 | # sample frames and object indices to be used in a datapoint 60 | sampled_frms_and_objs = self.sampler.sample( 61 | video, segment_loader, epoch=self.curr_epoch 62 | ) 63 | break # Succesfully loaded video 64 | except Exception as e: 65 | if self.training: 66 | logging.warning( 67 | f"Loading failed (id={idx}); Retry {retry} with exception: {e}" 68 | ) 69 | idx = random.randrange(0, len(self.video_dataset)) 70 | else: 71 | # Shouldn't fail to load a val video 72 | raise e 73 | 74 | datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) 75 | for transform in self._transforms: 76 | datapoint = transform(datapoint, epoch=self.curr_epoch) 77 | return datapoint 78 | 79 | def construct(self, video, sampled_frms_and_objs, segment_loader): 80 | """ 81 | Constructs a VideoDatapoint sample to pass to transforms 82 | """ 83 | sampled_frames = sampled_frms_and_objs.frames 84 | sampled_object_ids = sampled_frms_and_objs.object_ids 85 | 86 | images = [] 87 | rgb_images = load_images(sampled_frames) 88 | # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) 89 | for frame_idx, frame in enumerate(sampled_frames): 90 | w, h = rgb_images[frame_idx].size 91 | images.append( 92 | Frame( 93 | data=rgb_images[frame_idx], 94 | objects=[], 95 | ) 96 | ) 97 | # We load the gt segments associated with the current frame 98 | if isinstance(segment_loader, JSONSegmentLoader): 99 | segments = segment_loader.load( 100 | frame.frame_idx, obj_ids=sampled_object_ids 101 | ) 102 | else: 103 | segments = segment_loader.load(frame.frame_idx) 104 | for obj_id in sampled_object_ids: 105 | # Extract the segment 106 | if obj_id in segments: 107 | assert ( 108 | segments[obj_id] is not None 109 | ), "None targets are not supported" 110 | # segment is uint8 and remains uint8 throughout the transforms 111 | segment = segments[obj_id].to(torch.uint8) 112 | else: 113 | # There is no target, we either use a zero mask target or drop this object 114 | if not self.always_target: 115 | continue 116 | segment = torch.zeros(h, w, dtype=torch.uint8) 117 | 118 | images[frame_idx].objects.append( 119 | Object( 120 | object_id=obj_id, 121 | frame_index=frame.frame_idx, 122 | segment=segment, 123 | ) 124 | ) 125 | return VideoDatapoint( 126 | frames=images, 127 | video_id=video.video_id, 128 | size=(h, w), 129 | ) 130 | 131 | def __getitem__(self, idx): 132 | return self._get_datapoint(idx) 133 | 134 | def __len__(self): 135 | return len(self.video_dataset) 136 | 137 | 138 | def load_images(frames): 139 | all_images = [] 140 | cache = {} 141 | for frame in frames: 142 | if frame.data is None: 143 | # Load the frame rgb data from file 144 | path = frame.image_path 145 | if path in cache: 146 | all_images.append(deepcopy(all_images[cache[path]])) 147 | continue 148 | with g_pathmgr.open(path, "rb") as fopen: 149 | all_images.append(PILImage.open(fopen).convert("RGB")) 150 | cache[path] = len(all_images) - 1 151 | else: 152 | # The frame rgb data has already been loaded 153 | # Convert it to a PILImage 154 | all_images.append(tensor_2_PIL(frame.data)) 155 | 156 | return all_images 157 | 158 | 159 | def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: 160 | data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 161 | data = data.astype(np.uint8) 162 | return PILImage.fromarray(data) 163 | -------------------------------------------------------------------------------- /sav_dataset/utils/endo_sav_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | import json 7 | import os 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pycocotools.mask as mask_util 14 | 15 | 16 | def decode_video(video_path: str) -> List[np.ndarray]: 17 | """ 18 | Decode the video and return the RGB frames 19 | """ 20 | video = cv2.VideoCapture(video_path) 21 | video_frames = [] 22 | while video.isOpened(): 23 | ret, frame = video.read() 24 | if ret: 25 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 26 | video_frames.append(frame) 27 | else: 28 | break 29 | return video_frames 30 | 31 | 32 | def show_anns(masks, colors: List, borders=True) -> None: 33 | """ 34 | show the annotations 35 | """ 36 | # return if no masks 37 | if len(masks) == 0: 38 | return 39 | 40 | # sort masks by size 41 | sorted_annot_and_color = sorted( 42 | zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True 43 | ) 44 | H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1] 45 | 46 | canvas = np.ones((H, W, 4)) 47 | canvas[:, :, 3] = 0 # set the alpha channel 48 | contour_thickness = max(1, int(min(5, 0.01 * min(H, W)))) 49 | for mask, color in sorted_annot_and_color: 50 | canvas[mask] = np.concatenate([color, [0.55]]) 51 | if borders: 52 | contours, _ = cv2.findContours( 53 | np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE 54 | ) 55 | cv2.drawContours( 56 | canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness 57 | ) 58 | 59 | ax = plt.gca() 60 | ax.imshow(canvas) 61 | 62 | 63 | class SAVDataset: 64 | """ 65 | SAVDataset is a class to load the SAV dataset and visualize the annotations. 66 | """ 67 | 68 | def __init__(self, sav_dir, annot_sample_rate=4): 69 | """ 70 | Args: 71 | sav_dir: the directory of the SAV dataset 72 | annot_sample_rate: the sampling rate of the annotations. 73 | The annotations are aligned with the videos at 6 fps. 74 | """ 75 | self.sav_dir = sav_dir 76 | self.annot_sample_rate = annot_sample_rate 77 | self.manual_mask_colors = np.random.random((256, 3)) 78 | self.auto_mask_colors = np.random.random((256, 3)) 79 | 80 | def read_frames(self, mp4_path: str) -> None: 81 | """ 82 | Read the frames and downsample them to align with the annotations. 83 | """ 84 | if not os.path.exists(mp4_path): 85 | print(f"{mp4_path} doesn't exist.") 86 | return None 87 | else: 88 | # decode the video 89 | frames = decode_video(mp4_path) 90 | print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).") 91 | 92 | # downsample the frames to align with the annotations 93 | frames = frames[:: self.annot_sample_rate] 94 | print( 95 | f"Videos are annotated every {self.annot_sample_rate} frames. " 96 | "To align with the annotations, " 97 | f"downsample the video to {len(frames)} frames." 98 | ) 99 | return frames 100 | 101 | def get_frames_and_annotations( 102 | self, video_id: str 103 | ) -> Tuple[List | None, Dict | None, Dict | None]: 104 | """ 105 | Get the frames and annotations for video. 106 | """ 107 | # load the video 108 | mp4_path = os.path.join(self.sav_dir, video_id + ".mp4") 109 | frames = self.read_frames(mp4_path) 110 | if frames is None: 111 | return None, None, None 112 | 113 | # load the manual annotations 114 | manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json") 115 | if not os.path.exists(manual_annot_path): 116 | print(f"{manual_annot_path} doesn't exist. Something might be wrong.") 117 | manual_annot = None 118 | else: 119 | manual_annot = json.load(open(manual_annot_path)) 120 | 121 | # load the manual annotations 122 | auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json") 123 | if not os.path.exists(auto_annot_path): 124 | print(f"{auto_annot_path} doesn't exist.") 125 | auto_annot = None 126 | else: 127 | auto_annot = json.load(open(auto_annot_path)) 128 | 129 | return frames, manual_annot, auto_annot 130 | 131 | def visualize_annotation( 132 | self, 133 | frames: List[np.ndarray], 134 | auto_annot: Optional[Dict], 135 | manual_annot: Optional[Dict], 136 | annotated_frame_id: int, 137 | show_auto=True, 138 | show_manual=True, 139 | ) -> None: 140 | """ 141 | Visualize the annotations on the annotated_frame_id. 142 | If show_manual is True, show the manual annotations. 143 | If show_auto is True, show the auto annotations. 144 | By default, show both auto and manual annotations. 145 | """ 146 | 147 | if annotated_frame_id >= len(frames): 148 | print("invalid annotated_frame_id") 149 | return 150 | 151 | rles = [] 152 | colors = [] 153 | if show_manual and manual_annot is not None: 154 | rles.extend(manual_annot["masklet"][annotated_frame_id]) 155 | colors.extend( 156 | self.manual_mask_colors[ 157 | : len(manual_annot["masklet"][annotated_frame_id]) 158 | ] 159 | ) 160 | if show_auto and auto_annot is not None: 161 | rles.extend(auto_annot["masklet"][annotated_frame_id]) 162 | colors.extend( 163 | self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])] 164 | ) 165 | 166 | plt.imshow(frames[annotated_frame_id]) 167 | 168 | if len(rles) > 0: 169 | masks = [mask_util.decode(rle) > 0 for rle in rles] 170 | show_anns(masks, colors) 171 | else: 172 | print("No annotation will be shown") 173 | 174 | plt.axis("off") 175 | plt.show() 176 | -------------------------------------------------------------------------------- /training/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Misc functions, including distributed helpers. 9 | 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | 13 | from dataclasses import dataclass 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import torch 17 | 18 | from PIL import Image as PILImage 19 | from tensordict import tensorclass 20 | 21 | 22 | @tensorclass 23 | class BatchedVideoMetaData: 24 | """ 25 | This class represents metadata about a batch of videos. 26 | Attributes: 27 | unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) 28 | frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. 29 | """ 30 | 31 | unique_objects_identifier: torch.LongTensor 32 | frame_orig_size: torch.LongTensor 33 | 34 | 35 | @tensorclass 36 | class BatchedVideoDatapoint: 37 | """ 38 | This class represents a batch of videos with associated annotations and metadata. 39 | Attributes: 40 | img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. 41 | obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. 42 | masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. 43 | metadata: An instance of BatchedVideoMetaData containing metadata about the batch. 44 | dict_key: A string key used to identify the batch. 45 | """ 46 | 47 | img_batch: torch.FloatTensor 48 | obj_to_frame_idx: torch.IntTensor 49 | masks: torch.BoolTensor 50 | metadata: BatchedVideoMetaData 51 | 52 | dict_key: str 53 | 54 | def pin_memory(self, device=None): 55 | return self.apply(torch.Tensor.pin_memory, device=device) 56 | 57 | @property 58 | def num_frames(self) -> int: 59 | """ 60 | Returns the number of frames per video. 61 | """ 62 | return self.batch_size[0] 63 | 64 | @property 65 | def num_videos(self) -> int: 66 | """ 67 | Returns the number of videos in the batch. 68 | """ 69 | return self.img_batch.shape[1] 70 | 71 | @property 72 | def flat_obj_to_img_idx(self) -> torch.IntTensor: 73 | """ 74 | Returns a flattened tensor containing the object to img index. 75 | The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] 76 | """ 77 | frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) 78 | flat_idx = video_idx * self.num_frames + frame_idx 79 | return flat_idx 80 | 81 | @property 82 | def flat_img_batch(self) -> torch.FloatTensor: 83 | """ 84 | Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] 85 | """ 86 | 87 | return self.img_batch.transpose(0, 1).flatten(0, 1) 88 | 89 | 90 | @dataclass 91 | class Object: 92 | # Id of the object in the media 93 | object_id: int 94 | # Index of the frame in the media (0 if single image) 95 | frame_index: int 96 | segment: Union[torch.Tensor, dict] # RLE dict or binary mask 97 | 98 | 99 | @dataclass 100 | class Frame: 101 | data: Union[torch.Tensor, PILImage.Image] 102 | objects: List[Object] 103 | 104 | 105 | @dataclass 106 | class VideoDatapoint: 107 | """Refers to an image/video and all its annotations""" 108 | 109 | frames: List[Frame] 110 | video_id: int 111 | size: Tuple[int, int] 112 | 113 | 114 | def collate_fn( 115 | batch: List[VideoDatapoint], 116 | dict_key, 117 | ) -> BatchedVideoDatapoint: 118 | """ 119 | Args: 120 | batch: A list of VideoDatapoint instances. 121 | dict_key (str): A string key used to identify the batch. 122 | """ 123 | img_batch = [] 124 | for video in batch: 125 | img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] 126 | 127 | img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) 128 | T = img_batch.shape[0] 129 | # Prepare data structures for sequential processing. Per-frame processing but batched across videos. 130 | step_t_objects_identifier = [[] for _ in range(T)] 131 | step_t_frame_orig_size = [[] for _ in range(T)] 132 | 133 | step_t_masks = [[] for _ in range(T)] 134 | step_t_obj_to_frame_idx = [ 135 | [] for _ in range(T) 136 | ] # List to store frame indices for each time step 137 | 138 | for video_idx, video in enumerate(batch): 139 | orig_video_id = video.video_id 140 | orig_frame_size = video.size 141 | for t, frame in enumerate(video.frames): 142 | objects = frame.objects 143 | for obj in objects: 144 | orig_obj_id = obj.object_id 145 | orig_frame_idx = obj.frame_index 146 | step_t_obj_to_frame_idx[t].append( 147 | torch.tensor([t, video_idx], dtype=torch.int) 148 | ) 149 | step_t_masks[t].append(obj.segment.to(torch.bool)) 150 | step_t_objects_identifier[t].append( 151 | torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) 152 | ) 153 | step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) 154 | 155 | obj_to_frame_idx = torch.stack( 156 | [ 157 | torch.stack(obj_to_frame_idx, dim=0) 158 | for obj_to_frame_idx in step_t_obj_to_frame_idx 159 | ], 160 | dim=0, 161 | ) 162 | masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) 163 | objects_identifier = torch.stack( 164 | [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 165 | ) 166 | frame_orig_size = torch.stack( 167 | [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 168 | ) 169 | return BatchedVideoDatapoint( 170 | img_batch=img_batch, 171 | obj_to_frame_idx=obj_to_frame_idx, 172 | masks=masks, 173 | metadata=BatchedVideoMetaData( 174 | unique_objects_identifier=objects_identifier, 175 | frame_orig_size=frame_orig_size, 176 | ), 177 | dict_key=dict_key, 178 | batch_size=[T], 179 | ) 180 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # Training Code for SAM 2 2 | 3 | This folder contains the training code for SAM 2, a foundation model for promptable visual segmentation in images and videos. 4 | The code allows users to train and fine-tune SAM 2 on their own datasets (image, video, or both). 5 | 6 | ## Structure 7 | 8 | The training code is organized into the following subfolders: 9 | 10 | * `dataset`: This folder contains image and video dataset and dataloader classes as well as their transforms. 11 | * `model`: This folder contains the main model class (`SAM2Train`) for training/fine-tuning. `SAM2Train` inherits from `SAM2Base` model and provides functions to enable training or fine-tuning SAM 2. It also accepts all training-time parameters used for simulating user prompts (e.g. iterative point sampling). 12 | * `utils`: This folder contains training utils such as loggers and distributed training utils. 13 | * `scripts`: This folder contains the script to extract the frames of SA-V dataset to be used in training. 14 | * `loss_fns.py`: This file has the main loss class (`MultiStepMultiMasksAndIous`) used for training. 15 | * `optimizer.py`: This file contains all optimizer utils that support arbitrary schedulers. 16 | * `trainer.py`: This file contains the `Trainer` class that accepts all the `Hydra` configurable modules (model, optimizer, datasets, etc..) and implements the main train/eval loop. 17 | * `train.py`: This script is used to launch training jobs. It supports single and multi-node jobs. For usage, please check the [Getting Started](README.md#getting-started) section or run `python training/train.py -h` 18 | 19 | ## Getting Started 20 | 21 | To get started with the training code, we provide a simple example to fine-tune our checkpoints on [MOSE](https://henghuiding.github.io/MOSE/) dataset, which can be extended to your custom datasets. 22 | 23 | #### Requirements: 24 | - We assume training on A100 GPUs with **80 GB** of memory. 25 | - Download the MOSE dataset using one of the provided links from [here](https://github.com/henghuiding/MOSE-api?tab=readme-ov-file#download). 26 | 27 | #### Steps to fine-tune on MOSE: 28 | - Install the packages required for training by running `pip install -e ".[dev]"`. 29 | - Set the paths for MOSE dataset in `configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml`. 30 | ```yaml 31 | dataset: 32 | # PATHS to Dataset 33 | img_folder: null # PATH to MOSE JPEGImages folder 34 | gt_folder: null # PATH to MOSE Annotations folder 35 | file_list_txt: null # Optional PATH to filelist containing a subset of videos to be used for training 36 | ``` 37 | - To fine-tune the base model on MOSE using 8 GPUs, run 38 | 39 | ```python 40 | python training/train.py \ 41 | -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ 42 | --use-cluster 0 \ 43 | --num-gpus 8 44 | ``` 45 | 46 | We also support multi-node training on a cluster using [SLURM](https://slurm.schedmd.com/documentation.html), for example, you can train on 2 nodes by running 47 | 48 | ```python 49 | python training/train.py \ 50 | -c configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml \ 51 | --use-cluster 1 \ 52 | --num-gpus 8 \ 53 | --num-nodes 2 54 | --partition $PARTITION \ 55 | --qos $QOS \ 56 | --account $ACCOUNT 57 | ``` 58 | where partition, qos, and account are optional and depend on your SLURM configuration. 59 | By default, the checkpoint and logs will be saved under `sam2_logs` directory in the root of the repo. Alternatively, you can set the experiment log directory in the config file as follows: 60 | 61 | ```yaml 62 | experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name} 63 | ``` 64 | The training losses can be monitored using `tensorboard` logs stored under `tensorboard/` in the experiment log directory. We also provide a sample validation [split]( ../training/assets/MOSE_sample_val_list.txt) for evaluation purposes. To generate predictions, follow this [guide](../tools/README.md) on how to use our `vos_inference.py` script. After generating the predictions, you can run the `sav_evaluator.py` as detailed [here](../sav_dataset/README.md#sa-v-val-and-test-evaluation). The expected MOSE J&F after fine-tuning the Base plus model is 79.4. 65 | 66 | 67 | After training/fine-tuning, you can then use the new checkpoint (saved in `checkpoints/` in the experiment log directory) similar to SAM 2 released checkpoints (as illustrated [here](../README.md#image-prediction)). 68 | ## Training on images and videos 69 | The code supports training on images and videos (similar to how SAM 2 is trained). We provide classes for loading SA-1B as a sample image dataset, SA-V as a sample video dataset, as well as any DAVIS-style video dataset (e.g. MOSE). Note that to train on SA-V, you must first extract all videos to JPEG frames using the provided extraction [script](./scripts/sav_frame_extraction_submitit.py). Below is an example of how to setup the datasets in your config to train on a mix of image and video datasets: 70 | 71 | ```yaml 72 | data: 73 | train: 74 | _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset 75 | phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases 76 | batch_sizes: # List of batch sizes corresponding to each dataset 77 | - ${bs1} # Batch size of dataset 1 78 | - ${bs2} # Batch size of dataset 2 79 | datasets: 80 | # SA1B as an example of an image dataset 81 | - _target_: training.dataset.vos_dataset.VOSDataset 82 | training: true 83 | video_dataset: 84 | _target_: training.dataset.vos_raw_dataset.SA1BRawDataset 85 | img_folder: ${path_to_img_folder} 86 | gt_folder: ${path_to_gt_folder} 87 | file_list_txt: ${path_to_train_filelist} # Optional 88 | sampler: 89 | _target_: training.dataset.vos_sampler.RandomUniformSampler 90 | num_frames: 1 91 | max_num_objects: ${max_num_objects_per_image} 92 | transforms: ${image_transforms} 93 | # SA-V as an example of a video dataset 94 | - _target_: training.dataset.vos_dataset.VOSDataset 95 | training: true 96 | video_dataset: 97 | _target_: training.dataset.vos_raw_dataset.JSONRawDataset 98 | img_folder: ${path_to_img_folder} 99 | gt_folder: ${path_to_gt_folder} 100 | file_list_txt: ${path_to_train_filelist} # Optional 101 | ann_every: 4 102 | sampler: 103 | _target_: training.dataset.vos_sampler.RandomUniformSampler 104 | num_frames: 8 # Number of frames per video 105 | max_num_objects: ${max_num_objects_per_video} 106 | reverse_time_prob: ${reverse_time_prob} # probability to reverse video 107 | transforms: ${video_transforms} 108 | shuffle: True 109 | num_workers: ${num_train_workers} 110 | pin_memory: True 111 | drop_last: True 112 | collate_fn: 113 | _target_: training.utils.data_utils.collate_fn 114 | _partial_: true 115 | dict_key: all 116 | ``` 117 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | import torch 11 | from hydra import compose 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf 14 | 15 | import sam2 16 | 17 | # Check if the user is running Python from the parent directory of the sam2 repo 18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since 19 | # it could shadow the sam2 package and cause issues. 20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): 21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself 22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). 23 | # This typically happens because the user is running Python from the parent directory 24 | # that contains the sam2 repo they cloned. 25 | raise RuntimeError( 26 | "You're likely running Python from the parent directory of the sam2 repository " 27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " 28 | "This is not supported since the `sam2` Python package could be shadowed by the " 29 | "repository name (the repository is also named `sam2` and contains the Python package " 30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " 31 | "rather than its parent dir, or from your home directory) after installing SAM 2." 32 | ) 33 | 34 | 35 | HF_MODEL_ID_TO_FILENAMES = { 36 | "facebook/sam2-hiera-tiny": ( 37 | "configs/sam2/sam2_hiera_t.yaml", 38 | "sam2_hiera_tiny.pt", 39 | ), 40 | "facebook/sam2-hiera-small": ( 41 | "configs/sam2/sam2_hiera_s.yaml", 42 | "sam2_hiera_small.pt", 43 | ), 44 | "facebook/sam2-hiera-base-plus": ( 45 | "configs/sam2/sam2_hiera_b+.yaml", 46 | "sam2_hiera_base_plus.pt", 47 | ), 48 | "facebook/sam2-hiera-large": ( 49 | "configs/sam2/sam2_hiera_l.yaml", 50 | "sam2_hiera_large.pt", 51 | ), 52 | "facebook/sam2.1-hiera-tiny": ( 53 | "configs/sam2.1/sam2.1_hiera_t.yaml", 54 | "sam2.1_hiera_tiny.pt", 55 | ), 56 | "facebook/sam2.1-hiera-small": ( 57 | "configs/sam2.1/sam2.1_hiera_s.yaml", 58 | "sam2.1_hiera_small.pt", 59 | ), 60 | "facebook/sam2.1-hiera-base-plus": ( 61 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 62 | "sam2.1_hiera_base_plus.pt", 63 | ), 64 | "facebook/sam2.1-hiera-large": ( 65 | "configs/sam2.1/sam2.1_hiera_l.yaml", 66 | "sam2.1_hiera_large.pt", 67 | ), 68 | } 69 | 70 | 71 | def build_sam2( 72 | config_file, 73 | ckpt_path=None, 74 | device="cuda", 75 | mode="eval", 76 | hydra_overrides_extra=[], 77 | apply_postprocessing=True, 78 | **kwargs, 79 | ): 80 | 81 | if apply_postprocessing: 82 | hydra_overrides_extra = hydra_overrides_extra.copy() 83 | hydra_overrides_extra += [ 84 | # dynamically fall back to multi-mask if the single mask is not stable 85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 88 | ] 89 | # Read config and init model 90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 91 | OmegaConf.resolve(cfg) 92 | model = instantiate(cfg.model, _recursive_=True) 93 | _load_checkpoint(model, ckpt_path) 94 | model = model.to(device) 95 | if mode == "eval": 96 | model.eval() 97 | return model 98 | 99 | 100 | def build_sam2_video_predictor( 101 | config_file, 102 | ckpt_path=None, 103 | device="cuda", 104 | mode="eval", 105 | hydra_overrides_extra=[], 106 | apply_postprocessing=True, 107 | vos_optimized=False, 108 | **kwargs, 109 | ): 110 | hydra_overrides = [ 111 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 112 | ] 113 | if vos_optimized: 114 | hydra_overrides = [ 115 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", 116 | "++model.compile_image_encoder=True", # Let sam2_base handle this 117 | ] 118 | 119 | if apply_postprocessing: 120 | hydra_overrides_extra = hydra_overrides_extra.copy() 121 | hydra_overrides_extra += [ 122 | # dynamically fall back to multi-mask if the single mask is not stable 123 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 124 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 125 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 126 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 127 | "++model.binarize_mask_from_pts_for_mem_enc=true", 128 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 129 | "++model.fill_hole_area=8", 130 | ] 131 | hydra_overrides.extend(hydra_overrides_extra) 132 | 133 | # Read config and init model 134 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 135 | OmegaConf.resolve(cfg) 136 | model = instantiate(cfg.model, _recursive_=True) 137 | _load_checkpoint(model, ckpt_path) 138 | model = model.to(device) 139 | if mode == "eval": 140 | model.eval() 141 | return model 142 | 143 | 144 | def _hf_download(model_id): 145 | from huggingface_hub import hf_hub_download 146 | 147 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 148 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 149 | return config_name, ckpt_path 150 | 151 | 152 | def build_sam2_hf(model_id, **kwargs): 153 | config_name, ckpt_path = _hf_download(model_id) 154 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 155 | 156 | 157 | def build_sam2_video_predictor_hf(model_id, **kwargs): 158 | config_name, ckpt_path = _hf_download(model_id) 159 | return build_sam2_video_predictor( 160 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 161 | ) 162 | 163 | 164 | def _load_checkpoint(model, ckpt_path): 165 | if ckpt_path is not None: 166 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 167 | missing_keys, unexpected_keys = model.load_state_dict(sd) 168 | if missing_keys: 169 | logging.error(missing_keys) 170 | raise RuntimeError() 171 | if unexpected_keys: 172 | logging.error(unexpected_keys) 173 | raise RuntimeError() 174 | logging.info("Loaded checkpoint sucessfully") 175 | -------------------------------------------------------------------------------- /training/dataset/sam2_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | from typing import Callable, Iterable, List, Optional, Sequence 10 | 11 | import torch 12 | 13 | from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset 14 | 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | 18 | class MixedDataLoader: 19 | def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): 20 | """ 21 | Args: 22 | dataloaders (List[DataLoader]): List of DataLoaders to be mixed. 23 | mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from 24 | 25 | """ 26 | assert len(dataloaders) == mixing_prob.shape[0] 27 | self.dataloaders = dataloaders 28 | self.mixing_prob = mixing_prob 29 | # Iterator state 30 | self._iter_dls = None 31 | self._iter_mixing_prob = None 32 | self.random_generator = torch.Generator() 33 | 34 | def __len__(self): 35 | return sum([len(d) for d in self.dataloaders]) 36 | 37 | def __iter__(self): 38 | # Synchronize dataloader seeds 39 | self.random_generator.manual_seed(42) 40 | self._iter_dls = [iter(loader) for loader in self.dataloaders] 41 | self._iter_mixing_prob = self.mixing_prob.clone() 42 | return self 43 | 44 | def __next__(self): 45 | """ 46 | Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. 47 | """ 48 | if self._iter_dls is None: 49 | raise TypeError(f"{type(self).__name__} object is not an iterator") 50 | 51 | while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. 52 | dataset_idx = self._iter_mixing_prob.multinomial( 53 | 1, generator=self.random_generator 54 | ).item() 55 | try: 56 | item = next(self._iter_dls[dataset_idx]) 57 | return item 58 | except StopIteration: 59 | # No more iterations for this dataset, set it's mixing probability to zero and try again. 60 | self._iter_mixing_prob[dataset_idx] = 0 61 | except Exception as e: 62 | # log and raise any other unexpected error. 63 | logging.error(e) 64 | raise e 65 | 66 | # Exhausted all iterators 67 | raise StopIteration 68 | 69 | 70 | class TorchTrainMixedDataset: 71 | def __init__( 72 | self, 73 | datasets: List[Dataset], 74 | batch_sizes: List[int], 75 | num_workers: int, 76 | shuffle: bool, 77 | pin_memory: bool, 78 | drop_last: bool, 79 | collate_fn: Optional[Callable] = None, 80 | worker_init_fn: Optional[Callable] = None, 81 | phases_per_epoch: int = 1, 82 | dataset_prob: Optional[List[float]] = None, 83 | ) -> None: 84 | """ 85 | Args: 86 | datasets (List[Dataset]): List of Datasets to be mixed. 87 | batch_sizes (List[int]): Batch sizes for each dataset in the list. 88 | num_workers (int): Number of workers per dataloader. 89 | shuffle (bool): Whether or not to shuffle data. 90 | pin_memory (bool): If True, use pinned memory when loading tensors from disk. 91 | drop_last (bool): Whether or not to drop the last batch of data. 92 | collate_fn (Callable): Function to merge a list of samples into a mini-batch. 93 | worker_init_fn (Callable): Function to init each dataloader worker. 94 | phases_per_epoch (int): Number of phases per epoch. 95 | dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 96 | """ 97 | 98 | self.datasets = datasets 99 | self.batch_sizes = batch_sizes 100 | self.num_workers = num_workers 101 | self.shuffle = shuffle 102 | self.pin_memory = pin_memory 103 | self.drop_last = drop_last 104 | self.collate_fn = collate_fn 105 | self.worker_init_fn = worker_init_fn 106 | assert len(self.datasets) > 0 107 | for dataset in self.datasets: 108 | assert not isinstance(dataset, IterableDataset), "Not supported" 109 | # `RepeatFactorWrapper` requires calling set_epoch first to get its length 110 | self._set_dataset_epoch(dataset, 0) 111 | self.phases_per_epoch = phases_per_epoch 112 | self.chunks = [None] * len(datasets) 113 | if dataset_prob is None: 114 | # If not provided, assign each dataset a probability proportional to its length. 115 | dataset_lens = [ 116 | (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) 117 | for d, bs in zip(datasets, batch_sizes) 118 | ] 119 | total_len = sum(dataset_lens) 120 | dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) 121 | else: 122 | assert len(dataset_prob) == len(datasets) 123 | dataset_prob = torch.tensor(dataset_prob) 124 | 125 | logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") 126 | assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" 127 | self.dataset_prob = dataset_prob 128 | 129 | def _set_dataset_epoch(self, dataset, epoch: int) -> None: 130 | if hasattr(dataset, "epoch"): 131 | dataset.epoch = epoch 132 | if hasattr(dataset, "set_epoch"): 133 | dataset.set_epoch(epoch) 134 | 135 | def get_loader(self, epoch) -> Iterable: 136 | dataloaders = [] 137 | for d_idx, (dataset, batch_size) in enumerate( 138 | zip(self.datasets, self.batch_sizes) 139 | ): 140 | if self.phases_per_epoch > 1: 141 | # Major epoch that looops over entire dataset 142 | # len(main_epoch) == phases_per_epoch * len(epoch) 143 | main_epoch = epoch // self.phases_per_epoch 144 | 145 | # Phase with in the main epoch 146 | local_phase = epoch % self.phases_per_epoch 147 | 148 | # Start of new data-epoch or job is resumed after preemtion. 149 | if local_phase == 0 or self.chunks[d_idx] is None: 150 | # set seed for dataset epoch 151 | # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. 152 | self._set_dataset_epoch(dataset, main_epoch) 153 | 154 | # Separate random generator for subset sampling 155 | g = torch.Generator() 156 | g.manual_seed(main_epoch) 157 | self.chunks[d_idx] = torch.chunk( 158 | torch.randperm(len(dataset), generator=g), 159 | self.phases_per_epoch, 160 | ) 161 | 162 | dataset = Subset(dataset, self.chunks[d_idx][local_phase]) 163 | else: 164 | self._set_dataset_epoch(dataset, epoch) 165 | 166 | sampler = DistributedSampler(dataset, shuffle=self.shuffle) 167 | sampler.set_epoch(epoch) 168 | 169 | batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) 170 | dataloaders.append( 171 | DataLoader( 172 | dataset, 173 | num_workers=self.num_workers, 174 | pin_memory=self.pin_memory, 175 | batch_sampler=batch_sampler, 176 | collate_fn=self.collate_fn, 177 | worker_init_fn=self.worker_init_fn, 178 | ) 179 | ) 180 | return MixedDataLoader(dataloaders, self.dataset_prob) 181 | -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | 96 | point_embedding = torch.where( 97 | (labels == -1).unsqueeze(-1), 98 | torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, 99 | point_embedding, 100 | ) 101 | point_embedding = torch.where( 102 | (labels == 0).unsqueeze(-1), 103 | point_embedding + self.point_embeddings[0].weight, 104 | point_embedding, 105 | ) 106 | point_embedding = torch.where( 107 | (labels == 1).unsqueeze(-1), 108 | point_embedding + self.point_embeddings[1].weight, 109 | point_embedding, 110 | ) 111 | point_embedding = torch.where( 112 | (labels == 2).unsqueeze(-1), 113 | point_embedding + self.point_embeddings[2].weight, 114 | point_embedding, 115 | ) 116 | point_embedding = torch.where( 117 | (labels == 3).unsqueeze(-1), 118 | point_embedding + self.point_embeddings[3].weight, 119 | point_embedding, 120 | ) 121 | return point_embedding 122 | 123 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 124 | """Embeds box prompts.""" 125 | boxes = boxes + 0.5 # Shift to center of pixel 126 | coords = boxes.reshape(-1, 2, 2) 127 | corner_embedding = self.pe_layer.forward_with_coords( 128 | coords, self.input_image_size 129 | ) 130 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 131 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 132 | return corner_embedding 133 | 134 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 135 | """Embeds mask inputs.""" 136 | mask_embedding = self.mask_downscaling(masks) 137 | return mask_embedding 138 | 139 | def _get_batch_size( 140 | self, 141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 142 | boxes: Optional[torch.Tensor], 143 | masks: Optional[torch.Tensor], 144 | ) -> int: 145 | """ 146 | Gets the batch size of the output given the batch size of the input prompts. 147 | """ 148 | if points is not None: 149 | return points[0].shape[0] 150 | elif boxes is not None: 151 | return boxes.shape[0] 152 | elif masks is not None: 153 | return masks.shape[0] 154 | else: 155 | return 1 156 | 157 | def _get_device(self) -> torch.device: 158 | return self.point_embeddings[0].weight.device 159 | 160 | def forward( 161 | self, 162 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 163 | boxes: Optional[torch.Tensor], 164 | masks: Optional[torch.Tensor], 165 | ) -> Tuple[torch.Tensor, torch.Tensor]: 166 | """ 167 | Embeds different types of prompts, returning both sparse and dense 168 | embeddings. 169 | 170 | Arguments: 171 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 172 | and labels to embed. 173 | boxes (torch.Tensor or none): boxes to embed 174 | masks (torch.Tensor or none): masks to embed 175 | 176 | Returns: 177 | torch.Tensor: sparse embeddings for the points and boxes, with shape 178 | BxNx(embed_dim), where N is determined by the number of input points 179 | and boxes. 180 | torch.Tensor: dense embeddings for the masks, in the shape 181 | Bx(embed_dim)x(embed_H)x(embed_W) 182 | """ 183 | bs = self._get_batch_size(points, boxes, masks) 184 | sparse_embeddings = torch.empty( 185 | (bs, 0, self.embed_dim), device=self._get_device() 186 | ) 187 | if points is not None: 188 | coords, labels = points 189 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 190 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 191 | if boxes is not None: 192 | box_embeddings = self._embed_boxes(boxes) 193 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 194 | 195 | if masks is not None: 196 | dense_embeddings = self._embed_masks(masks) 197 | else: 198 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 199 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 200 | ) 201 | 202 | return sparse_embeddings, dense_embeddings 203 | -------------------------------------------------------------------------------- /sav_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Video (SA-V) Dataset 2 | 3 | ## Overview 4 | 5 | [Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset). 6 | 7 | ![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true) 8 | 9 | ## Getting Started 10 | 11 | ### Download the dataset 12 | 13 | Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets. 14 | 15 | ### Dataset Stats 16 | 17 | | | Num Videos | Num Masklets | 18 | | ---------- | ---------- | ----------------------------------------- | 19 | | SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) | 20 | | SA-V val | 155 | 293 | 21 | | SA-V test | 150 | 278 | 22 | 23 | ### Notebooks 24 | 25 | To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook. 26 | 27 | ### SA-V train 28 | 29 | For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below. 30 | 31 | ``` 32 | { 33 | "video_id" : str; video id 34 | "video_duration" : float64; the duration in seconds of this video 35 | "video_frame_count" : float64; the number of frames in the video 36 | "video_height" : float64; the height of the video 37 | "video_width" : float64; the width of the video 38 | "video_resolution" : float64; video_height $\times$ video_width 39 | "video_environment" : List[str]; "Indoor" or "Outdoor" 40 | "video_split" : str; "train" for training set 41 | "masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs. 42 | The outer list is over frames in the video and the inner list 43 | is over objects in the video. 44 | "masklet_id" : List[int]; the masklet ids 45 | "masklet_size_rel" : List[float]; the average mask area normalized by resolution 46 | across all the frames where the object is visible 47 | "masklet_size_abs" : List[float]; the average mask area (in pixels) 48 | across all the frames where the object is visible 49 | "masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$, 50 | "medium": $32^2$ <= masklet_size_abs < $96^2$, 51 | and "large": masklet_size_abs > $96^2$ 52 | "masklet_visibility_changes" : List[int]; the number of times where the visibility changes 53 | after the first appearance (e.g., invisible -> visible 54 | or visible -> invisible) 55 | "masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears 56 | the first time in the video. Always 0 for auto masklets. 57 | "masklet_frame_count" : List[int]; the number of frames being annotated. Note that 58 | videos are annotated at 6 fps (annotated every 4 frames) 59 | while the videos are at 24 fps. 60 | "masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators. 61 | Always 0 for auto masklets. 62 | "masklet_type" : List[str]; "auto" or "manual" 63 | "masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only. 64 | "masklet_num" : int; the number of manual/auto masklets in the video 65 | 66 | } 67 | ``` 68 | 69 | Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations. 70 | 71 | ### SA-V val and test 72 | 73 | For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure: 74 | 75 | ``` 76 | sav_val(sav_test) 77 | ├── sav_val.txt (sav_test.txt): a list of video ids in the split 78 | ├── JPEGImages_24fps # videos are extracted at 24 fps 79 | │ ├── {video_id} 80 | │ │ ├── 00000.jpg # video frame 81 | │ │ ├── 00001.jpg # video frame 82 | │ │ ├── 00002.jpg # video frame 83 | │ │ ├── 00003.jpg # video frame 84 | │ │ └── ... 85 | │ ├── {video_id} 86 | │ ├── {video_id} 87 | │ └── ... 88 | └── Annotations_6fps # videos are annotated at 6 fps 89 | ├── {video_id} 90 | │ ├── 000 # obj 000 91 | │ │ ├── 00000.png # mask for object 000 in 00000.jpg 92 | │ │ ├── 00004.png # mask for object 000 in 00004.jpg 93 | │ │ ├── 00008.png # mask for object 000 in 00008.jpg 94 | │ │ ├── 00012.png # mask for object 000 in 00012.jpg 95 | │ │ └── ... 96 | │ ├── 001 # obj 001 97 | │ ├── 002 # obj 002 98 | │ └── ... 99 | ├── {video_id} 100 | ├── {video_id} 101 | └── ... 102 | ``` 103 | 104 | All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands. 105 | 106 | ## SA-V Val and Test Evaluation 107 | 108 | We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows: 109 | 110 | ``` 111 | pip install -r requirements.txt 112 | ``` 113 | 114 | Then we can evaluate the predictions as follows: 115 | 116 | ``` 117 | python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT} 118 | ``` 119 | 120 | or run 121 | 122 | ``` 123 | python sav_evaluator.py --help 124 | ``` 125 | 126 | to print a complete help message. 127 | 128 | The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure. 129 | 130 | - Same as SA-V val and test directory structure 131 | 132 | ``` 133 | {GT_ROOT} # gt root folder 134 | ├── {video_id} 135 | │ ├── 000 # all masks associated with obj 000 136 | │ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask) 137 | │ │ └── ... 138 | │ ├── 001 # all masks associated with obj 001 139 | │ ├── 002 # all masks associated with obj 002 140 | │ └── ... 141 | ├── {video_id} 142 | ├── {video_id} 143 | └── ... 144 | ``` 145 | 146 | In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations. 147 | 148 | - Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure 149 | 150 | ``` 151 | {GT_ROOT} # gt root folder 152 | ├── {video_id} 153 | │ ├── 00000.png # annotations in frame 00000 (may contain multiple objects) 154 | │ └── ... 155 | ├── {video_id} 156 | ├── {video_id} 157 | └── ... 158 | ``` 159 | 160 | ## License 161 | 162 | The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0. 163 | 164 | Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)). 165 | -------------------------------------------------------------------------------- /training/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py 8 | import atexit 9 | import functools 10 | import logging 11 | import sys 12 | import uuid 13 | from typing import Any, Dict, Optional, Union 14 | 15 | from hydra.utils import instantiate 16 | 17 | from iopath.common.file_io import g_pathmgr 18 | from numpy import ndarray 19 | from torch import Tensor 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | from training.utils.train_utils import get_machine_local_and_dist_rank, makedir 23 | 24 | Scalar = Union[Tensor, ndarray, int, float] 25 | 26 | 27 | def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): 28 | makedir(log_dir) 29 | summary_writer_method = SummaryWriter 30 | return TensorBoardLogger( 31 | path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs 32 | ) 33 | 34 | 35 | class TensorBoardWriterWrapper: 36 | """ 37 | A wrapper around a SummaryWriter object. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | path: str, 43 | *args: Any, 44 | filename_suffix: str = None, 45 | summary_writer_method: Any = SummaryWriter, 46 | **kwargs: Any, 47 | ) -> None: 48 | """Create a new TensorBoard logger. 49 | On construction, the logger creates a new events file that logs 50 | will be written to. If the environment variable `RANK` is defined, 51 | logger will only log if RANK = 0. 52 | 53 | NOTE: If using the logger with distributed training: 54 | - This logger can call collective operations 55 | - Logs will be written on rank 0 only 56 | - Logger must be constructed synchronously *after* initializing distributed process group. 57 | 58 | Args: 59 | path (str): path to write logs to 60 | *args, **kwargs: Extra arguments to pass to SummaryWriter 61 | """ 62 | self._writer: Optional[SummaryWriter] = None 63 | _, self._rank = get_machine_local_and_dist_rank() 64 | self._path: str = path 65 | if self._rank == 0: 66 | logging.info( 67 | f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" 68 | ) 69 | self._writer = summary_writer_method( 70 | log_dir=path, 71 | *args, 72 | filename_suffix=filename_suffix or str(uuid.uuid4()), 73 | **kwargs, 74 | ) 75 | else: 76 | logging.debug( 77 | f"Not logging meters on this host because env RANK: {self._rank} != 0" 78 | ) 79 | atexit.register(self.close) 80 | 81 | @property 82 | def writer(self) -> Optional[SummaryWriter]: 83 | return self._writer 84 | 85 | @property 86 | def path(self) -> str: 87 | return self._path 88 | 89 | def flush(self) -> None: 90 | """Writes pending logs to disk.""" 91 | 92 | if not self._writer: 93 | return 94 | 95 | self._writer.flush() 96 | 97 | def close(self) -> None: 98 | """Close writer, flushing pending logs to disk. 99 | Logs cannot be written after `close` is called. 100 | """ 101 | 102 | if not self._writer: 103 | return 104 | 105 | self._writer.close() 106 | self._writer = None 107 | 108 | 109 | class TensorBoardLogger(TensorBoardWriterWrapper): 110 | """ 111 | A simple logger for TensorBoard. 112 | """ 113 | 114 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 115 | """Add multiple scalar values to TensorBoard. 116 | 117 | Args: 118 | payload (dict): dictionary of tag name and scalar value 119 | step (int, Optional): step value to record 120 | """ 121 | if not self._writer: 122 | return 123 | for k, v in payload.items(): 124 | self.log(k, v, step) 125 | 126 | def log(self, name: str, data: Scalar, step: int) -> None: 127 | """Add scalar data to TensorBoard. 128 | 129 | Args: 130 | name (string): tag name used to group scalars 131 | data (float/int/Tensor): scalar data to log 132 | step (int, optional): step value to record 133 | """ 134 | if not self._writer: 135 | return 136 | self._writer.add_scalar(name, data, global_step=step, new_style=True) 137 | 138 | def log_hparams( 139 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 140 | ) -> None: 141 | """Add hyperparameter data to TensorBoard. 142 | 143 | Args: 144 | hparams (dict): dictionary of hyperparameter names and corresponding values 145 | meters (dict): dictionary of name of meter and corersponding values 146 | """ 147 | if not self._writer: 148 | return 149 | self._writer.add_hparams(hparams, meters) 150 | 151 | 152 | class Logger: 153 | """ 154 | A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. 155 | """ 156 | 157 | def __init__(self, logging_conf): 158 | # allow turning off TensorBoard with "should_log: false" in config 159 | tb_config = logging_conf.tensorboard_writer 160 | tb_should_log = tb_config and tb_config.pop("should_log", True) 161 | self.tb_logger = instantiate(tb_config) if tb_should_log else None 162 | 163 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 164 | if self.tb_logger: 165 | self.tb_logger.log_dict(payload, step) 166 | 167 | def log(self, name: str, data: Scalar, step: int) -> None: 168 | if self.tb_logger: 169 | self.tb_logger.log(name, data, step) 170 | 171 | def log_hparams( 172 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 173 | ) -> None: 174 | if self.tb_logger: 175 | self.tb_logger.log_hparams(hparams, meters) 176 | 177 | 178 | # cache the opened file object, so that different calls to `setup_logger` 179 | # with the same file name can safely write to the same file. 180 | @functools.lru_cache(maxsize=None) 181 | def _cached_log_stream(filename): 182 | # we tune the buffering value so that the logs are updated 183 | # frequently. 184 | log_buffer_kb = 10 * 1024 # 10KB 185 | io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) 186 | atexit.register(io.close) 187 | return io 188 | 189 | 190 | def setup_logging( 191 | name, 192 | output_dir=None, 193 | rank=0, 194 | log_level_primary="INFO", 195 | log_level_secondary="ERROR", 196 | ): 197 | """ 198 | Setup various logging streams: stdout and file handlers. 199 | For file handlers, we only setup for the master gpu. 200 | """ 201 | # get the filename if we want to log to the file as well 202 | log_filename = None 203 | if output_dir: 204 | makedir(output_dir) 205 | if rank == 0: 206 | log_filename = f"{output_dir}/log.txt" 207 | 208 | logger = logging.getLogger(name) 209 | logger.setLevel(log_level_primary) 210 | 211 | # create formatter 212 | FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" 213 | formatter = logging.Formatter(FORMAT) 214 | 215 | # Cleanup any existing handlers 216 | for h in logger.handlers: 217 | logger.removeHandler(h) 218 | logger.root.handlers = [] 219 | 220 | # setup the console handler 221 | console_handler = logging.StreamHandler(sys.stdout) 222 | console_handler.setFormatter(formatter) 223 | logger.addHandler(console_handler) 224 | if rank == 0: 225 | console_handler.setLevel(log_level_primary) 226 | else: 227 | console_handler.setLevel(log_level_secondary) 228 | 229 | # we log to file as well if user wants 230 | if log_filename and rank == 0: 231 | file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) 232 | file_handler.setLevel(log_level_primary) 233 | file_handler.setFormatter(formatter) 234 | logger.addHandler(file_handler) 235 | 236 | logging.root = logger 237 | 238 | 239 | def shutdown_logging(): 240 | """ 241 | After training is done, we ensure to shut down all the logger streams. 242 | """ 243 | logging.info("Shutting down loggers...") 244 | handlers = logging.root.handlers 245 | for handler in handlers: 246 | handler.close() 247 | -------------------------------------------------------------------------------- /sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be an even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be an even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | # Following settings only relevant 29 | # for warmping up cache for compilation 30 | warmup_cache: bool = True, 31 | image_size: int = 1024, 32 | strides: Tuple[int] = (4, 8, 16, 32), 33 | ): 34 | super().__init__() 35 | assert num_pos_feats % 2 == 0, "Expecting even model width" 36 | self.num_pos_feats = num_pos_feats // 2 37 | self.temperature = temperature 38 | self.normalize = normalize 39 | if scale is not None and normalize is False: 40 | raise ValueError("normalize should be True if scale is passed") 41 | if scale is None: 42 | scale = 2 * math.pi 43 | self.scale = scale 44 | 45 | self.cache = {} 46 | if warmup_cache and torch.cuda.is_available(): 47 | # Warmup cache for cuda, to help with compilation 48 | device = torch.device("cuda") 49 | for stride in strides: 50 | cache_key = (image_size // stride, image_size // stride) 51 | self._pe(1, device, *cache_key) 52 | 53 | def _encode_xy(self, x, y): 54 | # The positions are expected to be normalized 55 | assert len(x) == len(y) and x.ndim == y.ndim == 1 56 | x_embed = x * self.scale 57 | y_embed = y * self.scale 58 | 59 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 60 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 61 | 62 | pos_x = x_embed[:, None] / dim_t 63 | pos_y = y_embed[:, None] / dim_t 64 | pos_x = torch.stack( 65 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 66 | ).flatten(1) 67 | pos_y = torch.stack( 68 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 69 | ).flatten(1) 70 | return pos_x, pos_y 71 | 72 | @torch.no_grad() 73 | def encode_boxes(self, x, y, w, h): 74 | pos_x, pos_y = self._encode_xy(x, y) 75 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 76 | return pos 77 | 78 | encode = encode_boxes # Backwards compatibility 79 | 80 | @torch.no_grad() 81 | def encode_points(self, x, y, labels): 82 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 83 | assert bx == by and nx == ny and bx == bl and nx == nl 84 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 85 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 86 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 87 | return pos 88 | 89 | @torch.no_grad() 90 | def _pe(self, B, device, *cache_key): 91 | H, W = cache_key 92 | if cache_key in self.cache: 93 | return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) 94 | 95 | y_embed = ( 96 | torch.arange(1, H + 1, dtype=torch.float32, device=device) 97 | .view(1, -1, 1) 98 | .repeat(B, 1, W) 99 | ) 100 | x_embed = ( 101 | torch.arange(1, W + 1, dtype=torch.float32, device=device) 102 | .view(1, 1, -1) 103 | .repeat(B, H, 1) 104 | ) 105 | 106 | if self.normalize: 107 | eps = 1e-6 108 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 109 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 110 | 111 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) 112 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 113 | 114 | pos_x = x_embed[:, :, :, None] / dim_t 115 | pos_y = y_embed[:, :, :, None] / dim_t 116 | pos_x = torch.stack( 117 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 118 | ).flatten(3) 119 | pos_y = torch.stack( 120 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 121 | ).flatten(3) 122 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 123 | self.cache[cache_key] = pos[0] 124 | return pos 125 | 126 | @torch.no_grad() 127 | def forward(self, x: torch.Tensor): 128 | B = x.shape[0] 129 | cache_key = (x.shape[-2], x.shape[-1]) 130 | return self._pe(B, x.device, *cache_key) 131 | 132 | 133 | class PositionEmbeddingRandom(nn.Module): 134 | """ 135 | Positional encoding using random spatial frequencies. 136 | """ 137 | 138 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 139 | super().__init__() 140 | if scale is None or scale <= 0.0: 141 | scale = 1.0 142 | self.register_buffer( 143 | "positional_encoding_gaussian_matrix", 144 | scale * torch.randn((2, num_pos_feats)), 145 | ) 146 | 147 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 148 | """Positionally encode points that are normalized to [0,1].""" 149 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 150 | coords = 2 * coords - 1 151 | coords = coords @ self.positional_encoding_gaussian_matrix 152 | coords = 2 * np.pi * coords 153 | # outputs d_1 x ... x d_n x C shape 154 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 155 | 156 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 157 | """Generate positional encoding for a grid of the specified size.""" 158 | h, w = size 159 | device: Any = self.positional_encoding_gaussian_matrix.device 160 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 161 | y_embed = grid.cumsum(dim=0) - 0.5 162 | x_embed = grid.cumsum(dim=1) - 0.5 163 | y_embed = y_embed / h 164 | x_embed = x_embed / w 165 | 166 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 167 | return pe.permute(2, 0, 1) # C x H x W 168 | 169 | def forward_with_coords( 170 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 171 | ) -> torch.Tensor: 172 | """Positionally encode points that are not normalized to [0,1].""" 173 | coords = coords_input.clone() 174 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 175 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 176 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 177 | 178 | 179 | # Rotary Positional Encoding, adapted from: 180 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 181 | # 2. https://github.com/naver-ai/rope-vit 182 | # 3. https://github.com/lucidrains/rotary-embedding-torch 183 | 184 | 185 | def init_t_xy(end_x: int, end_y: int): 186 | t = torch.arange(end_x * end_y, dtype=torch.float32) 187 | t_x = (t % end_x).float() 188 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 189 | return t_x, t_y 190 | 191 | 192 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 193 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 194 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 195 | 196 | t_x, t_y = init_t_xy(end_x, end_y) 197 | freqs_x = torch.outer(t_x, freqs_x) 198 | freqs_y = torch.outer(t_y, freqs_y) 199 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 200 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 201 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 202 | 203 | 204 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 205 | ndim = x.ndim 206 | assert 0 <= 1 < ndim 207 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 208 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 209 | return freqs_cis.view(*shape) 210 | 211 | 212 | def apply_rotary_enc( 213 | xq: torch.Tensor, 214 | xk: torch.Tensor, 215 | freqs_cis: torch.Tensor, 216 | repeat_freqs_k: bool = False, 217 | ): 218 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 219 | xk_ = ( 220 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 221 | if xk.shape[-2] != 0 222 | else None 223 | ) 224 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 225 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 226 | if xk_ is None: 227 | # no keys to rotate, due to dropout 228 | return xq_out.type_as(xq).to(xq.device), xk 229 | # repeat freqs along seq_len dim to match k seq_len 230 | if repeat_freqs_k: 231 | r = xk_.shape[-2] // xq_.shape[-2] 232 | if freqs_cis.is_cuda: 233 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 234 | else: 235 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 236 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 237 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 238 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 239 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 240 | -------------------------------------------------------------------------------- /training/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | import os 10 | import random 11 | import re 12 | from datetime import timedelta 13 | from typing import Optional 14 | 15 | import hydra 16 | 17 | import numpy as np 18 | import omegaconf 19 | import torch 20 | import torch.distributed as dist 21 | from iopath.common.file_io import g_pathmgr 22 | from omegaconf import OmegaConf 23 | 24 | 25 | def multiply_all(*args): 26 | return np.prod(np.array(args)).item() 27 | 28 | 29 | def collect_dict_keys(config): 30 | """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" 31 | val_keys = [] 32 | # If the this config points to the collate function, then it has a key 33 | if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): 34 | val_keys.append(config["dict_key"]) 35 | else: 36 | # Recursively proceed 37 | for v in config.values(): 38 | if isinstance(v, type(config)): 39 | val_keys.extend(collect_dict_keys(v)) 40 | elif isinstance(v, omegaconf.listconfig.ListConfig): 41 | for item in v: 42 | if isinstance(item, type(config)): 43 | val_keys.extend(collect_dict_keys(item)) 44 | return val_keys 45 | 46 | 47 | class Phase: 48 | TRAIN = "train" 49 | VAL = "val" 50 | 51 | 52 | def register_omegaconf_resolvers(): 53 | OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) 54 | OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) 55 | OmegaConf.register_new_resolver("add", lambda x, y: x + y) 56 | OmegaConf.register_new_resolver("times", multiply_all) 57 | OmegaConf.register_new_resolver("divide", lambda x, y: x / y) 58 | OmegaConf.register_new_resolver("pow", lambda x, y: x**y) 59 | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) 60 | OmegaConf.register_new_resolver("range", lambda x: list(range(x))) 61 | OmegaConf.register_new_resolver("int", lambda x: int(x)) 62 | OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) 63 | OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) 64 | 65 | 66 | def setup_distributed_backend(backend, timeout_mins): 67 | """ 68 | Initialize torch.distributed and set the CUDA device. 69 | Expects environment variables to be set as per 70 | https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization 71 | along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. 72 | """ 73 | # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins 74 | # of waiting 75 | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" 76 | logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") 77 | dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) 78 | return dist.get_rank() 79 | 80 | 81 | def get_machine_local_and_dist_rank(): 82 | """ 83 | Get the distributed and local rank of the current gpu. 84 | """ 85 | local_rank = int(os.environ.get("LOCAL_RANK", None)) 86 | distributed_rank = int(os.environ.get("RANK", None)) 87 | assert ( 88 | local_rank is not None and distributed_rank is not None 89 | ), "Please the set the RANK and LOCAL_RANK environment variables." 90 | return local_rank, distributed_rank 91 | 92 | 93 | def print_cfg(cfg): 94 | """ 95 | Supports printing both Hydra DictConfig and also the AttrDict config 96 | """ 97 | logging.info("Training with config:") 98 | logging.info(OmegaConf.to_yaml(cfg)) 99 | 100 | 101 | def set_seeds(seed_value, max_epochs, dist_rank): 102 | """ 103 | Set the python random, numpy and torch seed for each gpu. Also set the CUDA 104 | seeds if the CUDA is available. This ensures deterministic nature of the training. 105 | """ 106 | # Since in the pytorch sampler, we increment the seed by 1 for every epoch. 107 | seed_value = (seed_value + dist_rank) * max_epochs 108 | logging.info(f"MACHINE SEED: {seed_value}") 109 | random.seed(seed_value) 110 | np.random.seed(seed_value) 111 | torch.manual_seed(seed_value) 112 | if torch.cuda.is_available(): 113 | torch.cuda.manual_seed_all(seed_value) 114 | 115 | 116 | def makedir(dir_path): 117 | """ 118 | Create the directory if it does not exist. 119 | """ 120 | is_success = False 121 | try: 122 | if not g_pathmgr.exists(dir_path): 123 | g_pathmgr.mkdirs(dir_path) 124 | is_success = True 125 | except BaseException: 126 | logging.info(f"Error creating directory: {dir_path}") 127 | return is_success 128 | 129 | 130 | def is_dist_avail_and_initialized(): 131 | if not dist.is_available(): 132 | return False 133 | if not dist.is_initialized(): 134 | return False 135 | return True 136 | 137 | 138 | def get_amp_type(amp_type: Optional[str] = None): 139 | if amp_type is None: 140 | return None 141 | assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." 142 | if amp_type == "bfloat16": 143 | return torch.bfloat16 144 | else: 145 | return torch.float16 146 | 147 | 148 | def log_env_variables(): 149 | env_keys = sorted(list(os.environ.keys())) 150 | st = "" 151 | for k in env_keys: 152 | v = os.environ[k] 153 | st += f"{k}={v}\n" 154 | logging.info("Logging ENV_VARIABLES") 155 | logging.info(st) 156 | 157 | 158 | class AverageMeter: 159 | """Computes and stores the average and current value""" 160 | 161 | def __init__(self, name, device, fmt=":f"): 162 | self.name = name 163 | self.fmt = fmt 164 | self.device = device 165 | self.reset() 166 | 167 | def reset(self): 168 | self.val = 0 169 | self.avg = 0 170 | self.sum = 0 171 | self.count = 0 172 | self._allow_updates = True 173 | 174 | def update(self, val, n=1): 175 | self.val = val 176 | self.sum += val * n 177 | self.count += n 178 | self.avg = self.sum / self.count 179 | 180 | def __str__(self): 181 | fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" 182 | return fmtstr.format(**self.__dict__) 183 | 184 | 185 | class MemMeter: 186 | """Computes and stores the current, avg, and max of peak Mem usage per iteration""" 187 | 188 | def __init__(self, name, device, fmt=":f"): 189 | self.name = name 190 | self.fmt = fmt 191 | self.device = device 192 | self.reset() 193 | 194 | def reset(self): 195 | self.val = 0 # Per iteration max usage 196 | self.avg = 0 # Avg per iteration max usage 197 | self.peak = 0 # Peak usage for lifetime of program 198 | self.sum = 0 199 | self.count = 0 200 | self._allow_updates = True 201 | 202 | def update(self, n=1, reset_peak_usage=True): 203 | self.val = torch.cuda.max_memory_allocated() // 1e9 204 | self.sum += self.val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | self.peak = max(self.peak, self.val) 208 | if reset_peak_usage: 209 | torch.cuda.reset_peak_memory_stats() 210 | 211 | def __str__(self): 212 | fmtstr = ( 213 | "{name}: {val" 214 | + self.fmt 215 | + "} ({avg" 216 | + self.fmt 217 | + "}/{peak" 218 | + self.fmt 219 | + "})" 220 | ) 221 | return fmtstr.format(**self.__dict__) 222 | 223 | 224 | def human_readable_time(time_seconds): 225 | time = int(time_seconds) 226 | minutes, seconds = divmod(time, 60) 227 | hours, minutes = divmod(minutes, 60) 228 | days, hours = divmod(hours, 24) 229 | return f"{days:02}d {hours:02}h {minutes:02}m" 230 | 231 | 232 | class DurationMeter: 233 | def __init__(self, name, device, fmt=":f"): 234 | self.name = name 235 | self.device = device 236 | self.fmt = fmt 237 | self.val = 0 238 | 239 | def reset(self): 240 | self.val = 0 241 | 242 | def update(self, val): 243 | self.val = val 244 | 245 | def add(self, val): 246 | self.val += val 247 | 248 | def __str__(self): 249 | return f"{self.name}: {human_readable_time(self.val)}" 250 | 251 | 252 | class ProgressMeter: 253 | def __init__(self, num_batches, meters, real_meters, prefix=""): 254 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 255 | self.meters = meters 256 | self.real_meters = real_meters 257 | self.prefix = prefix 258 | 259 | def display(self, batch, enable_print=False): 260 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 261 | entries += [str(meter) for meter in self.meters] 262 | entries += [ 263 | " | ".join( 264 | [ 265 | f"{os.path.join(name, subname)}: {val:.4f}" 266 | for subname, val in meter.compute().items() 267 | ] 268 | ) 269 | for name, meter in self.real_meters.items() 270 | ] 271 | logging.info(" | ".join(entries)) 272 | if enable_print: 273 | print(" | ".join(entries)) 274 | 275 | def _get_batch_fmtstr(self, num_batches): 276 | num_digits = len(str(num_batches // 1)) 277 | fmt = "{:" + str(num_digits) + "d}" 278 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 279 | 280 | 281 | def get_resume_checkpoint(checkpoint_save_dir): 282 | if not g_pathmgr.isdir(checkpoint_save_dir): 283 | return None 284 | ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") 285 | if not g_pathmgr.isfile(ckpt_file): 286 | return None 287 | 288 | return ckpt_file 289 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Requirements 4 | 5 | - Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this. 6 | * Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`. 7 | - [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command. 8 | - If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu. 9 | 10 | Then, install SAM 2 from the root of this repository via 11 | ```bash 12 | pip install -e ".[notebooks]" 13 | ``` 14 | 15 | Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows: 16 | ```bash 17 | # skip the SAM 2 CUDA extension 18 | SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]" 19 | ``` 20 | This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases. 21 | 22 | ### Building the SAM 2 CUDA extension 23 | 24 | By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.) 25 | 26 | If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases. 27 | 28 | If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows 29 | ```bash 30 | pip uninstall -y SAM-2 && \ 31 | rm -f ./sam2/*.so && \ 32 | SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]" 33 | ``` 34 | 35 | Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`. 36 | 37 | Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime. 38 | 39 | ### Common Installation Issues 40 | 41 | Click each issue for its solutions: 42 | 43 |
44 | 45 | I got `ImportError: cannot import name '_C' from 'sam2'` 46 | 47 |
48 | 49 | This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails. 50 | 51 | In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77. 52 |
53 | 54 |
55 | 56 | I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'` 57 | 58 |
59 | 60 | This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via 61 | ```bash 62 | export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo 63 | export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}" 64 | ``` 65 | to manually add `sam2_configs` into your Python's `sys.path`. 66 | 67 |
68 | 69 |
70 | 71 | I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints 72 | 73 |
74 | 75 | This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps: 76 | 77 | 1. pull the latest code from the `main` branch of this repo 78 | 2. run `pip uninstall -y SAM-2` to uninstall any previous installations 79 | 3. then install the latest repo again using `pip install -e ".[notebooks]"` 80 | 81 | In case the steps above still don't resolve the error, please try running in your Python environment the following 82 | ```python 83 | from sam2.modeling import sam2_base 84 | 85 | print(sam2_base.__file__) 86 | ``` 87 | and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation. 88 | 89 |
90 | 91 |
92 | 93 | My installation failed with `CUDA_HOME environment variable is not set` 94 | 95 |
96 | 97 | This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via 98 | ``` 99 | export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path 100 | ``` 101 | and rerun the installation. 102 | 103 | Also, you should make sure 104 | ``` 105 | python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)' 106 | ``` 107 | print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up. 108 | 109 | If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command: 110 | ``` 111 | pip install --no-build-isolation -e . 112 | ``` 113 | 114 |
115 | 116 |
117 | 118 | I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors) 119 | 120 |
121 | 122 | This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version. 123 | 124 | In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`. 125 | 126 | We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0. 127 |
128 | 129 |
130 | 131 | I got `CUDA error: no kernel image is available for execution on the device` 132 | 133 |
134 | 135 | A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system). 136 | 137 | You can try pulling the latest code from the SAM 2 repo and running the following 138 | ``` 139 | export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0` 140 | ``` 141 | to manually specify the CUDA capability in the compilation target that matches your GPU. 142 |
143 | 144 |
145 | 146 | I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors) 147 | 148 |
149 | 150 | This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line 151 | ```python 152 | OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() 153 | ``` 154 | in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with 155 | ```python 156 | OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True 157 | ``` 158 | to relax the attention kernel setting and use other kernels than Flash Attention. 159 |
160 | 161 |
162 | 163 | I got `Error compiling objects for extension` 164 | 165 |
166 | 167 | You may see error log of: 168 | > unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk. 169 | 170 | This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).
171 | You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py).
172 | After adding the argument, `get_extension()` will look like this: 173 | ```python 174 | def get_extensions(): 175 | srcs = ["sam2/csrc/connected_components.cu"] 176 | compile_args = { 177 | "cxx": [], 178 | "nvcc": [ 179 | "-DCUDA_HAS_FP16=1", 180 | "-D__CUDA_NO_HALF_OPERATORS__", 181 | "-D__CUDA_NO_HALF_CONVERSIONS__", 182 | "-D__CUDA_NO_HALF2_OPERATORS__", 183 | "-allow-unsupported-compiler" # Add this argument 184 | ], 185 | } 186 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 187 | return ext_modules 188 | ``` 189 |
190 | --------------------------------------------------------------------------------