├── LICENSE ├── README.md ├── app.py ├── checkpoints └── README.md ├── download.sh ├── examples └── infer_CT_LUNA25.py ├── gitignore ├── medsam2_infer_3D_CT.py ├── medsam2_infer_video.py ├── multi_node_train.sh ├── notebooks ├── MedSAM2_Inference_Video.ipynb └── MedSAM2_inference_CT_Lesion.ipynb ├── pyproject.toml ├── sam2 ├── _C.so ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── build_sam.cpython-312.pyc │ ├── sam2_image_predictor.cpython-312.pyc │ └── sam2_video_predictor_npz.cpython-312.pyc ├── build_sam.py ├── configs │ ├── sam2.1_hiera_t512.yaml │ └── sam2.1_hiera_tiny_finetune512.yaml ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-312.pyc │ │ ├── memory_attention.cpython-312.pyc │ │ ├── memory_encoder.cpython-312.pyc │ │ ├── position_encoding.cpython-312.pyc │ │ ├── sam2_base.cpython-312.pyc │ │ └── sam2_utils.cpython-312.pyc │ ├── backbones │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── hieradet.cpython-312.pyc │ │ │ ├── image_encoder.cpython-312.pyc │ │ │ └── utils.cpython-312.pyc │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── mask_decoder.cpython-312.pyc │ │ │ ├── prompt_encoder.cpython-312.pyc │ │ │ └── transformer.cpython-312.pyc │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_image_predictor.py ├── sam2_video_predictor.py ├── sam2_video_predictor_npz.py ├── sam2_video_trainer.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── misc.cpython-312.pyc │ └── transforms.cpython-312.pyc │ ├── amg.py │ ├── misc.py │ └── transforms.py ├── setup.py └── training ├── __init__.py ├── __pycache__ ├── __init__.cpython-312.pyc ├── loss_fns.cpython-312.pyc ├── optimizer.cpython-312.pyc └── trainer.cpython-312.pyc ├── assets ├── MOSE_sample_train_list.txt └── MOSE_sample_val_list.txt ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ ├── sam2_datasets.cpython-312.pyc │ ├── transforms.cpython-312.pyc │ ├── utils.cpython-312.pyc │ ├── vos_dataset.cpython-312.pyc │ ├── vos_raw_dataset.cpython-312.pyc │ ├── vos_sampler.cpython-312.pyc │ └── vos_segment_loader.cpython-312.pyc ├── sam2_datasets.py ├── transforms.py ├── utils.py ├── vos_dataset.py ├── vos_raw_dataset.py ├── vos_sampler.py └── vos_segment_loader.py ├── loss_fns.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-312.pyc │ └── sam2.cpython-312.pyc └── sam2.py ├── optimizer.py ├── scripts └── sav_frame_extraction_submitit.py ├── train.py ├── trainer.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-312.pyc ├── checkpoint_utils.cpython-312.pyc ├── data_utils.cpython-312.pyc ├── distributed.cpython-312.pyc ├── logger.cpython-312.pyc └── train_utils.cpython-312.pyc ├── checkpoint_utils.py ├── data_utils.py ├── distributed.py ├── logger.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedSAM2 2 | Segment Anything in 3D Medical Images and Videos 3 | 4 |
5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |
PaperProjectCodeHuggingFace Model
Dataset ListCT_DeepLesion-MedSAM2LLD-MMRI-MedSAM23D Slicer
Gradio AppCT-Seg-DemoVideo-Seg-DemoBibTeX
25 |
26 | 27 | Welcome to join our [mailing list](https://forms.gle/bLxGb5SEpdLCUChQ7) to get updates. We’re also actively looking to collaborate on annotating new large-scale 3D datasets. If you have unlabeled medical images or videos and want to share them with the community, let’s connect! 28 | 29 | 30 | ## Installation 31 | 32 | - Create a virtual environment: `conda create -n medsam2 python=3.12 -y` and `conda activate medsam2` 33 | - Install [PyTorch](https://pytorch.org/get-started/locally/): `pip3 install torch torchvision` (Linux CUDA 12.4) 34 | - Download code `git clone https://github.com/bowang-lab/MedSAM2.git && cd MedSAM2` and run `pip install -e ".[dev]"` 35 | - Download checkpoints: `sh download.sh` 36 | - Optional: Please install the following dependencies for gradio 37 | 38 | ```bash 39 | sudo apt-get update 40 | sudo apt-get install ffmpeg 41 | pip install gradio==3.38.0 42 | pip install numpy==1.26.3 43 | pip install ffmpeg-python 44 | pip install moviepy 45 | ``` 46 | 47 | ## Download annotated datasets 48 | 49 | - [CT_DeepLesion-MedSAM2](https://huggingface.co/datasets/wanglab/CT_DeepLesion-MedSAM2) 50 | 51 | 52 | 53 | - [LLD-MMRI-MedSAM2](https://huggingface.co/datasets/wanglab/LLD-MMRI-MedSAM2) 54 | 55 | Note: Please also cite the raw [DeepLesion](https://doi.org/10.1117/1.JMI.5.3.036501) and [LLD-MMRI](https://www.sciencedirect.com/science/article/pii/S0893608025001078) dataset paper when using these datasets. 56 | 57 | - [RVENET](https://rvenet.github.io/dataset/): Waiting for authors' approval to release the mask. 58 | 59 | 60 | ## Inference 61 | 62 | ### 3D medical image segmentation 63 | 64 | - [Colab](https://colab.research.google.com/drive/1MKna9Sg9c78LNcrVyG58cQQmaePZq2k2?usp=sharing): [MedSAM2_inference_CT_Lesion_Demo.ipynb](notebooks/MedSAM2_inference_CT_Lesion.ipynb) 65 | 66 | - CMD 67 | 68 | ```bash 69 | python medsam2_infer_3D_CT.py -i CT_DeepLesion/images -o CT_DeepLesion/segmentation 70 | ``` 71 | 72 | ### Medical video segmentation 73 | 74 | - [Colab](https://colab.research.google.com/drive/16niRHqdDZMCGV7lKuagNq_r_CEHtKY1f?usp=sharing): [MedSAM2_Inference_Video_Demo.ipynb](notebooks/MedSAM2_Inference_Video.ipynb) 75 | 76 | 77 | - CMD 78 | 79 | ```bash 80 | python medsam2_infer_video.py -i input_video_path -m input_mask_path -o output_video_path 81 | ``` 82 | 83 | 84 | 85 | 86 | ### Gradio demo 87 | 88 | ```bash 89 | python app.py 90 | ``` 91 | 92 | ## Training 93 | 94 | Specify dataset path in `sam2/configs/sam2.1_hiera_tiny_finetune512.yaml` 95 | 96 | ```bash 97 | sbatch multi_node_train.sh 98 | ``` 99 | 100 | ## Acknowledgements 101 | 102 | - We highly appreciate all the challenge organizers and dataset owners for providing the public datasets to the community. 103 | - We thank Meta AI for making the source code of [SAM2](https://github.com/facebookresearch/sam2) publicly available. Please also cite this paper when using MedSAM2. 104 | 105 | 106 | ## Bibtex 107 | 108 | ```bash 109 | @article{MedSAM2, 110 | title={MedSAM2: Segment Anything in 3D Medical Images and Videos}, 111 | author={Ma, Jun and Yang, Zongxin and Kim, Sumin and Chen, Bihui and Baharoon, Mohammed and Fallahpour, Adibvafa and Asakereh, Reza and Lyu, Hongwei and Wang, Bo}, 112 | journal={arXiv preprint arXiv:2504.03600}, 113 | year={2025} 114 | } 115 | ``` 116 | Please also cite SAM2 117 | ``` 118 | @article{ravi2024sam2, 119 | title={SAM 2: Segment Anything in Images and Videos}, 120 | author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, 121 | journal={arXiv preprint arXiv:2408.00714}, 122 | url={https://arxiv.org/abs/2408.00714}, 123 | year={2024} 124 | } 125 | ``` 126 | 127 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | 2 | Download checkpoints `sh download.sh` 3 | 4 | - `MedSAM2_2411.pt`: The based model trained in Nov. 2024 5 | - `MedSAM2_US_Heart.pt`: Fine-tuned model for heart ultrasound video segmentation 6 | - `MedSAM2_MRI_LiverLesion.pt`: Fine-tuned model for liver lesion MRI segmentation 7 | - `MedSAM2_CTLesion.pt`: Fine-tuned model for CT lesion segmentation 8 | - `MedSAM2_latest.pt` (recommended): Latest model trained on the combination of existing public datasets and newly annotated datasets 9 | 10 | 11 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Script to download MedSAM2 model checkpoints 3 | # Create checkpoints directory if it doesn't exist 4 | mkdir -p checkpoints 5 | # Use either wget or curl to download the checkpoints 6 | if command -v wget > /dev/null 2>&1; then 7 | CMD="wget -P checkpoints" 8 | elif command -v curl > /dev/null 2>&1; then 9 | CMD="curl -L -o" 10 | CURL=1 11 | else 12 | echo "Please install wget or curl to download the checkpoints." 13 | exit 1 14 | fi 15 | # Define the base URL for MedSAM2 models on Hugging Face 16 | HF_BASE_URL="https://huggingface.co/wanglab/MedSAM2/resolve/main" 17 | # Define the model checkpoint files (as separate variables instead of an array) 18 | MODEL1="MedSAM2_2411.pt" 19 | MODEL2="MedSAM2_US_Heart.pt" 20 | MODEL3="MedSAM2_MRI_LiverLesion.pt" 21 | MODEL4="MedSAM2_CTLesion.pt" 22 | MODEL5="MedSAM2_latest.pt" 23 | 24 | # Download each checkpoint 25 | for model in $MODEL1 $MODEL2 $MODEL3 $MODEL4 $MODEL5; do 26 | echo "Downloading ${model}..." 27 | model_url="${HF_BASE_URL}/${model}" 28 | 29 | if [ -n "$CURL" ]; then 30 | $CMD "checkpoints/${model}" "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; } 31 | else 32 | $CMD "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; } 33 | fi 34 | done 35 | echo "All MedSAM2 model checkpoints have been downloaded successfully to the 'checkpoints' directory." -------------------------------------------------------------------------------- /examples/infer_CT_LUNA25.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to run inference on the LUNA25 dataset using the MedSAM2 CT lesion model with point prompts. 3 | 4 | Manually refined masks: https://huggingface.co/datasets/wanglab/LUNA25-MedSAM2 5 | image: https://zenodo.org/records/14223624 6 | annotation: https://zenodo.org/records/14673658 7 | """ 8 | 9 | from tqdm import tqdm 10 | import os 11 | from os.path import join 12 | import pandas as pd 13 | import numpy as np 14 | import argparse 15 | 16 | from PIL import Image 17 | import SimpleITK as sitk 18 | import torch 19 | import torch.multiprocessing as mp 20 | from sam2.build_sam import build_sam2_video_predictor_npz 21 | 22 | torch.set_float32_matmul_precision('high') 23 | torch.manual_seed(2024) 24 | torch.cuda.manual_seed(2024) 25 | np.random.seed(2024) 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument( 30 | '--checkpoint', 31 | type=str, 32 | default="./hg_checkpoints/MedSAM2_CTLesion.pt", 33 | help='checkpoint path', 34 | ) 35 | parser.add_argument( 36 | '--cfg', 37 | type=str, 38 | default="configs/sam2.1/sam2.1_hiera_t512.yaml", 39 | help='model config', 40 | ) 41 | parser.add_argument( 42 | '-i', 43 | '--imgs_path', 44 | type=str, 45 | default="/path/to/luna25_images", 46 | help='imgs path', 47 | ) 48 | parser.add_argument( 49 | '-o', 50 | '--pred_save_dir', 51 | type=str, 52 | default="./segs/MedSAM2_release", 53 | help='segs path', 54 | ) 55 | parser.add_argument( 56 | '--num_workers', 57 | type=int, 58 | default=2, 59 | ) 60 | parser.add_argument( 61 | '--df_path', 62 | type=str, 63 | default='/path/to/LUNA25_Public_Training_Development_Data.csv', 64 | ) 65 | 66 | args = parser.parse_args() 67 | imsize = 512 68 | df = pd.read_csv(args.df_path) 69 | df = df[['SeriesInstanceUID', 'CoordX', 'CoordY', 'CoordZ']] 70 | 71 | checkpoint = args.checkpoint 72 | model_cfg = args.cfg 73 | imgs_path = args.imgs_path 74 | pred_save_dir = args.pred_save_dir 75 | num_workers = args.num_workers 76 | predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint) 77 | os.makedirs(pred_save_dir, exist_ok=True) 78 | 79 | 80 | def preprocess(image_data, modality="CT", window_level=-750, window_width=1500): 81 | if modality == "CT": 82 | assert window_level is not None and window_width is not None, "CT modality requires window_level and window_width" 83 | lower_bound = window_level - window_width / 2 84 | upper_bound = window_level + window_width / 2 85 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 86 | image_data_pre = ( 87 | (image_data_pre - np.min(image_data_pre)) 88 | / (np.max(image_data_pre) - np.min(image_data_pre)) 89 | * 255.0 90 | ) 91 | else: 92 | lower_bound, upper_bound = np.percentile( 93 | image_data[image_data > 0], 0.5 94 | ), np.percentile(image_data[image_data > 0], 99.5) 95 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 96 | image_data_pre = ( 97 | (image_data_pre - np.min(image_data_pre)) 98 | / (np.max(image_data_pre) - np.min(image_data_pre)) 99 | * 255.0 100 | ) 101 | image_data_pre[image_data == 0] = 0 102 | 103 | return image_data_pre 104 | 105 | 106 | def resize_grayscale_to_rgb_and_resize(array, image_size): 107 | """ 108 | Resize a 3D grayscale NumPy array to an RGB image and then resize it. 109 | 110 | Parameters: 111 | array (np.ndarray): Input array of shape (d, h, w). 112 | image_size (int): Desired size for the width and height. 113 | 114 | Returns: 115 | np.ndarray: Resized array of shape (d, 3, image_size, image_size). 116 | """ 117 | d, h, w = array.shape 118 | resized_array = np.zeros((d, 3, image_size, image_size)) 119 | 120 | for i in range(d): 121 | img_pil = Image.fromarray(array[i].astype(np.uint8)) 122 | img_rgb = img_pil.convert("RGB") 123 | img_resized = img_rgb.resize((image_size, image_size)) 124 | img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size) 125 | resized_array[i] = img_array 126 | 127 | return resized_array 128 | 129 | 130 | @torch.inference_mode() 131 | def infer_3d(mha_name): 132 | print(f'processing {mha_name}') 133 | # get the corresponding keypoints of mha_name 134 | df_file = df[df['SeriesInstanceUID'] == mha_name.replace('.mha', '')] 135 | 136 | # read and preprocess the image 137 | sitk_img = sitk.ReadImage(join(imgs_path, mha_name)) 138 | img_3D = preprocess(sitk.GetArrayFromImage(sitk_img)) 139 | assert np.max(img_3D) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D)}' 140 | 141 | # initialize segmentation mask 142 | segs_3D = np.zeros(img_3D.shape, dtype=np.uint8) 143 | 144 | # resize and normalize the image 145 | video_height = img_3D.shape[1] 146 | video_width = img_3D.shape[2] 147 | if video_height != imsize or video_width != imsize: 148 | img_resized = resize_grayscale_to_rgb_and_resize(img_3D, imsize) #d, 3, 512, 512 149 | else: 150 | img_resized = img_3D[:,None].repeat(3, axis=1) # d, 3, h, w 151 | img_resized = img_resized / 255.0 152 | img_resized = torch.from_numpy(img_resized).cuda() 153 | img_mean=(0.485, 0.456, 0.406) 154 | img_std=(0.229, 0.224, 0.225) 155 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda() 156 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda() 157 | img_resized -= img_mean 158 | img_resized /= img_std 159 | z_mids = [] 160 | coords = [] 161 | 162 | # for each point in the dataframe, get the corresponding 3D mask using keypoint prompts 163 | for index, (_, row) in enumerate(df_file.iterrows(), 1): 164 | 165 | x = row['CoordX'] 166 | y = row['CoordY'] 167 | z = row['CoordZ'] 168 | # convert the coordinates to voxel coordinates 169 | voxel_x, voxel_y, voxel_z = sitk_img.TransformPhysicalPointToIndex((x, y, z)) 170 | coords.append([voxel_x, voxel_y, voxel_z]) 171 | z_mids.append(voxel_z) 172 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): 173 | inference_state = predictor.init_state(img_resized, video_height, video_width) 174 | 175 | points = np.array([[voxel_x, voxel_y]], dtype=np.float32) 176 | # for labels, `1` means positive click and `0` means negative click 177 | labels = np.array([1], np.int32) 178 | # add point prompt 179 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( 180 | inference_state=inference_state, 181 | frame_idx=voxel_z, 182 | obj_id=1, 183 | points=points, 184 | labels=labels, 185 | ) 186 | mask_prompt = (out_mask_logits[0] > 0.0).squeeze(0).cpu().numpy().astype(np.uint8) 187 | 188 | 189 | frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=voxel_z, obj_id=1, mask=mask_prompt) 190 | segs_3D[voxel_z, ((masks[0] > 0.0).cpu().numpy())[0]] = index 191 | # propagate in the video 192 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=voxel_z, reverse=False): 193 | segs_3D[(out_frame_idx), (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = index 194 | 195 | # reverse process, delete old memory and initialize new predictor 196 | predictor.reset_state(inference_state) 197 | inference_state = predictor.init_state(img_resized, video_height, video_width) 198 | frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=voxel_z, obj_id=1, mask=mask_prompt) 199 | 200 | 201 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=voxel_z, reverse=True): 202 | segs_3D[(out_frame_idx), (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = index 203 | 204 | predictor.reset_state(inference_state) 205 | 206 | sitk.WriteImage(sitk.GetImageFromArray(segs_3D), join(pred_save_dir, mha_name.replace('.mha', '.nii.gz'))) 207 | 208 | return 209 | 210 | 211 | 212 | if __name__ == '__main__': 213 | img_mha_files = os.listdir(imgs_path) 214 | img_mha_files = [x for x in img_mha_files if x.endswith('.mha')] 215 | process_files = list(set(df['SeriesInstanceUID'].values)) 216 | img_mha_files = [x for x in img_mha_files if x.replace('.mha', '') in process_files] 217 | 218 | print(f'number of files to process: {len(img_mha_files)}') 219 | 220 | mp.set_start_method('spawn', force=True) 221 | with mp.Pool(processes=num_workers) as pool: 222 | with tqdm(total=len(img_mha_files)) as pbar: 223 | for i, ret in tqdm(enumerate(pool.imap_unordered(infer_3d, img_mha_files))): 224 | pbar.update() 225 | 226 | 227 | 228 | -------------------------------------------------------------------------------- /gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | *.nii.gz 10 | *.csv 11 | outputs/* 12 | checkpoints/*.pt 13 | *.pt 14 | -------------------------------------------------------------------------------- /medsam2_infer_3D_CT.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | import os 4 | from os.path import join, basename 5 | import re 6 | import matplotlib.pyplot as plt 7 | from collections import OrderedDict 8 | import pandas as pd 9 | import numpy as np 10 | import argparse 11 | 12 | from PIL import Image 13 | import SimpleITK as sitk 14 | import torch 15 | import torch.multiprocessing as mp 16 | from sam2.build_sam import build_sam2_video_predictor_npz 17 | import SimpleITK as sitk 18 | from skimage import measure, morphology 19 | 20 | torch.set_float32_matmul_precision('high') 21 | torch.manual_seed(2024) 22 | torch.cuda.manual_seed(2024) 23 | np.random.seed(2024) 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument( 28 | '--checkpoint', 29 | type=str, 30 | default="checkpoints/MedSAM2_latest.pt", 31 | help='checkpoint path', 32 | ) 33 | parser.add_argument( 34 | '--cfg', 35 | type=str, 36 | default="configs/sam2.1_hiera_t512.yaml", 37 | help='model config', 38 | ) 39 | 40 | parser.add_argument( 41 | '-i', 42 | '--imgs_path', 43 | type=str, 44 | default="CT_DeepLesion/images", 45 | help='imgs path', 46 | ) 47 | parser.add_argument( 48 | '--gts_path', 49 | default=None, 50 | help='simulate prompts based on ground truth', 51 | ) 52 | parser.add_argument( 53 | '-o', 54 | '--pred_save_dir', 55 | type=str, 56 | default="./DeeLesion_results", 57 | help='path to save segmentation results', 58 | ) 59 | # add option to propagate with either box or mask 60 | parser.add_argument( 61 | '--propagate_with_box', 62 | default=True, 63 | action='store_true', 64 | help='whether to propagate with box' 65 | ) 66 | 67 | args = parser.parse_args() 68 | checkpoint = args.checkpoint 69 | model_cfg = args.cfg 70 | imgs_path = args.imgs_path 71 | gts_path = args.gts_path 72 | pred_save_dir = args.pred_save_dir 73 | os.makedirs(pred_save_dir, exist_ok=True) 74 | propagate_with_box = args.propagate_with_box 75 | 76 | def getLargestCC(segmentation): 77 | labels = measure.label(segmentation) 78 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 79 | return largestCC 80 | 81 | def dice_multi_class(preds, targets): 82 | smooth = 1.0 83 | assert preds.shape == targets.shape 84 | labels = np.unique(targets)[1:] 85 | dices = [] 86 | for label in labels: 87 | pred = preds == label 88 | target = targets == label 89 | intersection = (pred * target).sum() 90 | dices.append((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth)) 91 | return np.mean(dices) 92 | 93 | def show_mask(mask, ax, mask_color=None, alpha=0.5): 94 | """ 95 | show mask on the image 96 | 97 | Parameters 98 | ---------- 99 | mask : numpy.ndarray 100 | mask of the image 101 | ax : matplotlib.axes.Axes 102 | axes to plot the mask 103 | mask_color : numpy.ndarray 104 | color of the mask 105 | alpha : float 106 | transparency of the mask 107 | """ 108 | if mask_color is not None: 109 | color = np.concatenate([mask_color, np.array([alpha])], axis=0) 110 | else: 111 | color = np.array([251/255, 252/255, 30/255, alpha]) 112 | h, w = mask.shape[-2:] 113 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 114 | ax.imshow(mask_image) 115 | 116 | 117 | def show_box(box, ax, edgecolor='blue'): 118 | """ 119 | show bounding box on the image 120 | 121 | Parameters 122 | ---------- 123 | box : numpy.ndarray 124 | bounding box coordinates in the original image 125 | ax : matplotlib.axes.Axes 126 | axes to plot the bounding box 127 | edgecolor : str 128 | color of the bounding box 129 | """ 130 | x0, y0 = box[0], box[1] 131 | w, h = box[2] - box[0], box[3] - box[1] 132 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2)) 133 | 134 | 135 | def resize_grayscale_to_rgb_and_resize(array, image_size): 136 | """ 137 | Resize a 3D grayscale NumPy array to an RGB image and then resize it. 138 | 139 | Parameters: 140 | array (np.ndarray): Input array of shape (d, h, w). 141 | image_size (int): Desired size for the width and height. 142 | 143 | Returns: 144 | np.ndarray: Resized array of shape (d, 3, image_size, image_size). 145 | """ 146 | d, h, w = array.shape 147 | resized_array = np.zeros((d, 3, image_size, image_size)) 148 | 149 | for i in range(d): 150 | img_pil = Image.fromarray(array[i].astype(np.uint8)) 151 | img_rgb = img_pil.convert("RGB") 152 | img_resized = img_rgb.resize((image_size, image_size)) 153 | img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size) 154 | resized_array[i] = img_array 155 | 156 | return resized_array 157 | 158 | def mask2D_to_bbox(gt2D, max_shift=20): 159 | y_indices, x_indices = np.where(gt2D > 0) 160 | x_min, x_max = np.min(x_indices), np.max(x_indices) 161 | y_min, y_max = np.min(y_indices), np.max(y_indices) 162 | H, W = gt2D.shape 163 | bbox_shift = np.random.randint(0, max_shift + 1, 1)[0] 164 | x_min = max(0, x_min - bbox_shift) 165 | x_max = min(W-1, x_max + bbox_shift) 166 | y_min = max(0, y_min - bbox_shift) 167 | y_max = min(H-1, y_max + bbox_shift) 168 | boxes = np.array([x_min, y_min, x_max, y_max]) 169 | return boxes 170 | 171 | def mask3D_to_bbox(gt3D, max_shift=20): 172 | z_indices, y_indices, x_indices = np.where(gt3D > 0) 173 | x_min, x_max = np.min(x_indices), np.max(x_indices) 174 | y_min, y_max = np.min(y_indices), np.max(y_indices) 175 | z_min, z_max = np.min(z_indices), np.max(z_indices) 176 | D, H, W = gt3D.shape 177 | bbox_shift = np.random.randint(0, max_shift + 1, 1)[0] 178 | x_min = max(0, x_min - bbox_shift) 179 | x_max = min(W-1, x_max + bbox_shift) 180 | y_min = max(0, y_min - bbox_shift) 181 | y_max = min(H-1, y_max + bbox_shift) 182 | z_min = max(0, z_min) 183 | z_max = min(D-1, z_max) 184 | boxes3d = np.array([x_min, y_min, z_min, x_max, y_max, z_max]) 185 | return boxes3d 186 | 187 | 188 | DL_info = pd.read_csv('CT_DeepLesion/DeepLesion_Dataset_Info.csv') 189 | nii_fnames = sorted(os.listdir(imgs_path)) 190 | nii_fnames = [i for i in nii_fnames if i.endswith('.nii.gz')] 191 | nii_fnames = [i for i in nii_fnames if not i.startswith('._')] 192 | print(f'Processing {len(nii_fnames)} nii files') 193 | seg_info = OrderedDict() 194 | seg_info['nii_name'] = [] 195 | seg_info['key_slice_index'] = [] 196 | seg_info['DICOM_windows'] = [] 197 | # initialized predictor 198 | predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint) 199 | 200 | for nii_fname in tqdm(nii_fnames): 201 | # get corresponding case info 202 | range_suffix = re.findall(r'\d{3}-\d{3}', nii_fname)[0] 203 | slice_range = range_suffix.split('-') 204 | slice_range = [str(int(s)) for s in slice_range] 205 | slice_range = ', '.join(slice_range) 206 | nii_image = sitk.ReadImage(join(imgs_path, nii_fname)) 207 | nii_image_data = sitk.GetArrayFromImage(nii_image) 208 | 209 | case_name = re.findall(r'^(\d{6}_\d{2}_\d{2})', nii_fname)[0] 210 | case_df = DL_info[ 211 | DL_info['File_name'].str.contains(case_name) & 212 | DL_info['Slice_range'].str.contains(slice_range) 213 | ].copy() 214 | 215 | segs_3D = np.zeros(nii_image_data.shape, dtype=np.uint8) 216 | 217 | for row_id, row in case_df.iterrows(): 218 | # print(f'Processing {case_name} tumor {tumor_idx}') 219 | # get the key slice info 220 | lower_bound, upper_bound = row['DICOM_windows'].split(',') 221 | lower_bound, upper_bound = float(lower_bound), float(upper_bound) 222 | nii_image_data_pre = np.clip(nii_image_data, lower_bound, upper_bound) 223 | nii_image_data_pre = (nii_image_data_pre - np.min(nii_image_data_pre))/(np.max(nii_image_data_pre)-np.min(nii_image_data_pre))*255.0 224 | nii_image_data_pre = np.uint8(nii_image_data_pre) 225 | key_slice_idx = row['Key_slice_index'] 226 | key_slice_idx = int(key_slice_idx) 227 | slice_range = row['Slice_range'] 228 | slice_idx_start, slice_idx_end = slice_range.split(',') 229 | slice_idx_start, slice_idx_end = int(slice_idx_start), int(slice_idx_end) 230 | bbox_coords = row['Bounding_boxes'] 231 | bbox_coords = bbox_coords.split(',') 232 | bbox_coords = [int(float(coord)) for coord in bbox_coords] 233 | #bbox_coords = expand_box(bbox_coords) 234 | bbox = np.array(bbox_coords) # y_min, x_min, y_max, x_max 235 | bbox = np.array([bbox[1], bbox[0], bbox[3], bbox[2]]) 236 | 237 | key_slice_idx_offset = key_slice_idx - slice_idx_start 238 | key_slice_img = nii_image_data_pre[key_slice_idx_offset, :,:] 239 | 240 | img_3D_ori = nii_image_data_pre 241 | assert np.max(img_3D_ori) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D_ori)}' 242 | 243 | video_height = key_slice_img.shape[0] 244 | video_width = key_slice_img.shape[1] 245 | img_resized = resize_grayscale_to_rgb_and_resize(img_3D_ori, 512) 246 | img_resized = img_resized / 255.0 247 | img_resized = torch.from_numpy(img_resized).cuda() 248 | img_mean=(0.485, 0.456, 0.406) 249 | img_std=(0.229, 0.224, 0.225) 250 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda() 251 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda() 252 | img_resized -= img_mean 253 | img_resized /= img_std 254 | z_mids = [] 255 | 256 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): 257 | inference_state = predictor.init_state(img_resized, video_height, video_width) 258 | if propagate_with_box: 259 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( 260 | inference_state=inference_state, 261 | frame_idx=key_slice_idx_offset, 262 | obj_id=1, 263 | box=bbox, 264 | ) 265 | else: # gt 266 | pass 267 | 268 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): 269 | segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1 270 | predictor.reset_state(inference_state) 271 | if propagate_with_box: 272 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( 273 | inference_state=inference_state, 274 | frame_idx=key_slice_idx_offset, 275 | obj_id=1, 276 | box=bbox, 277 | ) 278 | else: # gt 279 | pass 280 | 281 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True): 282 | segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1 283 | predictor.reset_state(inference_state) 284 | if np.max(segs_3D) > 0: 285 | segs_3D = getLargestCC(segs_3D) 286 | segs_3D = np.uint8(segs_3D) 287 | sitk_image = sitk.GetImageFromArray(img_3D_ori) 288 | sitk_image.CopyInformation(nii_image) 289 | sitk_mask = sitk.GetImageFromArray(segs_3D) 290 | sitk_mask.CopyInformation(nii_image) 291 | # save single lesion 292 | key_slice_idx = row['Key_slice_index'] 293 | save_seg_name = nii_fname.split('.nii.gz')[0] + f'_k{key_slice_idx}_mask.nii.gz' 294 | sitk.WriteImage(sitk_image, os.path.join(pred_save_dir, nii_fname.replace('.nii.gz', '_img.nii.gz'))) 295 | sitk.WriteImage(sitk_mask, os.path.join(pred_save_dir, save_seg_name)) 296 | seg_info['nii_name'].append(save_seg_name) 297 | seg_info['key_slice_index'].append(key_slice_idx) 298 | seg_info['DICOM_windows'].append(row['DICOM_windows']) 299 | 300 | seg_info_df = pd.DataFrame(seg_info) 301 | seg_info_df.to_csv(join(pred_save_dir, 'tiny_seg_info202412.csv'), index=False) 302 | 303 | 304 | 305 | -------------------------------------------------------------------------------- /multi_node_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 7-00:0:0 3 | #SBATCH -J medsam2-tr-tiny 4 | #SBATCH --mem=450G 5 | #SBATCH -c 60 6 | #SBATCH -N 3 7 | #SBATCH --ntasks-per-node=1 8 | #SBATCH --gres=gpu:4 9 | #SBATCH -o out_mnodes_tiny.out 10 | 11 | export PATH=/usr/local/cuda/bin:$PATH 12 | timestamp=$(date +"%Y%m%d-%H%M") 13 | 14 | # Set the master node address (first node in the allocation) 15 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 16 | # export MASTER_PORT=29500 17 | export MASTER_PORT=$(python - <=61.0", 4 | "torch>=2.5.1", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /sam2/_C.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/_C.so -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /sam2/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/__pycache__/build_sam.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/build_sam.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/__pycache__/sam2_image_predictor.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/sam2_image_predictor.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | HF_MODEL_ID_TO_FILENAMES = { 15 | "facebook/sam2-hiera-tiny": ( 16 | "configs/sam2/sam2_hiera_t.yaml", 17 | "sam2_hiera_tiny.pt", 18 | ), 19 | "facebook/sam2-hiera-small": ( 20 | "configs/sam2/sam2_hiera_s.yaml", 21 | "sam2_hiera_small.pt", 22 | ), 23 | "facebook/sam2-hiera-base-plus": ( 24 | "configs/sam2/sam2_hiera_b+.yaml", 25 | "sam2_hiera_base_plus.pt", 26 | ), 27 | "facebook/sam2-hiera-large": ( 28 | "configs/sam2/sam2_hiera_l.yaml", 29 | "sam2_hiera_large.pt", 30 | ), 31 | "facebook/sam2.1-hiera-tiny": ( 32 | "configs/sam2.1/sam2.1_hiera_t.yaml", 33 | "sam2.1_hiera_tiny.pt", 34 | ), 35 | "facebook/sam2.1-hiera-small": ( 36 | "configs/sam2.1/sam2.1_hiera_s.yaml", 37 | "sam2.1_hiera_small.pt", 38 | ), 39 | "facebook/sam2.1-hiera-base-plus": ( 40 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 41 | "sam2.1_hiera_base_plus.pt", 42 | ), 43 | "facebook/sam2.1-hiera-large": ( 44 | "configs/sam2.1/sam2.1_hiera_l.yaml", 45 | "sam2.1_hiera_large.pt", 46 | ), 47 | } 48 | 49 | 50 | def get_best_available_device(): 51 | """ 52 | Get the best available device in the order: CUDA, MPS, CPU 53 | Returns: device string for torch.device 54 | """ 55 | if torch.cuda.is_available(): 56 | return "cuda" 57 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 58 | return "mps" 59 | else: 60 | return "cpu" 61 | 62 | 63 | def build_sam2( 64 | config_file, 65 | ckpt_path=None, 66 | device=None, 67 | mode="eval", 68 | hydra_overrides_extra=[], 69 | apply_postprocessing=True, 70 | **kwargs, 71 | ): 72 | # Use the provided device or get the best available one 73 | device = device or get_best_available_device() 74 | logging.info(f"Using device: {device}") 75 | 76 | if apply_postprocessing: 77 | hydra_overrides_extra = hydra_overrides_extra.copy() 78 | hydra_overrides_extra += [ 79 | # dynamically fall back to multi-mask if the single mask is not stable 80 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 81 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 82 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 83 | ] 84 | # Read config and init model 85 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 86 | OmegaConf.resolve(cfg) 87 | model = instantiate(cfg.model, _recursive_=True) 88 | _load_checkpoint(model, ckpt_path) 89 | model = model.to(device) 90 | if mode == "eval": 91 | model.eval() 92 | return model 93 | 94 | 95 | def build_sam2_video_predictor( 96 | config_file, 97 | ckpt_path=None, 98 | device=None, 99 | mode="eval", 100 | hydra_overrides_extra=[], 101 | apply_postprocessing=True, 102 | **kwargs, 103 | ): 104 | # Use the provided device or get the best available one 105 | device = device or get_best_available_device() 106 | logging.info(f"Using device: {device}") 107 | 108 | hydra_overrides = [ 109 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 110 | ] 111 | if apply_postprocessing: 112 | hydra_overrides_extra = hydra_overrides_extra.copy() 113 | hydra_overrides_extra += [ 114 | # dynamically fall back to multi-mask if the single mask is not stable 115 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 116 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 117 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 118 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 119 | "++model.binarize_mask_from_pts_for_mem_enc=true", 120 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 121 | "++model.fill_hole_area=8", 122 | ] 123 | hydra_overrides.extend(hydra_overrides_extra) 124 | 125 | # Read config and init model 126 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 127 | OmegaConf.resolve(cfg) 128 | model = instantiate(cfg.model, _recursive_=True) 129 | _load_checkpoint(model, ckpt_path) 130 | model = model.to(device) 131 | if mode == "eval": 132 | model.eval() 133 | return model 134 | 135 | def build_sam2_video_predictor_npz( 136 | config_file, 137 | ckpt_path=None, 138 | device=None, 139 | mode="eval", 140 | hydra_overrides_extra=[], 141 | apply_postprocessing=True, 142 | **kwargs, 143 | ): 144 | # Use the provided device or get the best available one 145 | device = device or get_best_available_device() 146 | logging.info(f"Using device: {device}") 147 | 148 | hydra_overrides = [ 149 | "++model._target_=sam2.sam2_video_predictor_npz.SAM2VideoPredictorNPZ", 150 | ] 151 | if apply_postprocessing: 152 | hydra_overrides_extra = hydra_overrides_extra.copy() 153 | hydra_overrides_extra += [ 154 | # dynamically fall back to multi-mask if the single mask is not stable 155 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 156 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 157 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 158 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 159 | "++model.binarize_mask_from_pts_for_mem_enc=true", 160 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 161 | "++model.fill_hole_area=8", 162 | ] 163 | hydra_overrides.extend(hydra_overrides_extra) 164 | 165 | # Read config and init model 166 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 167 | OmegaConf.resolve(cfg) 168 | model = instantiate(cfg.model, _recursive_=True) 169 | _load_checkpoint(model, ckpt_path) 170 | model = model.to(device) 171 | if mode == "eval": 172 | model.eval() 173 | return model 174 | 175 | 176 | 177 | def _hf_download(model_id): 178 | from huggingface_hub import hf_hub_download 179 | 180 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 181 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 182 | return config_name, ckpt_path 183 | 184 | 185 | def build_sam2_hf(model_id, **kwargs): 186 | config_name, ckpt_path = _hf_download(model_id) 187 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 188 | 189 | 190 | def build_sam2_video_predictor_hf(model_id, **kwargs): 191 | config_name, ckpt_path = _hf_download(model_id) 192 | return build_sam2_video_predictor( 193 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 194 | ) 195 | 196 | 197 | def _load_checkpoint(model, ckpt_path): 198 | if ckpt_path is not None: 199 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 200 | missing_keys, unexpected_keys = model.load_state_dict(sd) 201 | if missing_keys: 202 | logging.error(missing_keys) 203 | raise RuntimeError() 204 | if unexpected_keys: 205 | logging.error(unexpected_keys) 206 | raise RuntimeError() 207 | logging.info("Loaded checkpoint sucessfully") 208 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1_hiera_t512.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 512 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be an even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be an even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/memory_attention.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/memory_attention.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/position_encoding.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/position_encoding.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/sam2_base.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/sam2_base.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from functools import partial 9 | from typing import List, Tuple, Union 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from iopath.common.file_io import g_pathmgr 15 | 16 | from sam2.modeling.backbones.utils import ( 17 | PatchEmbed, 18 | window_partition, 19 | window_unpartition, 20 | ) 21 | 22 | from sam2.modeling.sam2_utils import DropPath, MLP 23 | 24 | 25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 26 | if pool is None: 27 | return x 28 | # (B, H, W, C) -> (B, C, H, W) 29 | x = x.permute(0, 3, 1, 2) 30 | x = pool(x) 31 | # (B, C, H', W') -> (B, H', W', C) 32 | x = x.permute(0, 2, 3, 1) 33 | if norm: 34 | x = norm(x) 35 | 36 | return x 37 | 38 | 39 | class MultiScaleAttention(nn.Module): 40 | def __init__( 41 | self, 42 | dim: int, 43 | dim_out: int, 44 | num_heads: int, 45 | q_pool: nn.Module = None, 46 | ): 47 | super().__init__() 48 | 49 | self.dim = dim 50 | self.dim_out = dim_out 51 | self.num_heads = num_heads 52 | self.q_pool = q_pool 53 | self.qkv = nn.Linear(dim, dim_out * 3) 54 | self.proj = nn.Linear(dim_out, dim_out) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | B, H, W, _ = x.shape 58 | # qkv with shape (B, H * W, 3, nHead, C) 59 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 60 | # q, k, v with shape (B, H * W, nheads, C) 61 | q, k, v = torch.unbind(qkv, 2) 62 | 63 | # Q pooling (for downsample at stage changes) 64 | if self.q_pool: 65 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 66 | H, W = q.shape[1:3] # downsampled shape 67 | q = q.reshape(B, H * W, self.num_heads, -1) 68 | 69 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 70 | x = F.scaled_dot_product_attention( 71 | q.transpose(1, 2), 72 | k.transpose(1, 2), 73 | v.transpose(1, 2), 74 | ) 75 | # Transpose back 76 | x = x.transpose(1, 2) 77 | x = x.reshape(B, H, W, -1) 78 | 79 | x = self.proj(x) 80 | 81 | return x 82 | 83 | 84 | class MultiScaleBlock(nn.Module): 85 | def __init__( 86 | self, 87 | dim: int, 88 | dim_out: int, 89 | num_heads: int, 90 | mlp_ratio: float = 4.0, 91 | drop_path: float = 0.0, 92 | norm_layer: Union[nn.Module, str] = "LayerNorm", 93 | q_stride: Tuple[int, int] = None, 94 | act_layer: nn.Module = nn.GELU, 95 | window_size: int = 0, 96 | ): 97 | super().__init__() 98 | 99 | if isinstance(norm_layer, str): 100 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 101 | 102 | self.dim = dim 103 | self.dim_out = dim_out 104 | self.norm1 = norm_layer(dim) 105 | 106 | self.window_size = window_size 107 | 108 | self.pool, self.q_stride = None, q_stride 109 | if self.q_stride: 110 | self.pool = nn.MaxPool2d( 111 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 112 | ) 113 | 114 | self.attn = MultiScaleAttention( 115 | dim, 116 | dim_out, 117 | num_heads=num_heads, 118 | q_pool=self.pool, 119 | ) 120 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 121 | 122 | self.norm2 = norm_layer(dim_out) 123 | self.mlp = MLP( 124 | dim_out, 125 | int(dim_out * mlp_ratio), 126 | dim_out, 127 | num_layers=2, 128 | activation=act_layer, 129 | ) 130 | 131 | if dim != dim_out: 132 | self.proj = nn.Linear(dim, dim_out) 133 | 134 | def forward(self, x: torch.Tensor) -> torch.Tensor: 135 | shortcut = x # B, H, W, C 136 | x = self.norm1(x) 137 | 138 | # Skip connection 139 | if self.dim != self.dim_out: 140 | shortcut = do_pool(self.proj(x), self.pool) 141 | 142 | # Window partition 143 | window_size = self.window_size 144 | if window_size > 0: 145 | H, W = x.shape[1], x.shape[2] 146 | x, pad_hw = window_partition(x, window_size) 147 | 148 | # Window Attention + Q Pooling (if stage change) 149 | x = self.attn(x) 150 | if self.q_stride: 151 | # Shapes have changed due to Q pooling 152 | window_size = self.window_size // self.q_stride[0] 153 | H, W = shortcut.shape[1:3] 154 | 155 | pad_h = (window_size - H % window_size) % window_size 156 | pad_w = (window_size - W % window_size) % window_size 157 | pad_hw = (H + pad_h, W + pad_w) 158 | 159 | # Reverse window partition 160 | if self.window_size > 0: 161 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 162 | 163 | x = shortcut + self.drop_path(x) 164 | # MLP 165 | x = x + self.drop_path(self.mlp(self.norm2(x))) 166 | return x 167 | 168 | 169 | class Hiera(nn.Module): 170 | """ 171 | Reference: https://arxiv.org/abs/2306.00989 172 | """ 173 | 174 | def __init__( 175 | self, 176 | embed_dim: int = 96, # initial embed dim 177 | num_heads: int = 1, # initial number of heads 178 | drop_path_rate: float = 0.0, # stochastic depth 179 | q_pool: int = 3, # number of q_pool stages 180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 182 | dim_mul: float = 2.0, # dim_mul factor at stage shift 183 | head_mul: float = 2.0, # head_mul factor at stage shift 184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 185 | # window size per stage, when not using global att. 186 | window_spec: Tuple[int, ...] = ( 187 | 8, 188 | 4, 189 | 14, 190 | 7, 191 | ), 192 | # global attn in these blocks 193 | global_att_blocks: Tuple[int, ...] = ( 194 | 12, 195 | 16, 196 | 20, 197 | ), 198 | weights_path=None, 199 | return_interm_layers=True, # return feats from every stage 200 | ): 201 | super().__init__() 202 | 203 | assert len(stages) == len(window_spec) 204 | self.window_spec = window_spec 205 | 206 | depth = sum(stages) 207 | self.q_stride = q_stride 208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 209 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 211 | self.return_interm_layers = return_interm_layers 212 | 213 | self.patch_embed = PatchEmbed( 214 | embed_dim=embed_dim, 215 | ) 216 | # Which blocks have global att? 217 | self.global_att_blocks = global_att_blocks 218 | 219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 221 | self.pos_embed = nn.Parameter( 222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 223 | ) 224 | self.pos_embed_window = nn.Parameter( 225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 226 | ) 227 | 228 | dpr = [ 229 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 230 | ] # stochastic depth decay rule 231 | 232 | cur_stage = 1 233 | self.blocks = nn.ModuleList() 234 | 235 | for i in range(depth): 236 | dim_out = embed_dim 237 | # lags by a block, so first block of 238 | # next stage uses an initial window size 239 | # of previous stage and final window size of current stage 240 | window_size = self.window_spec[cur_stage - 1] 241 | 242 | if self.global_att_blocks is not None: 243 | window_size = 0 if i in self.global_att_blocks else window_size 244 | 245 | if i - 1 in self.stage_ends: 246 | dim_out = int(embed_dim * dim_mul) 247 | num_heads = int(num_heads * head_mul) 248 | cur_stage += 1 249 | 250 | block = MultiScaleBlock( 251 | dim=embed_dim, 252 | dim_out=dim_out, 253 | num_heads=num_heads, 254 | drop_path=dpr[i], 255 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 256 | window_size=window_size, 257 | ) 258 | 259 | embed_dim = dim_out 260 | self.blocks.append(block) 261 | 262 | self.channel_list = ( 263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 264 | if return_interm_layers 265 | else [self.blocks[-1].dim_out] 266 | ) 267 | 268 | if weights_path is not None: 269 | with g_pathmgr.open(weights_path, "rb") as f: 270 | chkpt = torch.load(f, map_location="cpu") 271 | logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) 272 | 273 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 274 | h, w = hw 275 | window_embed = self.pos_embed_window 276 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 277 | pos_embed = pos_embed + window_embed.tile( 278 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 279 | ) 280 | pos_embed = pos_embed.permute(0, 2, 3, 1) 281 | return pos_embed 282 | 283 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 284 | x = self.patch_embed(x) 285 | # x: (B, H, W, C) 286 | 287 | # Add pos embed 288 | x = x + self._get_pos_embed(x.shape[1:3]) 289 | 290 | outputs = [] 291 | for i, blk in enumerate(self.blocks): 292 | x = blk(x) 293 | if (i == self.stage_ends[-1]) or ( 294 | i in self.stage_ends and self.return_interm_layers 295 | ): 296 | feats = x.permute(0, 3, 1, 2) 297 | outputs.append(feats) 298 | 299 | return outputs 300 | 301 | def get_layer_id(self, layer_name): 302 | # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 303 | num_layers = self.get_num_layers() 304 | 305 | if layer_name.find("rel_pos") != -1: 306 | return num_layers + 1 307 | elif layer_name.find("pos_embed") != -1: 308 | return 0 309 | elif layer_name.find("patch_embed") != -1: 310 | return 0 311 | elif layer_name.find("blocks") != -1: 312 | return int(layer_name.split("blocks")[1].split(".")[1]) + 1 313 | else: 314 | return num_layers + 1 315 | 316 | def get_num_layers(self) -> int: 317 | return len(self.blocks) 318 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | self.d_model = d_model 75 | for dim in backbone_channel_list: 76 | current = nn.Sequential() 77 | current.add_module( 78 | "conv", 79 | nn.Conv2d( 80 | in_channels=dim, 81 | out_channels=d_model, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | ), 86 | ) 87 | 88 | self.convs.append(current) 89 | self.fpn_interp_model = fpn_interp_model 90 | assert fuse_type in ["sum", "avg"] 91 | self.fuse_type = fuse_type 92 | 93 | # levels to have top-down features in its outputs 94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 95 | # have top-down propagation, while outputs of level 0 and level 1 have only 96 | # lateral features from the same backbone level. 97 | if fpn_top_down_levels is None: 98 | # default is to have top-down features on all levels 99 | fpn_top_down_levels = range(len(self.convs)) 100 | self.fpn_top_down_levels = list(fpn_top_down_levels) 101 | 102 | def forward(self, xs: List[torch.Tensor]): 103 | 104 | out = [None] * len(self.convs) 105 | pos = [None] * len(self.convs) 106 | assert len(xs) == len(self.convs) 107 | # fpn forward pass 108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 109 | prev_features = None 110 | # forward in top-down order (from low to high resolution) 111 | n = len(self.convs) - 1 112 | for i in range(n, -1, -1): 113 | x = xs[i] 114 | lateral_features = self.convs[n - i](x) 115 | if i in self.fpn_top_down_levels and prev_features is not None: 116 | top_down_features = F.interpolate( 117 | prev_features.to(dtype=torch.float32), 118 | scale_factor=2.0, 119 | mode=self.fpn_interp_model, 120 | align_corners=( 121 | None if self.fpn_interp_model == "nearest" else False 122 | ), 123 | antialias=False, 124 | ) 125 | prev_features = lateral_features + top_down_features 126 | if self.fuse_type == "avg": 127 | prev_features /= 2 128 | else: 129 | prev_features = lateral_features 130 | x_out = prev_features 131 | out[i] = x_out 132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 133 | 134 | return out, pos 135 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention Is All You Need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | ): 29 | super().__init__() 30 | assert num_pos_feats % 2 == 0, "Expecting even model width" 31 | self.num_pos_feats = num_pos_feats // 2 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | self.cache = {} 41 | 42 | def _encode_xy(self, x, y): 43 | # The positions are expected to be normalized 44 | assert len(x) == len(y) and x.ndim == y.ndim == 1 45 | x_embed = x * self.scale 46 | y_embed = y * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, None] / dim_t 52 | pos_y = y_embed[:, None] / dim_t 53 | pos_x = torch.stack( 54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 55 | ).flatten(1) 56 | pos_y = torch.stack( 57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 58 | ).flatten(1) 59 | return pos_x, pos_y 60 | 61 | @torch.no_grad() 62 | def encode_boxes(self, x, y, w, h): 63 | pos_x, pos_y = self._encode_xy(x, y) 64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 65 | return pos 66 | 67 | encode = encode_boxes # Backwards compatibility 68 | 69 | @torch.no_grad() 70 | def encode_points(self, x, y, labels): 71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 72 | assert bx == by and nx == ny and bx == bl and nx == nl 73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 76 | return pos 77 | 78 | @torch.no_grad() 79 | def forward(self, x: torch.Tensor): 80 | cache_key = (x.shape[-2], x.shape[-1]) 81 | if cache_key in self.cache: 82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 83 | y_embed = ( 84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 85 | .view(1, -1, 1) 86 | .repeat(x.shape[0], 1, x.shape[-1]) 87 | ) 88 | x_embed = ( 89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 90 | .view(1, 1, -1) 91 | .repeat(x.shape[0], x.shape[-2], 1) 92 | ) 93 | 94 | if self.normalize: 95 | eps = 1e-6 96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 98 | 99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 101 | 102 | pos_x = x_embed[:, :, :, None] / dim_t 103 | pos_y = y_embed[:, :, :, None] / dim_t 104 | pos_x = torch.stack( 105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 106 | ).flatten(3) 107 | pos_y = torch.stack( 108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 109 | ).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | self.cache[cache_key] = pos[0] 112 | return pos 113 | 114 | 115 | class PositionEmbeddingRandom(nn.Module): 116 | """ 117 | Positional encoding using random spatial frequencies. 118 | """ 119 | 120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 121 | super().__init__() 122 | if scale is None or scale <= 0.0: 123 | scale = 1.0 124 | self.register_buffer( 125 | "positional_encoding_gaussian_matrix", 126 | scale * torch.randn((2, num_pos_feats)), 127 | ) 128 | 129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 130 | """Positionally encode points that are normalized to [0,1].""" 131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 132 | coords = 2 * coords - 1 133 | coords = coords @ self.positional_encoding_gaussian_matrix 134 | coords = 2 * np.pi * coords 135 | # outputs d_1 x ... x d_n x C shape 136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 137 | 138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 139 | """Generate positional encoding for a grid of the specified size.""" 140 | h, w = size 141 | device: Any = self.positional_encoding_gaussian_matrix.device 142 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 143 | y_embed = grid.cumsum(dim=0) - 0.5 144 | x_embed = grid.cumsum(dim=1) - 0.5 145 | y_embed = y_embed / h 146 | x_embed = x_embed / w 147 | 148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 149 | return pe.permute(2, 0, 1) # C x H x W 150 | 151 | def forward_with_coords( 152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 153 | ) -> torch.Tensor: 154 | """Positionally encode points that are not normalized to [0,1].""" 155 | coords = coords_input.clone() 156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 159 | 160 | 161 | # Rotary Positional Encoding, adapted from: 162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 163 | # 2. https://github.com/naver-ai/rope-vit 164 | # 3. https://github.com/lucidrains/rotary-embedding-torch 165 | 166 | 167 | def init_t_xy(end_x: int, end_y: int): 168 | t = torch.arange(end_x * end_y, dtype=torch.float32) 169 | t_x = (t % end_x).float() 170 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 171 | return t_x, t_y 172 | 173 | 174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 177 | 178 | t_x, t_y = init_t_xy(end_x, end_y) 179 | freqs_x = torch.outer(t_x, freqs_x) 180 | freqs_y = torch.outer(t_y, freqs_y) 181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 184 | 185 | 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | ndim = x.ndim 188 | assert 0 <= 1 < ndim 189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 191 | return freqs_cis.view(*shape) 192 | 193 | 194 | def apply_rotary_enc( 195 | xq: torch.Tensor, 196 | xk: torch.Tensor, 197 | freqs_cis: torch.Tensor, 198 | repeat_freqs_k: bool = False, 199 | ): 200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 201 | xk_ = ( 202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 203 | if xk.shape[-2] != 0 204 | else None 205 | ) 206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 208 | if xk_ is None: 209 | # no keys to rotate, due to dropout 210 | return xq_out.type_as(xq).to(xq.device), xk 211 | # repeat freqs along seq_len dim to match k seq_len 212 | if repeat_freqs_k: 213 | r = xk_.shape[-2] // xq_.shape[-2] 214 | if freqs_cis.is_cuda: 215 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 216 | else: 217 | # torch.repeat on complex numbers may not be supported on non-CUDA devices 218 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten 219 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) 220 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 221 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 222 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | """ 147 | Embeds different types of prompts, returning both sparse and dense 148 | embeddings. 149 | 150 | Arguments: 151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 152 | and labels to embed. 153 | boxes (torch.Tensor or none): boxes to embed 154 | masks (torch.Tensor or none): masks to embed 155 | 156 | Returns: 157 | torch.Tensor: sparse embeddings for the points and boxes, with shape 158 | BxNx(embed_dim), where N is determined by the number of input points 159 | and boxes. 160 | torch.Tensor: dense embeddings for the masks, in the shape 161 | Bx(embed_dim)x(embed_H)x(embed_W) 162 | """ 163 | bs = self._get_batch_size(points, boxes, masks) 164 | sparse_embeddings = torch.empty( 165 | (bs, 0, self.embed_dim), device=self._get_device() 166 | ) 167 | if points is not None: 168 | coords, labels = points 169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 171 | if boxes is not None: 172 | box_embeddings = self._embed_boxes(boxes) 173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 174 | 175 | if masks is not None: 176 | dense_embeddings = self._embed_masks(masks) 177 | else: 178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 180 | ) 181 | 182 | return sparse_embeddings, dense_embeddings 183 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/utils/__pycache__/misc.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/misc.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/utils/__pycache__/transforms.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/transforms.cpython-312.pyc -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | # Package metadata 11 | NAME = "MedSAM2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "MedSAM2 was adapted from SAM2 (https://github.com/facebookresearch/sam2) for medical image segmentation." 14 | URL = "https://github.com/bowang-lab/MedSAM2" 15 | AUTHOR = "WangLab" 16 | AUTHOR_EMAIL = "medseg20s@gmail.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r", encoding="utf-8") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.5.1", 26 | "torchvision>=0.20.1", 27 | "numpy>=2.0.1", 28 | "tqdm>=4.66.5", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=10.4.0", 32 | "SimpleITK>=2.4.0", 33 | ] 34 | 35 | EXTRA_PACKAGES = { 36 | "notebooks": [ 37 | "matplotlib>=3.9.1", 38 | "jupyter>=1.0.0", 39 | "opencv-python>=4.10.0", 40 | "eva-decord>=0.6.1", 41 | ], 42 | "interactive-demo": [ 43 | "Flask>=3.0.3", 44 | "Flask-Cors>=5.0.0", 45 | "av>=13.0.0", 46 | "dataclasses-json>=0.6.7", 47 | "eva-decord>=0.6.1", 48 | "gunicorn>=23.0.0", 49 | "imagesize>=1.4.1", 50 | "pycocotools>=2.0.8", 51 | "strawberry-graphql>=0.239.2", 52 | ], 53 | "dev": [ 54 | "matplotlib>=3.9.1", 55 | "jupyter>=1.0.0", 56 | "black==24.2.0", 57 | "usort==1.0.2", 58 | "ufmt==2.0.0b2", 59 | "fvcore>=0.1.5.post20221221", 60 | "pandas>=2.2.3", 61 | "scikit-image>=0.24.0", 62 | "tensorboard>=2.17.0", 63 | "pycocotools>=2.0.8", 64 | "tensordict>=0.5.0", 65 | "opencv-python>=4.10.0", 66 | "submitit>=1.5.1", 67 | ], 68 | } 69 | 70 | # By default, we also build the SAM 2 CUDA extension. 71 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. 72 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" 73 | # By default, we allow SAM 2 installation to proceed even with build errors. 74 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. 75 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" 76 | 77 | # Catch and skip errors during extension building and print a warning message 78 | # (note that this message only shows up under verbose build mode 79 | # "pip install -v -e ." or "python setup.py build_ext -v") 80 | CUDA_ERROR_MSG = ( 81 | "{}\n\n" 82 | "Failed to build the SAM 2 CUDA extension due to the error above. " 83 | "You can still use SAM 2 and it's OK to ignore the error above, although some " 84 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 85 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" 86 | ) 87 | 88 | 89 | def get_extensions(): 90 | if not BUILD_CUDA: 91 | return [] 92 | 93 | try: 94 | from torch.utils.cpp_extension import CUDAExtension 95 | 96 | srcs = ["sam2/csrc/connected_components.cu"] 97 | compile_args = { 98 | "cxx": [], 99 | "nvcc": [ 100 | "-DCUDA_HAS_FP16=1", 101 | "-D__CUDA_NO_HALF_OPERATORS__", 102 | "-D__CUDA_NO_HALF_CONVERSIONS__", 103 | "-D__CUDA_NO_HALF2_OPERATORS__", 104 | ], 105 | } 106 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 107 | except Exception as e: 108 | if BUILD_ALLOW_ERRORS: 109 | print(CUDA_ERROR_MSG.format(e)) 110 | ext_modules = [] 111 | else: 112 | raise e 113 | 114 | return ext_modules 115 | 116 | 117 | try: 118 | from torch.utils.cpp_extension import BuildExtension 119 | 120 | class BuildExtensionIgnoreErrors(BuildExtension): 121 | 122 | def finalize_options(self): 123 | try: 124 | super().finalize_options() 125 | except Exception as e: 126 | print(CUDA_ERROR_MSG.format(e)) 127 | self.extensions = [] 128 | 129 | def build_extensions(self): 130 | try: 131 | super().build_extensions() 132 | except Exception as e: 133 | print(CUDA_ERROR_MSG.format(e)) 134 | self.extensions = [] 135 | 136 | def get_ext_filename(self, ext_name): 137 | try: 138 | return super().get_ext_filename(ext_name) 139 | except Exception as e: 140 | print(CUDA_ERROR_MSG.format(e)) 141 | self.extensions = [] 142 | return "_C.so" 143 | 144 | cmdclass = { 145 | "build_ext": ( 146 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) 147 | if BUILD_ALLOW_ERRORS 148 | else BuildExtension.with_options(no_python_abi_suffix=True) 149 | ) 150 | } 151 | except Exception as e: 152 | cmdclass = {} 153 | if BUILD_ALLOW_ERRORS: 154 | print(CUDA_ERROR_MSG.format(e)) 155 | else: 156 | raise e 157 | 158 | 159 | # Setup configuration 160 | setup( 161 | name=NAME, 162 | version=VERSION, 163 | description=DESCRIPTION, 164 | long_description=LONG_DESCRIPTION, 165 | long_description_content_type="text/markdown", 166 | url=URL, 167 | author=AUTHOR, 168 | author_email=AUTHOR_EMAIL, 169 | license=LICENSE, 170 | packages=find_packages(exclude="notebooks"), 171 | include_package_data=True, 172 | install_requires=REQUIRED_PACKAGES, 173 | extras_require=EXTRA_PACKAGES, 174 | python_requires=">=3.10.0", 175 | ext_modules=get_extensions(), 176 | cmdclass=cmdclass, 177 | ) 178 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /training/__pycache__/loss_fns.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/loss_fns.cpython-312.pyc -------------------------------------------------------------------------------- /training/__pycache__/optimizer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/optimizer.cpython-312.pyc -------------------------------------------------------------------------------- /training/__pycache__/trainer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/trainer.cpython-312.pyc -------------------------------------------------------------------------------- /training/assets/MOSE_sample_val_list.txt: -------------------------------------------------------------------------------- 1 | 32e5d721 2 | 5bad0bab 3 | 267bfd6c 4 | 0a43a414 5 | 56c56ca9 6 | 9a1146b3 7 | c6ad7aaf 8 | 78a1f4b1 9 | fc455e73 10 | 072e7b3f 11 | 77ccb57d 12 | a76ee415 13 | 8cdcfc17 14 | 5d518b42 15 | 376dd830 16 | 0e843fc8 17 | 2af0e766 18 | 2bd4e845 19 | de2f2a6a 20 | ade9ee91 21 | 001ca3cb 22 | fc4c1c67 23 | 8ef55579 24 | b84ce852 25 | 4cc8528a 26 | 767ffaaa 27 | 112a2ef0 28 | a338c8aa 29 | cbd144f5 30 | 5ff72128 31 | 86a949e2 32 | 9f2323ac 33 | 1fab1d1c 34 | 75924351 35 | ef55817b 36 | 02deca50 37 | 4d979d99 38 | 4d65f873 39 | 28470fa0 40 | 0d1575fe 41 | 06ea172e 42 | 29a6ddc2 43 | 797f1bec 44 | 780e7a99 45 | b9ed5b44 46 | 02a236b4 47 | 607d8ff5 48 | af5666b2 49 | 0558d0ed 50 | a938c6b2 51 | 103df575 52 | 77110e80 53 | 739e5a07 54 | 6763a576 55 | 06ebc138 56 | ba4b3b09 57 | b35cc2f3 58 | 4e0597a0 59 | 5949ee84 60 | 5348d547 61 | 323c4236 62 | b3b51117 63 | 55727ddd 64 | ab2714f3 65 | d2878895 66 | c0734cb3 67 | 94f7c53e 68 | 2a2745e5 69 | 442ffb54 70 | 3592425a 71 | 50ae03b0 72 | 5f150435 73 | 3067f9fa 74 | 9ffb2818 75 | adeaf5aa 76 | 31caacec 77 | 1cd99b86 78 | aa22f9d0 79 | 8fa50320 80 | e6348d2c 81 | 42ff84a5 82 | 8c8b7913 83 | c96adcbc 84 | 495be321 85 | db735509 86 | ee113fc4 87 | a678cdab 88 | c409ca4d 89 | 68d2b259 90 | 592b4dee 91 | 4e2b4dc7 92 | eb4d26e1 93 | 2009a00f 94 | bec5c89d 95 | 67191f24 96 | a3e85b4b 97 | da7080cd 98 | 80d978e9 99 | 36dcb93f 100 | a41e8c44 101 | 12fdc864 102 | 46d140ea 103 | 657c9dd9 104 | a86f84ee 105 | 90c1c43d 106 | 33015509 107 | afc7664d 108 | 23df06e1 109 | 291d4799 110 | 0ab75563 111 | 251bf059 112 | bcefdcc4 113 | ce9a2796 114 | 94d3403a 115 | 8f2e04bc 116 | f9cda066 117 | 9dfa2cc5 118 | 66924c91 119 | e765a09e 120 | 15654ee1 121 | 48e0bd39 122 | ee095221 123 | 2463609b 124 | 544d0d1f 125 | 51b8c2e1 126 | d321dde4 127 | 4cb11a5f 128 | d7058a0d 129 | 37af282a 130 | fabae187 131 | 7be91184 132 | 181ec185 133 | 2d16ceeb 134 | b56be4b1 135 | 6699eff0 136 | 79acac96 137 | d61c4665 138 | 0c13e1e7 139 | 100f6ecf 140 | 71217dfc 141 | 82df0888 142 | 4c42c747 143 | c9fdf703 144 | d2efeb4b 145 | 69ed9d14 146 | 64914fb6 147 | 255bedbc 148 | 4ea934d8 149 | a034feb2 150 | e4f4ddae 151 | e36a3026 152 | c1489591 153 | 111bb373 154 | e1d9fb32 155 | 93e22d48 156 | c1ec4b26 157 | d9638e69 158 | 60ab04c5 159 | cfe7773a 160 | 62132822 161 | 2f5fb2a3 162 | 7bdd197d 163 | 033333fd 164 | 130fcdbe 165 | 12e509c2 166 | 67138c33 167 | 6f90cc5f 168 | 4e3020fe 169 | bbdd8bb7 170 | b399ccdb 171 | fecd10d2 172 | 2e0967f7 173 | f509054f 174 | 792c6ff7 175 | 48e2afc5 176 | d904c048 177 | 111e0a5c 178 | b83024e2 179 | e6a7b79c 180 | bdc5ccf7 181 | b8146d00 182 | 9d394f1a 183 | 645b84f9 184 | 95ab2d0f 185 | e6f8a31d 186 | b4f876fb 187 | dc2c570d 188 | 3afd02d7 189 | 5c80c82c 190 | b1b32ddd 191 | 9f25fc61 192 | ba538072 193 | f8916fef 194 | 43c04ad2 195 | a658e949 196 | 2861dd53 197 | f6e40aba 198 | 09d305d1 199 | aac33bff 200 | 8d9d4c08 201 | -------------------------------------------------------------------------------- /training/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/dataset/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/sam2_datasets.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/sam2_datasets.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/transforms.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/transforms.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/vos_dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_dataset.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/vos_raw_dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_raw_dataset.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/vos_sampler.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_sampler.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/__pycache__/vos_segment_loader.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_segment_loader.cpython-312.pyc -------------------------------------------------------------------------------- /training/dataset/sam2_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | from typing import Callable, Iterable, List, Optional, Sequence 10 | 11 | import torch 12 | 13 | from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset 14 | 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | 18 | class MixedDataLoader: 19 | def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): 20 | """ 21 | Args: 22 | dataloaders (List[DataLoader]): List of DataLoaders to be mixed. 23 | mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from 24 | 25 | """ 26 | assert len(dataloaders) == mixing_prob.shape[0] 27 | self.dataloaders = dataloaders 28 | self.mixing_prob = mixing_prob 29 | # Iterator state 30 | self._iter_dls = None 31 | self._iter_mixing_prob = None 32 | self.random_generator = torch.Generator() 33 | 34 | def __len__(self): 35 | return sum([len(d) for d in self.dataloaders]) 36 | 37 | def __iter__(self): 38 | # Synchronize dataloader seeds 39 | self.random_generator.manual_seed(42) 40 | self._iter_dls = [iter(loader) for loader in self.dataloaders] 41 | self._iter_mixing_prob = self.mixing_prob.clone() 42 | return self 43 | 44 | def __next__(self): 45 | """ 46 | Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. 47 | """ 48 | if self._iter_dls is None: 49 | raise TypeError(f"{type(self).__name__} object is not an iterator") 50 | 51 | while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. 52 | dataset_idx = self._iter_mixing_prob.multinomial( 53 | 1, generator=self.random_generator 54 | ).item() 55 | try: 56 | item = next(self._iter_dls[dataset_idx]) 57 | return item 58 | except StopIteration: 59 | # No more iterations for this dataset, set it's mixing probability to zero and try again. 60 | self._iter_mixing_prob[dataset_idx] = 0 61 | except Exception as e: 62 | # log and raise any other unexpected error. 63 | logging.error(e) 64 | raise e 65 | 66 | # Exhausted all iterators 67 | raise StopIteration 68 | 69 | 70 | class TorchTrainMixedDataset: 71 | def __init__( 72 | self, 73 | datasets: List[Dataset], 74 | batch_sizes: List[int], 75 | num_workers: int, 76 | shuffle: bool, 77 | pin_memory: bool, 78 | drop_last: bool, 79 | collate_fn: Optional[Callable] = None, 80 | worker_init_fn: Optional[Callable] = None, 81 | phases_per_epoch: int = 1, 82 | dataset_prob: Optional[List[float]] = None, 83 | ) -> None: 84 | """ 85 | Args: 86 | datasets (List[Dataset]): List of Datasets to be mixed. 87 | batch_sizes (List[int]): Batch sizes for each dataset in the list. 88 | num_workers (int): Number of workers per dataloader. 89 | shuffle (bool): Whether or not to shuffle data. 90 | pin_memory (bool): If True, use pinned memory when loading tensors from disk. 91 | drop_last (bool): Whether or not to drop the last batch of data. 92 | collate_fn (Callable): Function to merge a list of samples into a mini-batch. 93 | worker_init_fn (Callable): Function to init each dataloader worker. 94 | phases_per_epoch (int): Number of phases per epoch. 95 | dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 96 | """ 97 | 98 | self.datasets = datasets 99 | self.batch_sizes = batch_sizes 100 | self.num_workers = num_workers 101 | self.shuffle = shuffle 102 | self.pin_memory = pin_memory 103 | self.drop_last = drop_last 104 | self.collate_fn = collate_fn 105 | self.worker_init_fn = worker_init_fn 106 | assert len(self.datasets) > 0 107 | for dataset in self.datasets: 108 | assert not isinstance(dataset, IterableDataset), "Not supported" 109 | # `RepeatFactorWrapper` requires calling set_epoch first to get its length 110 | self._set_dataset_epoch(dataset, 0) 111 | self.phases_per_epoch = phases_per_epoch 112 | self.chunks = [None] * len(datasets) 113 | if dataset_prob is None: 114 | # If not provided, assign each dataset a probability proportional to its length. 115 | dataset_lens = [ 116 | (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) 117 | for d, bs in zip(datasets, batch_sizes) 118 | ] 119 | total_len = sum(dataset_lens) 120 | dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) 121 | else: 122 | assert len(dataset_prob) == len(datasets) 123 | dataset_prob = torch.tensor(dataset_prob) 124 | 125 | logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") 126 | assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" 127 | self.dataset_prob = dataset_prob 128 | 129 | def _set_dataset_epoch(self, dataset, epoch: int) -> None: 130 | if hasattr(dataset, "epoch"): 131 | dataset.epoch = epoch 132 | if hasattr(dataset, "set_epoch"): 133 | dataset.set_epoch(epoch) 134 | 135 | def get_loader(self, epoch) -> Iterable: 136 | dataloaders = [] 137 | for d_idx, (dataset, batch_size) in enumerate( 138 | zip(self.datasets, self.batch_sizes) 139 | ): 140 | if self.phases_per_epoch > 1: 141 | # Major epoch that looops over entire dataset 142 | # len(main_epoch) == phases_per_epoch * len(epoch) 143 | main_epoch = epoch // self.phases_per_epoch 144 | 145 | # Phase with in the main epoch 146 | local_phase = epoch % self.phases_per_epoch 147 | 148 | # Start of new data-epoch or job is resumed after preemtion. 149 | if local_phase == 0 or self.chunks[d_idx] is None: 150 | # set seed for dataset epoch 151 | # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. 152 | self._set_dataset_epoch(dataset, main_epoch) 153 | 154 | # Separate random generator for subset sampling 155 | g = torch.Generator() 156 | g.manual_seed(main_epoch) 157 | self.chunks[d_idx] = torch.chunk( 158 | torch.randperm(len(dataset), generator=g), 159 | self.phases_per_epoch, 160 | ) 161 | 162 | dataset = Subset(dataset, self.chunks[d_idx][local_phase]) 163 | else: 164 | self._set_dataset_epoch(dataset, epoch) 165 | 166 | sampler = DistributedSampler(dataset, shuffle=self.shuffle) 167 | sampler.set_epoch(epoch) 168 | 169 | batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) 170 | dataloaders.append( 171 | DataLoader( 172 | dataset, 173 | num_workers=self.num_workers, 174 | pin_memory=self.pin_memory, 175 | batch_sampler=batch_sampler, 176 | collate_fn=self.collate_fn, 177 | worker_init_fn=self.worker_init_fn, 178 | ) 179 | ) 180 | return MixedDataLoader(dataloaders, self.dataset_prob) 181 | -------------------------------------------------------------------------------- /training/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular""" 8 | 9 | from typing import Iterable 10 | 11 | import torch 12 | from torch.utils.data import ( 13 | ConcatDataset as TorchConcatDataset, 14 | Dataset, 15 | Subset as TorchSubset, 16 | ) 17 | 18 | 19 | class ConcatDataset(TorchConcatDataset): 20 | def __init__(self, datasets: Iterable[Dataset]) -> None: 21 | super(ConcatDataset, self).__init__(datasets) 22 | 23 | self.repeat_factors = torch.cat([d.repeat_factors for d in datasets]) 24 | 25 | def set_epoch(self, epoch: int): 26 | for dataset in self.datasets: 27 | if hasattr(dataset, "epoch"): 28 | dataset.epoch = epoch 29 | if hasattr(dataset, "set_epoch"): 30 | dataset.set_epoch(epoch) 31 | 32 | 33 | class Subset(TorchSubset): 34 | def __init__(self, dataset, indices) -> None: 35 | super(Subset, self).__init__(dataset, indices) 36 | 37 | self.repeat_factors = dataset.repeat_factors[indices] 38 | assert len(indices) == len(self.repeat_factors) 39 | 40 | 41 | # Adapted from Detectron2 42 | class RepeatFactorWrapper(Dataset): 43 | """ 44 | Thin wrapper around a dataset to implement repeat factor sampling. 45 | The underlying dataset must have a repeat_factors member to indicate the per-image factor. 46 | Set it to uniformly ones to disable repeat factor sampling 47 | """ 48 | 49 | def __init__(self, dataset, seed: int = 0): 50 | self.dataset = dataset 51 | self.epoch_ids = None 52 | self._seed = seed 53 | 54 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 55 | self._int_part = torch.trunc(dataset.repeat_factors) 56 | self._frac_part = dataset.repeat_factors - self._int_part 57 | 58 | def _get_epoch_indices(self, generator): 59 | """ 60 | Create a list of dataset indices (with repeats) to use for one epoch. 61 | 62 | Args: 63 | generator (torch.Generator): pseudo random number generator used for 64 | stochastic rounding. 65 | 66 | Returns: 67 | torch.Tensor: list of dataset indices to use in one epoch. Each index 68 | is repeated based on its calculated repeat factor. 69 | """ 70 | # Since repeat factors are fractional, we use stochastic rounding so 71 | # that the target repeat factor is achieved in expectation over the 72 | # course of training 73 | rands = torch.rand(len(self._frac_part), generator=generator) 74 | rep_factors = self._int_part + (rands < self._frac_part).float() 75 | # Construct a list of indices in which we repeat images as specified 76 | indices = [] 77 | for dataset_index, rep_factor in enumerate(rep_factors): 78 | indices.extend([dataset_index] * int(rep_factor.item())) 79 | return torch.tensor(indices, dtype=torch.int64) 80 | 81 | def __len__(self): 82 | if self.epoch_ids is None: 83 | # Here we raise an error instead of returning original len(self.dataset) avoid 84 | # accidentally using unwrapped length. Otherwise it's error-prone since the 85 | # length changes to `len(self.epoch_ids)`changes after set_epoch is called. 86 | raise RuntimeError("please call set_epoch first to get wrapped length") 87 | # return len(self.dataset) 88 | 89 | return len(self.epoch_ids) 90 | 91 | def set_epoch(self, epoch: int): 92 | g = torch.Generator() 93 | g.manual_seed(self._seed + epoch) 94 | self.epoch_ids = self._get_epoch_indices(g) 95 | if hasattr(self.dataset, "set_epoch"): 96 | self.dataset.set_epoch(epoch) 97 | 98 | def __getitem__(self, idx): 99 | if self.epoch_ids is None: 100 | raise RuntimeError( 101 | "Repeat ids haven't been computed. Did you forget to call set_epoch?" 102 | ) 103 | 104 | return self.dataset[self.epoch_ids[idx]] 105 | -------------------------------------------------------------------------------- /training/dataset/vos_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import random 9 | from copy import deepcopy 10 | 11 | import numpy as np 12 | 13 | import torch 14 | from iopath.common.file_io import g_pathmgr 15 | from PIL import Image as PILImage 16 | from torchvision.datasets.vision import VisionDataset 17 | 18 | from training.dataset.vos_raw_dataset import VOSRawDataset 19 | from training.dataset.vos_sampler import VOSSampler 20 | from training.dataset.vos_segment_loader import JSONSegmentLoader 21 | 22 | from training.utils.data_utils import Frame, Object, VideoDatapoint 23 | 24 | MAX_RETRIES = 100 25 | 26 | 27 | class VOSDataset(VisionDataset): 28 | def __init__( 29 | self, 30 | transforms, 31 | training: bool, 32 | video_dataset: VOSRawDataset, 33 | sampler: VOSSampler, 34 | multiplier: int, 35 | always_target=True, 36 | target_segments_available=True, 37 | ): 38 | self._transforms = transforms 39 | self.training = training 40 | self.video_dataset = video_dataset 41 | self.sampler = sampler 42 | 43 | self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) 44 | self.repeat_factors *= multiplier 45 | print(f"Raw dataset length = {len(self.video_dataset)}") 46 | 47 | self.curr_epoch = 0 # Used in case data loader behavior changes across epochs 48 | self.always_target = always_target 49 | self.target_segments_available = target_segments_available 50 | 51 | def _get_datapoint(self, idx): 52 | 53 | for retry in range(MAX_RETRIES): 54 | try: 55 | if isinstance(idx, torch.Tensor): 56 | idx = idx.item() 57 | # sample a video 58 | video, segment_loader = self.video_dataset.get_video(idx) 59 | # sample frames and object indices to be used in a datapoint 60 | sampled_frms_and_objs = self.sampler.sample( 61 | video, segment_loader, epoch=self.curr_epoch 62 | ) 63 | break # Succesfully loaded video 64 | except Exception as e: 65 | if self.training: 66 | logging.warning( 67 | f"Loading failed (id={idx}); Retry {retry} with exception: {e}" 68 | ) 69 | idx = random.randrange(0, len(self.video_dataset)) 70 | else: 71 | # Shouldn't fail to load a val video 72 | raise e 73 | 74 | datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) 75 | for transform in self._transforms: 76 | datapoint = transform(datapoint, epoch=self.curr_epoch) 77 | return datapoint 78 | 79 | def construct(self, video, sampled_frms_and_objs, segment_loader): 80 | """ 81 | Constructs a VideoDatapoint sample to pass to transforms 82 | """ 83 | sampled_frames = sampled_frms_and_objs.frames 84 | sampled_object_ids = sampled_frms_and_objs.object_ids 85 | 86 | images = [] 87 | rgb_images = load_images(sampled_frames) 88 | # Iterate over the sampled frames and store their rgb data and object data (bbox, segment) 89 | for frame_idx, frame in enumerate(sampled_frames): 90 | w, h = rgb_images[frame_idx].size 91 | images.append( 92 | Frame( 93 | data=rgb_images[frame_idx], 94 | objects=[], 95 | ) 96 | ) 97 | # We load the gt segments associated with the current frame 98 | if isinstance(segment_loader, JSONSegmentLoader): 99 | segments = segment_loader.load( 100 | frame.frame_idx, obj_ids=sampled_object_ids 101 | ) 102 | else: 103 | segments = segment_loader.load(frame.frame_idx) 104 | for obj_id in sampled_object_ids: 105 | # Extract the segment 106 | if obj_id in segments: 107 | assert ( 108 | segments[obj_id] is not None 109 | ), "None targets are not supported" 110 | # segment is uint8 and remains uint8 throughout the transforms 111 | segment = segments[obj_id].to(torch.uint8) 112 | else: 113 | # There is no target, we either use a zero mask target or drop this object 114 | if not self.always_target: 115 | continue 116 | segment = torch.zeros(h, w, dtype=torch.uint8) 117 | 118 | images[frame_idx].objects.append( 119 | Object( 120 | object_id=obj_id, 121 | frame_index=frame.frame_idx, 122 | segment=segment, 123 | ) 124 | ) 125 | return VideoDatapoint( 126 | frames=images, 127 | video_id=video.video_id, 128 | size=(h, w), 129 | ) 130 | 131 | def __getitem__(self, idx): 132 | return self._get_datapoint(idx) 133 | 134 | def __len__(self): 135 | return len(self.video_dataset) 136 | 137 | 138 | def load_images(frames): 139 | all_images = [] 140 | cache = {} 141 | for frame in frames: 142 | if frame.data is None: 143 | # Load the frame rgb data from file 144 | path = frame.image_path 145 | if path in cache: 146 | all_images.append(deepcopy(all_images[cache[path]])) 147 | continue 148 | with g_pathmgr.open(path, "rb") as fopen: 149 | all_images.append(PILImage.open(fopen).convert("RGB")) 150 | cache[path] = len(all_images) - 1 151 | else: 152 | # The frame rgb data has already been loaded 153 | # Convert it to a PILImage 154 | all_images.append(tensor_2_PIL(frame.data)) 155 | 156 | return all_images 157 | 158 | 159 | def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: 160 | data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 161 | data = data.astype(np.uint8) 162 | return PILImage.fromarray(data) 163 | -------------------------------------------------------------------------------- /training/dataset/vos_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | from dataclasses import dataclass 9 | from typing import List 10 | 11 | from training.dataset.vos_segment_loader import LazySegments 12 | 13 | MAX_RETRIES = 1000 14 | 15 | 16 | @dataclass 17 | class SampledFramesAndObjects: 18 | frames: List[int] 19 | object_ids: List[int] 20 | 21 | 22 | class VOSSampler: 23 | def __init__(self, sort_frames=True): 24 | # frames are ordered by frame id when sort_frames is True 25 | self.sort_frames = sort_frames 26 | 27 | def sample(self, video): 28 | raise NotImplementedError() 29 | 30 | 31 | class RandomUniformSampler(VOSSampler): 32 | def __init__( 33 | self, 34 | num_frames, 35 | max_num_objects, 36 | reverse_time_prob=0.0, 37 | ): 38 | self.num_frames = num_frames 39 | self.max_num_objects = max_num_objects 40 | self.reverse_time_prob = reverse_time_prob 41 | 42 | def sample(self, video, segment_loader, epoch=None): 43 | 44 | for retry in range(MAX_RETRIES): 45 | if len(video.frames) < self.num_frames: 46 | raise Exception( 47 | f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." 48 | ) 49 | start = random.randrange(0, len(video.frames) - self.num_frames + 1) 50 | frames = [video.frames[start + step] for step in range(self.num_frames)] 51 | if random.uniform(0, 1) < self.reverse_time_prob: 52 | # Reverse time 53 | frames = frames[::-1] 54 | 55 | # Get first frame object ids 56 | visible_object_ids = [] 57 | loaded_segms = segment_loader.load(frames[0].frame_idx) 58 | if isinstance(loaded_segms, LazySegments): 59 | # LazySegments for SA1BRawDataset 60 | visible_object_ids = list(loaded_segms.keys()) 61 | else: 62 | for object_id, segment in segment_loader.load( 63 | frames[0].frame_idx 64 | ).items(): 65 | if segment.sum(): 66 | visible_object_ids.append(object_id) 67 | 68 | # First frame needs to have at least a target to track 69 | if len(visible_object_ids) > 0: 70 | break 71 | if retry >= MAX_RETRIES - 1: 72 | raise Exception("No visible objects") 73 | 74 | object_ids = random.sample( 75 | visible_object_ids, 76 | min(len(visible_object_ids), self.max_num_objects), 77 | ) 78 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 79 | 80 | 81 | class EvalSampler(VOSSampler): 82 | """ 83 | VOS Sampler for evaluation: sampling all the frames and all the objects in a video 84 | """ 85 | 86 | def __init__( 87 | self, 88 | ): 89 | super().__init__() 90 | 91 | def sample(self, video, segment_loader, epoch=None): 92 | """ 93 | Sampling all the frames and all the objects 94 | """ 95 | if self.sort_frames: 96 | # ordered by frame id 97 | frames = sorted(video.frames, key=lambda x: x.frame_idx) 98 | else: 99 | # use the original order 100 | frames = video.frames 101 | object_ids = segment_loader.load(frames[0].frame_idx).keys() 102 | if len(object_ids) == 0: 103 | raise Exception("First frame of the video has no objects") 104 | 105 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids) 106 | -------------------------------------------------------------------------------- /training/dataset/vos_segment_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import json 9 | import os 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | 15 | from PIL import Image as PILImage 16 | 17 | try: 18 | from pycocotools import mask as mask_utils 19 | except: 20 | pass 21 | 22 | 23 | class JSONSegmentLoader: 24 | def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): 25 | # Annotations in the json are provided every ann_every th frame 26 | self.ann_every = ann_every 27 | # Ids of the objects to consider when sampling this video 28 | self.valid_obj_ids = valid_obj_ids 29 | with open(video_json_path, "r") as f: 30 | data = json.load(f) 31 | if isinstance(data, list): 32 | self.frame_annots = data 33 | elif isinstance(data, dict): 34 | masklet_field_name = "masklet" if "masklet" in data else "masks" 35 | self.frame_annots = data[masklet_field_name] 36 | if "fps" in data: 37 | if isinstance(data["fps"], list): 38 | annotations_fps = int(data["fps"][0]) 39 | else: 40 | annotations_fps = int(data["fps"]) 41 | assert frames_fps % annotations_fps == 0 42 | self.ann_every = frames_fps // annotations_fps 43 | else: 44 | raise NotImplementedError 45 | 46 | def load(self, frame_id, obj_ids=None): 47 | assert frame_id % self.ann_every == 0 48 | rle_mask = self.frame_annots[frame_id // self.ann_every] 49 | 50 | valid_objs_ids = set(range(len(rle_mask))) 51 | if self.valid_obj_ids is not None: 52 | # Remove the masklets that have been filtered out for this video 53 | valid_objs_ids &= set(self.valid_obj_ids) 54 | if obj_ids is not None: 55 | # Only keep the objects that have been sampled 56 | valid_objs_ids &= set(obj_ids) 57 | valid_objs_ids = sorted(list(valid_objs_ids)) 58 | 59 | # Construct rle_masks_filtered that only contains the rle masks we are interested in 60 | id_2_idx = {} 61 | rle_mask_filtered = [] 62 | for obj_id in valid_objs_ids: 63 | if rle_mask[obj_id] is not None: 64 | id_2_idx[obj_id] = len(rle_mask_filtered) 65 | rle_mask_filtered.append(rle_mask[obj_id]) 66 | else: 67 | id_2_idx[obj_id] = None 68 | 69 | # Decode the masks 70 | raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( 71 | 2, 0, 1 72 | ) # (num_obj, h, w) 73 | segments = {} 74 | for obj_id in valid_objs_ids: 75 | if id_2_idx[obj_id] is None: 76 | segments[obj_id] = None 77 | else: 78 | idx = id_2_idx[obj_id] 79 | segments[obj_id] = raw_segments[idx] 80 | return segments 81 | 82 | def get_valid_obj_frames_ids(self, num_frames_min=None): 83 | # For each object, find all the frames with a valid (not None) mask 84 | num_objects = len(self.frame_annots[0]) 85 | 86 | # The result dict associates each obj_id with the id of its valid frames 87 | res = {obj_id: [] for obj_id in range(num_objects)} 88 | 89 | for annot_idx, annot in enumerate(self.frame_annots): 90 | for obj_id in range(num_objects): 91 | if annot[obj_id] is not None: 92 | res[obj_id].append(int(annot_idx * self.ann_every)) 93 | 94 | if num_frames_min is not None: 95 | # Remove masklets that have less than num_frames_min valid masks 96 | for obj_id, valid_frames in list(res.items()): 97 | if len(valid_frames) < num_frames_min: 98 | res.pop(obj_id) 99 | 100 | return res 101 | 102 | 103 | class PalettisedPNGSegmentLoader: 104 | def __init__(self, video_png_root, sample_rate=1): 105 | """ 106 | SegmentLoader for datasets with masks stored as palettised PNGs. 107 | video_png_root: the folder contains all the masks stored in png 108 | """ 109 | self.video_png_root = video_png_root 110 | self.sample_rate = sample_rate 111 | # build a mapping from frame id to their PNG mask path 112 | # note that in some datasets, the PNG paths could have more 113 | # than 5 digits, e.g. "00000000.png" instead of "00000.png" 114 | png_filenames = sorted(glob.glob(os.path.join(self.video_png_root, "*.png"))) # os.listdir(self.video_png_root) 115 | self.frame_id_to_png_filename = {} 116 | for idx, filename in enumerate(png_filenames[::self.sample_rate]): 117 | frame_id = idx # int(os.path.basename(filename).split(".")[0]) 118 | self.frame_id_to_png_filename[frame_id] = filename 119 | 120 | def load(self, frame_id): 121 | """ 122 | load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') 123 | Args: 124 | frame_id: int, define the mask path 125 | Return: 126 | binary_segments: dict 127 | """ 128 | # check the path 129 | mask_path = os.path.join( 130 | self.video_png_root, self.frame_id_to_png_filename[frame_id] 131 | ) 132 | 133 | # load the mask 134 | masks = PILImage.open(mask_path).convert("P") 135 | masks = np.array(masks) 136 | 137 | object_id = pd.unique(masks.flatten()) 138 | object_id = object_id[object_id != 0] # remove background (0) 139 | 140 | # convert into N binary segmentation masks 141 | binary_segments = {} 142 | for i in object_id: 143 | bs = masks == i 144 | binary_segments[i] = torch.from_numpy(bs) 145 | 146 | return binary_segments 147 | 148 | def __len__(self): 149 | return 150 | 151 | 152 | class MultiplePNGSegmentLoader: 153 | def __init__(self, video_png_root, single_object_mode=False): 154 | """ 155 | video_png_root: the folder contains all the masks stored in png 156 | single_object_mode: whether to load only a single object at a time 157 | """ 158 | self.video_png_root = video_png_root 159 | self.single_object_mode = single_object_mode 160 | # read a mask to know the resolution of the video 161 | if self.single_object_mode: 162 | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] 163 | else: 164 | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] 165 | tmp_mask = np.array(PILImage.open(tmp_mask_path)) 166 | self.H = tmp_mask.shape[0] 167 | self.W = tmp_mask.shape[1] 168 | if self.single_object_mode: 169 | self.obj_id = ( 170 | int(video_png_root.split("/")[-1]) + 1 171 | ) # offset by 1 as bg is 0 172 | else: 173 | self.obj_id = None 174 | 175 | def load(self, frame_id): 176 | if self.single_object_mode: 177 | return self._load_single_png(frame_id) 178 | else: 179 | return self._load_multiple_pngs(frame_id) 180 | 181 | def _load_single_png(self, frame_id): 182 | """ 183 | load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') 184 | Args: 185 | frame_id: int, define the mask path 186 | Return: 187 | binary_segments: dict 188 | """ 189 | mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") 190 | binary_segments = {} 191 | 192 | if os.path.exists(mask_path): 193 | mask = np.array(PILImage.open(mask_path)) 194 | else: 195 | # if png doesn't exist, empty mask 196 | mask = np.zeros((self.H, self.W), dtype=bool) 197 | binary_segments[self.obj_id] = torch.from_numpy(mask > 0) 198 | return binary_segments 199 | 200 | def _load_multiple_pngs(self, frame_id): 201 | """ 202 | load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') 203 | Args: 204 | frame_id: int, define the mask path 205 | Return: 206 | binary_segments: dict 207 | """ 208 | # get the path 209 | all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) 210 | num_objects = len(all_objects) 211 | assert num_objects > 0 212 | 213 | # load the masks 214 | binary_segments = {} 215 | for obj_folder in all_objects: 216 | # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder 217 | obj_id = int(obj_folder.split("/")[-1]) 218 | obj_id = obj_id + 1 # offset 1 as bg is 0 219 | mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") 220 | if os.path.exists(mask_path): 221 | mask = np.array(PILImage.open(mask_path)) 222 | else: 223 | mask = np.zeros((self.H, self.W), dtype=bool) 224 | binary_segments[obj_id] = torch.from_numpy(mask > 0) 225 | 226 | return binary_segments 227 | 228 | def __len__(self): 229 | return 230 | 231 | 232 | class LazySegments: 233 | """ 234 | Only decodes segments that are actually used. 235 | """ 236 | 237 | def __init__(self): 238 | self.segments = {} 239 | self.cache = {} 240 | 241 | def __setitem__(self, key, item): 242 | self.segments[key] = item 243 | 244 | def __getitem__(self, key): 245 | if key in self.cache: 246 | return self.cache[key] 247 | rle = self.segments[key] 248 | mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] 249 | self.cache[key] = mask 250 | return mask 251 | 252 | def __contains__(self, key): 253 | return key in self.segments 254 | 255 | def __len__(self): 256 | return len(self.segments) 257 | 258 | def keys(self): 259 | return self.segments.keys() 260 | 261 | 262 | class SA1BSegmentLoader: 263 | def __init__( 264 | self, 265 | video_mask_path, 266 | mask_area_frac_thresh=1.1, 267 | video_frame_path=None, 268 | uncertain_iou=-1, 269 | ): 270 | with open(video_mask_path, "r") as f: 271 | self.frame_annots = json.load(f) 272 | 273 | if mask_area_frac_thresh <= 1.0: 274 | # Lazily read frame 275 | orig_w, orig_h = PILImage.open(video_frame_path).size 276 | area = orig_w * orig_h 277 | 278 | self.frame_annots = self.frame_annots["annotations"] 279 | 280 | rle_masks = [] 281 | for frame_annot in self.frame_annots: 282 | if not frame_annot["area"] > 0: 283 | continue 284 | if ("uncertain_iou" in frame_annot) and ( 285 | frame_annot["uncertain_iou"] < uncertain_iou 286 | ): 287 | # uncertain_iou is stability score 288 | continue 289 | if ( 290 | mask_area_frac_thresh <= 1.0 291 | and (frame_annot["area"] / area) >= mask_area_frac_thresh 292 | ): 293 | continue 294 | rle_masks.append(frame_annot["segmentation"]) 295 | 296 | self.segments = LazySegments() 297 | for i, rle in enumerate(rle_masks): 298 | self.segments[i] = rle 299 | 300 | def load(self, frame_idx): 301 | return self.segments 302 | 303 | 304 | class NPZSegmentLoader: 305 | def __init__(self, masks): 306 | """ 307 | Initialize the NPZSegmentLoader. 308 | 309 | Args: 310 | masks (numpy.ndarray): Array of masks with shape (img_num, H, W). 311 | """ 312 | self.masks = masks 313 | 314 | def load(self, frame_idx): 315 | """ 316 | Load the single mask for the given frame index and convert it to binary segments. 317 | 318 | Args: 319 | frame_idx (int): Index of the frame to load. 320 | 321 | Returns: 322 | dict: A dictionary where keys are object IDs and values are binary masks. 323 | """ 324 | mask = self.masks[frame_idx] 325 | 326 | # Find unique object IDs in the mask, excluding the background (0) 327 | object_ids = np.unique(mask) 328 | object_ids = object_ids[object_ids != 0] 329 | 330 | # Convert into binary segmentation masks for each object 331 | binary_segments = {} 332 | for obj_id in object_ids: 333 | binary_mask = (mask == obj_id) 334 | binary_segments[int(obj_id)] = torch.from_numpy(binary_mask).bool() 335 | 336 | return binary_segments 337 | -------------------------------------------------------------------------------- /training/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/model/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/model/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /training/model/__pycache__/sam2.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/model/__pycache__/sam2.cpython-312.pyc -------------------------------------------------------------------------------- /training/scripts/sav_frame_extraction_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | import argparse 4 | import os 5 | from pathlib import Path 6 | 7 | import cv2 8 | 9 | import numpy as np 10 | import submitit 11 | import tqdm 12 | 13 | 14 | def get_args_parser(): 15 | parser = argparse.ArgumentParser( 16 | description="[SA-V Preprocessing] Extracting JPEG frames", 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 18 | ) 19 | 20 | # ------------ 21 | # DATA 22 | # ------------ 23 | data_parser = parser.add_argument_group( 24 | title="SA-V dataset data root", 25 | description="What data to load and how to process it.", 26 | ) 27 | data_parser.add_argument( 28 | "--sav-vid-dir", 29 | type=str, 30 | required=True, 31 | help=("Where to find the SAV videos"), 32 | ) 33 | data_parser.add_argument( 34 | "--sav-frame-sample-rate", 35 | type=int, 36 | default=4, 37 | help="Rate at which to sub-sample frames", 38 | ) 39 | 40 | # ------------ 41 | # LAUNCH 42 | # ------------ 43 | launch_parser = parser.add_argument_group( 44 | title="Cluster launch settings", 45 | description="Number of jobs and retry settings.", 46 | ) 47 | launch_parser.add_argument( 48 | "--n-jobs", 49 | type=int, 50 | required=True, 51 | help="Shard the run over this many jobs.", 52 | ) 53 | launch_parser.add_argument( 54 | "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes." 55 | ) 56 | launch_parser.add_argument( 57 | "--partition", type=str, required=True, help="Partition to launch on." 58 | ) 59 | launch_parser.add_argument( 60 | "--account", type=str, required=True, help="Partition to launch on." 61 | ) 62 | launch_parser.add_argument("--qos", type=str, required=True, help="QOS.") 63 | 64 | # ------------ 65 | # OUTPUT 66 | # ------------ 67 | output_parser = parser.add_argument_group( 68 | title="Setting for results output", description="Where and how to save results." 69 | ) 70 | output_parser.add_argument( 71 | "--output-dir", 72 | type=str, 73 | required=True, 74 | help=("Where to dump the extracted jpeg frames"), 75 | ) 76 | output_parser.add_argument( 77 | "--slurm-output-root-dir", 78 | type=str, 79 | required=True, 80 | help=("Where to save slurm outputs"), 81 | ) 82 | return parser 83 | 84 | 85 | def decode_video(video_path: str): 86 | assert os.path.exists(video_path) 87 | video = cv2.VideoCapture(video_path) 88 | video_frames = [] 89 | while video.isOpened(): 90 | ret, frame = video.read() 91 | if ret: 92 | video_frames.append(frame) 93 | else: 94 | break 95 | return video_frames 96 | 97 | 98 | def extract_frames(video_path, sample_rate): 99 | frames = decode_video(video_path) 100 | return frames[::sample_rate] 101 | 102 | 103 | def submitit_launch(video_paths, sample_rate, save_root): 104 | for path in tqdm.tqdm(video_paths): 105 | frames = extract_frames(path, sample_rate) 106 | output_folder = os.path.join(save_root, Path(path).stem) 107 | if not os.path.exists(output_folder): 108 | os.makedirs(output_folder) 109 | for fid, frame in enumerate(frames): 110 | frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg") 111 | cv2.imwrite(frame_path, frame) 112 | print(f"Saved output to {save_root}") 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = get_args_parser() 117 | args = parser.parse_args() 118 | 119 | sav_vid_dir = args.sav_vid_dir 120 | save_root = args.output_dir 121 | sample_rate = args.sav_frame_sample_rate 122 | 123 | # List all SA-V videos 124 | mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")]) 125 | mp4_files = np.array(mp4_files) 126 | chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)] 127 | 128 | print(f"Processing videos in: {sav_vid_dir}") 129 | print(f"Processing {len(mp4_files)} files") 130 | print(f"Beginning processing in {args.n_jobs} processes") 131 | 132 | # Submitit params 133 | jobs_dir = os.path.join(args.slurm_output_root_dir, "%j") 134 | cpus_per_task = 4 135 | executor = submitit.AutoExecutor(folder=jobs_dir) 136 | executor.update_parameters( 137 | timeout_min=args.timeout, 138 | gpus_per_node=0, 139 | tasks_per_node=1, 140 | slurm_array_parallelism=args.n_jobs, 141 | cpus_per_task=cpus_per_task, 142 | slurm_partition=args.partition, 143 | slurm_account=args.account, 144 | slurm_qos=args.qos, 145 | ) 146 | executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"]) 147 | 148 | # Launch 149 | jobs = [] 150 | with executor.batch(): 151 | for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)): 152 | job = executor.submit( 153 | submitit_launch, 154 | video_paths=mp4_chunk, 155 | sample_rate=sample_rate, 156 | save_root=save_root, 157 | ) 158 | jobs.append(job) 159 | 160 | for j in jobs: 161 | print(f"Slurm JobID: {j.job_id}") 162 | print(f"Saving outputs to {save_root}") 163 | print(f"Slurm outputs at {args.slurm_output_root_dir}") 164 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | import sys 11 | import traceback 12 | from argparse import ArgumentParser 13 | 14 | import submitit 15 | import torch 16 | 17 | from hydra import compose, initialize_config_module 18 | from hydra.utils import instantiate 19 | 20 | from iopath.common.file_io import g_pathmgr 21 | from omegaconf import OmegaConf 22 | 23 | from training.utils.train_utils import makedir, register_omegaconf_resolvers 24 | 25 | os.environ["HYDRA_FULL_ERROR"] = "1" 26 | 27 | 28 | def single_proc_run(local_rank, main_port, cfg, world_size, node_rank, master_addr): 29 | """Single GPU process""" 30 | os.environ["MASTER_ADDR"] = master_addr 31 | os.environ["MASTER_PORT"] = str(main_port) 32 | os.environ["RANK"] = str(node_rank * cfg.launcher.gpus_per_node + local_rank) 33 | os.environ["LOCAL_RANK"] = str(local_rank) 34 | os.environ["WORLD_SIZE"] = str(world_size) 35 | try: 36 | register_omegaconf_resolvers() 37 | except Exception as e: 38 | logging.info(e) 39 | 40 | trainer = instantiate(cfg.trainer, _recursive_=False) 41 | trainer.run() 42 | 43 | 44 | def single_node_runner(cfg, main_port: int, node_rank: int = 0, master_addr: str = "localhost"): 45 | num_proc = cfg.launcher.gpus_per_node 46 | world_size = cfg.launcher.gpus_per_node * cfg.launcher.num_nodes 47 | torch.multiprocessing.set_start_method( 48 | "spawn" 49 | ) # CUDA runtime does not support `fork` 50 | if num_proc == 1: 51 | # directly call single_proc so we can easily set breakpoints 52 | # mp.spawn does not let us set breakpoints 53 | single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=world_size, node_rank=node_rank, master_addr=master_addr) 54 | else: 55 | mp_runner = torch.multiprocessing.start_processes 56 | args = (main_port, cfg, world_size, node_rank, master_addr) 57 | mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn") 58 | 59 | 60 | def format_exception(e: Exception, limit=20): 61 | traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit)) 62 | return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}" 63 | 64 | 65 | class SubmititRunner(submitit.helpers.Checkpointable): 66 | """A callable which is passed to submitit to launch the jobs.""" 67 | 68 | def __init__(self, port, cfg): 69 | self.cfg = cfg 70 | self.port = port 71 | self.has_setup = False 72 | 73 | def run_trainer(self): 74 | job_env = submitit.JobEnvironment() 75 | # Need to add this again so the hydra.job.set_env PYTHONPATH 76 | # is also set when launching jobs. 77 | add_pythonpath_to_sys_path() 78 | os.environ["MASTER_ADDR"] = job_env.hostnames[0] 79 | os.environ["MASTER_PORT"] = str(self.port) 80 | os.environ["RANK"] = str(job_env.global_rank) 81 | os.environ["LOCAL_RANK"] = str(job_env.local_rank) 82 | os.environ["WORLD_SIZE"] = str(job_env.num_tasks) 83 | 84 | register_omegaconf_resolvers() 85 | cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False) 86 | cfg_resolved = OmegaConf.create(cfg_resolved) 87 | 88 | trainer = instantiate(cfg_resolved.trainer, _recursive_=False) 89 | trainer.run() 90 | 91 | def __call__(self): 92 | job_env = submitit.JobEnvironment() 93 | self.setup_job_info(job_env.job_id, job_env.global_rank) 94 | try: 95 | self.run_trainer() 96 | except Exception as e: 97 | # Log the exception. Then raise it again (as what SubmititRunner currently does). 98 | message = format_exception(e) 99 | logging.error(message) 100 | raise e 101 | 102 | def setup_job_info(self, job_id, rank): 103 | """Set up slurm job info""" 104 | self.job_info = { 105 | "job_id": job_id, 106 | "rank": rank, 107 | "cluster": self.cfg.get("cluster", None), 108 | "experiment_log_dir": self.cfg.launcher.experiment_log_dir, 109 | } 110 | 111 | self.has_setup = True 112 | 113 | 114 | def add_pythonpath_to_sys_path(): 115 | if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]: 116 | return 117 | sys.path = os.environ["PYTHONPATH"].split(":") + sys.path 118 | 119 | 120 | def main(args, cfg) -> None: 121 | 122 | if cfg.launcher.experiment_log_dir is None: 123 | cfg.launcher.experiment_log_dir = os.path.join( 124 | os.getcwd(), "sam2_logs", args.config 125 | ) 126 | print("###################### Train App Config ####################") 127 | print(OmegaConf.to_yaml(cfg)) 128 | print("############################################################") 129 | 130 | add_pythonpath_to_sys_path() 131 | makedir(cfg.launcher.experiment_log_dir) 132 | with g_pathmgr.open( 133 | os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w" 134 | ) as f: 135 | f.write(OmegaConf.to_yaml(cfg)) 136 | 137 | cfg_resolved = OmegaConf.to_container(cfg, resolve=False) 138 | cfg_resolved = OmegaConf.create(cfg_resolved) 139 | 140 | with g_pathmgr.open( 141 | os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w" 142 | ) as f: 143 | f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True)) 144 | 145 | submitit_conf = cfg.get("submitit", None) 146 | assert submitit_conf is not None, "Missing submitit config" 147 | 148 | submitit_dir = cfg.launcher.experiment_log_dir 149 | submitit_dir = os.path.join(submitit_dir, "submitit_logs") 150 | # Priotrize cmd line args 151 | cfg.launcher.gpus_per_node = ( 152 | args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node 153 | ) 154 | cfg.launcher.num_nodes = ( 155 | args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes 156 | ) 157 | submitit_conf.use_cluster = ( 158 | args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster 159 | ) 160 | if submitit_conf.use_cluster: 161 | executor = submitit.AutoExecutor(folder=submitit_dir) 162 | submitit_conf.partition = ( 163 | args.partition 164 | if args.partition is not None 165 | else submitit_conf.get("partition", None) 166 | ) 167 | submitit_conf.account = ( 168 | args.account 169 | if args.account is not None 170 | else submitit_conf.get("account", None) 171 | ) 172 | submitit_conf.qos = ( 173 | args.qos if args.qos is not None else submitit_conf.get("qos", None) 174 | ) 175 | job_kwargs = { 176 | "timeout_min": 60 * submitit_conf.timeout_hour, 177 | "name": ( 178 | submitit_conf.name if hasattr(submitit_conf, "name") else args.config 179 | ), 180 | "slurm_partition": submitit_conf.partition, 181 | "gpus_per_node": cfg.launcher.gpus_per_node, 182 | "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU 183 | "cpus_per_task": submitit_conf.cpus_per_task, 184 | "nodes": cfg.launcher.num_nodes, 185 | "slurm_additional_parameters": { 186 | "exclude": " ".join(submitit_conf.get("exclude_nodes", [])), 187 | }, 188 | } 189 | if "include_nodes" in submitit_conf: 190 | assert ( 191 | len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes 192 | ), "Not enough nodes" 193 | job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( 194 | submitit_conf["include_nodes"] 195 | ) 196 | if submitit_conf.account is not None: 197 | job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account 198 | if submitit_conf.qos is not None: 199 | job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos 200 | 201 | if submitit_conf.get("mem_gb", None) is not None: 202 | job_kwargs["mem_gb"] = submitit_conf.mem_gb 203 | elif submitit_conf.get("mem", None) is not None: 204 | job_kwargs["slurm_mem"] = submitit_conf.mem 205 | 206 | if submitit_conf.get("constraints", None) is not None: 207 | job_kwargs["slurm_constraint"] = submitit_conf.constraints 208 | 209 | if submitit_conf.get("comment", None) is not None: 210 | job_kwargs["slurm_comment"] = submitit_conf.comment 211 | 212 | # Supports only cpu-bind option within srun_args. New options can be added here 213 | if submitit_conf.get("srun_args", None) is not None: 214 | job_kwargs["slurm_srun_args"] = [] 215 | if submitit_conf.srun_args.get("cpu_bind", None) is not None: 216 | job_kwargs["slurm_srun_args"].extend( 217 | ["--cpu-bind", submitit_conf.srun_args.cpu_bind] 218 | ) 219 | 220 | print("###################### SLURM Config ####################") 221 | print(job_kwargs) 222 | print("##########################################") 223 | executor.update_parameters(**job_kwargs) 224 | 225 | main_port = random.randint( 226 | submitit_conf.port_range[0], submitit_conf.port_range[1] 227 | ) 228 | runner = SubmititRunner(main_port, cfg) 229 | job = executor.submit(runner) 230 | print(f"Submitit Job ID: {job.job_id}") 231 | runner.setup_job_info(job.job_id, rank=0) 232 | else: 233 | # Handle the master_addr separately 234 | master_addr = args.master_addr if args.master_addr else "localhost" 235 | 236 | main_port = args.main_port if args.main_port else random.randint( 237 | submitit_conf.port_range[0], submitit_conf.port_range[1] 238 | ) 239 | if 'SLURM_PROCID' in os.environ: 240 | node_rank = int(os.environ['SLURM_PROCID']) 241 | else: 242 | node_rank = 0 243 | single_node_runner(cfg, main_port, node_rank=node_rank, master_addr=master_addr) 244 | 245 | 246 | if __name__ == "__main__": 247 | 248 | initialize_config_module("sam2", version_base="1.2") 249 | parser = ArgumentParser() 250 | parser.add_argument( 251 | "-c", 252 | "--config", 253 | required=True, 254 | type=str, 255 | help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)", 256 | ) 257 | parser.add_argument( 258 | "--use-cluster", 259 | type=int, 260 | default=None, 261 | help="whether to launch on a cluster, 0: run locally, 1: run on a cluster", 262 | ) 263 | parser.add_argument("--partition", type=str, default=None, help="SLURM partition") 264 | parser.add_argument("--account", type=str, default=None, help="SLURM account") 265 | parser.add_argument("--qos", type=str, default=None, help="SLURM qos") 266 | parser.add_argument( 267 | "--num-gpus", type=int, default=None, help="number of GPUS per node" 268 | ) 269 | parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes") 270 | parser.add_argument("--master-addr", type=str, default=None, help="Master node address") 271 | parser.add_argument("--main-port", type=int, default=None, help="Main port for communication") 272 | parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset, overrides cfg.dataset.folder") 273 | parser.add_argument("--output-path", type=str, default=None, help="Path to the experiment output, overrides cfg.launcher.experiment_log_dir") 274 | args = parser.parse_args() 275 | args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None 276 | register_omegaconf_resolvers() 277 | 278 | # Override dataset folder and experiment output path if specified 279 | cfg = compose(config_name=args.config) 280 | if args.dataset_path is not None: 281 | cfg.dataset.folder = args.dataset_path 282 | if args.output_path is not None: 283 | 284 | cfg.launcher.experiment_log_dir = args.output_path 285 | 286 | main(args, cfg) -------------------------------------------------------------------------------- /training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /training/utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/__pycache__/checkpoint_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/checkpoint_utils.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/__pycache__/data_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/data_utils.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/__pycache__/distributed.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/distributed.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/__pycache__/logger.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/logger.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/__pycache__/train_utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/train_utils.cpython-312.pyc -------------------------------------------------------------------------------- /training/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Misc functions, including distributed helpers. 9 | 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | 13 | from dataclasses import dataclass 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import torch 17 | 18 | from PIL import Image as PILImage 19 | from tensordict import tensorclass 20 | 21 | 22 | @tensorclass 23 | class BatchedVideoMetaData: 24 | """ 25 | This class represents metadata about a batch of videos. 26 | Attributes: 27 | unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id) 28 | frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch. 29 | """ 30 | 31 | unique_objects_identifier: torch.LongTensor 32 | frame_orig_size: torch.LongTensor 33 | 34 | 35 | @tensorclass 36 | class BatchedVideoDatapoint: 37 | """ 38 | This class represents a batch of videos with associated annotations and metadata. 39 | Attributes: 40 | img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch. 41 | obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch. 42 | masks: A [TxOxHxW] tensor containing binary masks for each object in the batch. 43 | metadata: An instance of BatchedVideoMetaData containing metadata about the batch. 44 | dict_key: A string key used to identify the batch. 45 | """ 46 | 47 | img_batch: torch.FloatTensor 48 | obj_to_frame_idx: torch.IntTensor 49 | masks: torch.BoolTensor 50 | metadata: BatchedVideoMetaData 51 | 52 | dict_key: str 53 | 54 | def pin_memory(self, device=None): 55 | return self.apply(torch.Tensor.pin_memory, device=device) 56 | 57 | @property 58 | def num_frames(self) -> int: 59 | """ 60 | Returns the number of frames per video. 61 | """ 62 | return self.batch_size[0] 63 | 64 | @property 65 | def num_videos(self) -> int: 66 | """ 67 | Returns the number of videos in the batch. 68 | """ 69 | return self.img_batch.shape[1] 70 | 71 | @property 72 | def flat_obj_to_img_idx(self) -> torch.IntTensor: 73 | """ 74 | Returns a flattened tensor containing the object to img index. 75 | The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW] 76 | """ 77 | frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1) 78 | flat_idx = video_idx * self.num_frames + frame_idx 79 | return flat_idx 80 | 81 | @property 82 | def flat_img_batch(self) -> torch.FloatTensor: 83 | """ 84 | Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW] 85 | """ 86 | 87 | return self.img_batch.transpose(0, 1).flatten(0, 1) 88 | 89 | 90 | @dataclass 91 | class Object: 92 | # Id of the object in the media 93 | object_id: int 94 | # Index of the frame in the media (0 if single image) 95 | frame_index: int 96 | segment: Union[torch.Tensor, dict] # RLE dict or binary mask 97 | 98 | 99 | @dataclass 100 | class Frame: 101 | data: Union[torch.Tensor, PILImage.Image] 102 | objects: List[Object] 103 | 104 | 105 | @dataclass 106 | class VideoDatapoint: 107 | """Refers to an image/video and all its annotations""" 108 | 109 | frames: List[Frame] 110 | video_id: int 111 | size: Tuple[int, int] 112 | 113 | 114 | def collate_fn( 115 | batch: List[VideoDatapoint], 116 | dict_key, 117 | ) -> BatchedVideoDatapoint: 118 | """ 119 | Args: 120 | batch: A list of VideoDatapoint instances. 121 | dict_key (str): A string key used to identify the batch. 122 | """ 123 | img_batch = [] 124 | for video in batch: 125 | img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)] 126 | 127 | img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4)) 128 | T = img_batch.shape[0] 129 | # Prepare data structures for sequential processing. Per-frame processing but batched across videos. 130 | step_t_objects_identifier = [[] for _ in range(T)] 131 | step_t_frame_orig_size = [[] for _ in range(T)] 132 | 133 | step_t_masks = [[] for _ in range(T)] 134 | step_t_obj_to_frame_idx = [ 135 | [] for _ in range(T) 136 | ] # List to store frame indices for each time step 137 | 138 | for video_idx, video in enumerate(batch): 139 | orig_video_id = video.video_id 140 | orig_frame_size = video.size 141 | for t, frame in enumerate(video.frames): 142 | objects = frame.objects 143 | for obj in objects: 144 | orig_obj_id = obj.object_id 145 | orig_frame_idx = obj.frame_index 146 | step_t_obj_to_frame_idx[t].append( 147 | torch.tensor([t, video_idx], dtype=torch.int) 148 | ) 149 | step_t_masks[t].append(obj.segment.to(torch.bool)) 150 | step_t_objects_identifier[t].append( 151 | torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx]) 152 | ) 153 | step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size)) 154 | 155 | obj_to_frame_idx = torch.stack( 156 | [ 157 | torch.stack(obj_to_frame_idx, dim=0) 158 | for obj_to_frame_idx in step_t_obj_to_frame_idx 159 | ], 160 | dim=0, 161 | ) 162 | masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0) 163 | objects_identifier = torch.stack( 164 | [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0 165 | ) 166 | frame_orig_size = torch.stack( 167 | [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0 168 | ) 169 | return BatchedVideoDatapoint( 170 | img_batch=img_batch, 171 | obj_to_frame_idx=obj_to_frame_idx, 172 | masks=masks, 173 | metadata=BatchedVideoMetaData( 174 | unique_objects_identifier=objects_identifier, 175 | frame_orig_size=frame_orig_size, 176 | ), 177 | dict_key=dict_key, 178 | batch_size=[T], 179 | ) 180 | -------------------------------------------------------------------------------- /training/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py 8 | import atexit 9 | import functools 10 | import logging 11 | import sys 12 | import uuid 13 | from typing import Any, Dict, Optional, Union 14 | 15 | from hydra.utils import instantiate 16 | 17 | from iopath.common.file_io import g_pathmgr 18 | from numpy import ndarray 19 | from torch import Tensor 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | from training.utils.train_utils import get_machine_local_and_dist_rank, makedir 23 | 24 | Scalar = Union[Tensor, ndarray, int, float] 25 | 26 | 27 | def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): 28 | makedir(log_dir) 29 | summary_writer_method = SummaryWriter 30 | return TensorBoardLogger( 31 | path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs 32 | ) 33 | 34 | 35 | class TensorBoardWriterWrapper: 36 | """ 37 | A wrapper around a SummaryWriter object. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | path: str, 43 | *args: Any, 44 | filename_suffix: str = None, 45 | summary_writer_method: Any = SummaryWriter, 46 | **kwargs: Any, 47 | ) -> None: 48 | """Create a new TensorBoard logger. 49 | On construction, the logger creates a new events file that logs 50 | will be written to. If the environment variable `RANK` is defined, 51 | logger will only log if RANK = 0. 52 | 53 | NOTE: If using the logger with distributed training: 54 | - This logger can call collective operations 55 | - Logs will be written on rank 0 only 56 | - Logger must be constructed synchronously *after* initializing distributed process group. 57 | 58 | Args: 59 | path (str): path to write logs to 60 | *args, **kwargs: Extra arguments to pass to SummaryWriter 61 | """ 62 | self._writer: Optional[SummaryWriter] = None 63 | _, self._rank = get_machine_local_and_dist_rank() 64 | self._path: str = path 65 | if self._rank == 0: 66 | logging.info( 67 | f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" 68 | ) 69 | self._writer = summary_writer_method( 70 | log_dir=path, 71 | *args, 72 | filename_suffix=filename_suffix or str(uuid.uuid4()), 73 | **kwargs, 74 | ) 75 | else: 76 | logging.debug( 77 | f"Not logging meters on this host because env RANK: {self._rank} != 0" 78 | ) 79 | atexit.register(self.close) 80 | 81 | @property 82 | def writer(self) -> Optional[SummaryWriter]: 83 | return self._writer 84 | 85 | @property 86 | def path(self) -> str: 87 | return self._path 88 | 89 | def flush(self) -> None: 90 | """Writes pending logs to disk.""" 91 | 92 | if not self._writer: 93 | return 94 | 95 | self._writer.flush() 96 | 97 | def close(self) -> None: 98 | """Close writer, flushing pending logs to disk. 99 | Logs cannot be written after `close` is called. 100 | """ 101 | 102 | if not self._writer: 103 | return 104 | 105 | self._writer.close() 106 | self._writer = None 107 | 108 | 109 | class TensorBoardLogger(TensorBoardWriterWrapper): 110 | """ 111 | A simple logger for TensorBoard. 112 | """ 113 | 114 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 115 | """Add multiple scalar values to TensorBoard. 116 | 117 | Args: 118 | payload (dict): dictionary of tag name and scalar value 119 | step (int, Optional): step value to record 120 | """ 121 | if not self._writer: 122 | return 123 | for k, v in payload.items(): 124 | self.log(k, v, step) 125 | 126 | def log(self, name: str, data: Scalar, step: int) -> None: 127 | """Add scalar data to TensorBoard. 128 | 129 | Args: 130 | name (string): tag name used to group scalars 131 | data (float/int/Tensor): scalar data to log 132 | step (int, optional): step value to record 133 | """ 134 | if not self._writer: 135 | return 136 | self._writer.add_scalar(name, data, global_step=step, new_style=True) 137 | 138 | def log_hparams( 139 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 140 | ) -> None: 141 | """Add hyperparameter data to TensorBoard. 142 | 143 | Args: 144 | hparams (dict): dictionary of hyperparameter names and corresponding values 145 | meters (dict): dictionary of name of meter and corersponding values 146 | """ 147 | if not self._writer: 148 | return 149 | self._writer.add_hparams(hparams, meters) 150 | 151 | 152 | class Logger: 153 | """ 154 | A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. 155 | """ 156 | 157 | def __init__(self, logging_conf): 158 | # allow turning off TensorBoard with "should_log: false" in config 159 | tb_config = logging_conf.tensorboard_writer 160 | tb_should_log = tb_config and tb_config.pop("should_log", True) 161 | self.tb_logger = instantiate(tb_config) if tb_should_log else None 162 | 163 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: 164 | if self.tb_logger: 165 | self.tb_logger.log_dict(payload, step) 166 | 167 | def log(self, name: str, data: Scalar, step: int) -> None: 168 | if self.tb_logger: 169 | self.tb_logger.log(name, data, step) 170 | 171 | def log_hparams( 172 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] 173 | ) -> None: 174 | if self.tb_logger: 175 | self.tb_logger.log_hparams(hparams, meters) 176 | 177 | 178 | # cache the opened file object, so that different calls to `setup_logger` 179 | # with the same file name can safely write to the same file. 180 | @functools.lru_cache(maxsize=None) 181 | def _cached_log_stream(filename): 182 | # we tune the buffering value so that the logs are updated 183 | # frequently. 184 | log_buffer_kb = 10 * 1024 # 10KB 185 | io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) 186 | atexit.register(io.close) 187 | return io 188 | 189 | 190 | def setup_logging( 191 | name, 192 | output_dir=None, 193 | rank=0, 194 | log_level_primary="INFO", 195 | log_level_secondary="ERROR", 196 | ): 197 | """ 198 | Setup various logging streams: stdout and file handlers. 199 | For file handlers, we only setup for the master gpu. 200 | """ 201 | # get the filename if we want to log to the file as well 202 | log_filename = None 203 | if output_dir: 204 | makedir(output_dir) 205 | if rank == 0: 206 | log_filename = f"{output_dir}/log.txt" 207 | 208 | logger = logging.getLogger(name) 209 | logger.setLevel(log_level_primary) 210 | 211 | # create formatter 212 | FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" 213 | formatter = logging.Formatter(FORMAT) 214 | 215 | # Cleanup any existing handlers 216 | for h in logger.handlers: 217 | logger.removeHandler(h) 218 | logger.root.handlers = [] 219 | 220 | # setup the console handler 221 | console_handler = logging.StreamHandler(sys.stdout) 222 | console_handler.setFormatter(formatter) 223 | logger.addHandler(console_handler) 224 | if rank == 0: 225 | console_handler.setLevel(log_level_primary) 226 | else: 227 | console_handler.setLevel(log_level_secondary) 228 | 229 | # we log to file as well if user wants 230 | if log_filename and rank == 0: 231 | file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) 232 | file_handler.setLevel(log_level_primary) 233 | file_handler.setFormatter(formatter) 234 | logger.addHandler(file_handler) 235 | 236 | logging.root = logger 237 | 238 | 239 | def shutdown_logging(): 240 | """ 241 | After training is done, we ensure to shut down all the logger streams. 242 | """ 243 | logging.info("Shutting down loggers...") 244 | handlers = logging.root.handlers 245 | for handler in handlers: 246 | handler.close() 247 | -------------------------------------------------------------------------------- /training/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | import os 10 | import random 11 | import re 12 | from datetime import timedelta 13 | from typing import Optional 14 | 15 | import hydra 16 | 17 | import numpy as np 18 | import omegaconf 19 | import torch 20 | import torch.distributed as dist 21 | from iopath.common.file_io import g_pathmgr 22 | from omegaconf import OmegaConf 23 | 24 | 25 | def multiply_all(*args): 26 | return np.prod(np.array(args)).item() 27 | 28 | 29 | def collect_dict_keys(config): 30 | """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" 31 | val_keys = [] 32 | # If the this config points to the collate function, then it has a key 33 | if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): 34 | val_keys.append(config["dict_key"]) 35 | else: 36 | # Recursively proceed 37 | for v in config.values(): 38 | if isinstance(v, type(config)): 39 | val_keys.extend(collect_dict_keys(v)) 40 | elif isinstance(v, omegaconf.listconfig.ListConfig): 41 | for item in v: 42 | if isinstance(item, type(config)): 43 | val_keys.extend(collect_dict_keys(item)) 44 | return val_keys 45 | 46 | 47 | class Phase: 48 | TRAIN = "train" 49 | VAL = "val" 50 | 51 | 52 | def register_omegaconf_resolvers(): 53 | OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) 54 | OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) 55 | OmegaConf.register_new_resolver("add", lambda x, y: x + y) 56 | OmegaConf.register_new_resolver("times", multiply_all) 57 | OmegaConf.register_new_resolver("divide", lambda x, y: x / y) 58 | OmegaConf.register_new_resolver("pow", lambda x, y: x**y) 59 | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) 60 | OmegaConf.register_new_resolver("range", lambda x: list(range(x))) 61 | OmegaConf.register_new_resolver("int", lambda x: int(x)) 62 | OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) 63 | OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) 64 | 65 | 66 | def setup_distributed_backend(backend, timeout_mins): 67 | """ 68 | Initialize torch.distributed and set the CUDA device. 69 | Expects environment variables to be set as per 70 | https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization 71 | along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. 72 | """ 73 | # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins 74 | # of waiting 75 | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" 76 | logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") 77 | dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) 78 | return dist.get_rank() 79 | 80 | 81 | def get_machine_local_and_dist_rank(): 82 | """ 83 | Get the distributed and local rank of the current gpu. 84 | """ 85 | local_rank = int(os.environ.get("LOCAL_RANK", None)) 86 | distributed_rank = int(os.environ.get("RANK", None)) 87 | assert ( 88 | local_rank is not None and distributed_rank is not None 89 | ), "Please the set the RANK and LOCAL_RANK environment variables." 90 | return local_rank, distributed_rank 91 | 92 | 93 | def print_cfg(cfg): 94 | """ 95 | Supports printing both Hydra DictConfig and also the AttrDict config 96 | """ 97 | logging.info("Training with config:") 98 | logging.info(OmegaConf.to_yaml(cfg)) 99 | 100 | 101 | def set_seeds(seed_value, max_epochs, dist_rank): 102 | """ 103 | Set the python random, numpy and torch seed for each gpu. Also set the CUDA 104 | seeds if the CUDA is available. This ensures deterministic nature of the training. 105 | """ 106 | # Since in the pytorch sampler, we increment the seed by 1 for every epoch. 107 | seed_value = (seed_value + dist_rank) * max_epochs 108 | logging.info(f"MACHINE SEED: {seed_value}") 109 | random.seed(seed_value) 110 | np.random.seed(seed_value) 111 | torch.manual_seed(seed_value) 112 | if torch.cuda.is_available(): 113 | torch.cuda.manual_seed_all(seed_value) 114 | 115 | 116 | def makedir(dir_path): 117 | """ 118 | Create the directory if it does not exist. 119 | """ 120 | is_success = False 121 | try: 122 | if not g_pathmgr.exists(dir_path): 123 | g_pathmgr.mkdirs(dir_path) 124 | is_success = True 125 | except BaseException: 126 | logging.info(f"Error creating directory: {dir_path}") 127 | return is_success 128 | 129 | 130 | def is_dist_avail_and_initialized(): 131 | if not dist.is_available(): 132 | return False 133 | if not dist.is_initialized(): 134 | return False 135 | return True 136 | 137 | 138 | def get_amp_type(amp_type: Optional[str] = None): 139 | if amp_type is None: 140 | return None 141 | assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." 142 | if amp_type == "bfloat16": 143 | return torch.bfloat16 144 | else: 145 | return torch.float16 146 | 147 | 148 | def log_env_variables(): 149 | env_keys = sorted(list(os.environ.keys())) 150 | st = "" 151 | for k in env_keys: 152 | v = os.environ[k] 153 | st += f"{k}={v}\n" 154 | logging.info("Logging ENV_VARIABLES") 155 | logging.info(st) 156 | 157 | 158 | class AverageMeter: 159 | """Computes and stores the average and current value""" 160 | 161 | def __init__(self, name, device, fmt=":f"): 162 | self.name = name 163 | self.fmt = fmt 164 | self.device = device 165 | self.reset() 166 | 167 | def reset(self): 168 | self.val = 0 169 | self.avg = 0 170 | self.sum = 0 171 | self.count = 0 172 | self._allow_updates = True 173 | 174 | def update(self, val, n=1): 175 | self.val = val 176 | self.sum += val * n 177 | self.count += n 178 | self.avg = self.sum / self.count 179 | 180 | def __str__(self): 181 | fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" 182 | return fmtstr.format(**self.__dict__) 183 | 184 | 185 | class MemMeter: 186 | """Computes and stores the current, avg, and max of peak Mem usage per iteration""" 187 | 188 | def __init__(self, name, device, fmt=":f"): 189 | self.name = name 190 | self.fmt = fmt 191 | self.device = device 192 | self.reset() 193 | 194 | def reset(self): 195 | self.val = 0 # Per iteration max usage 196 | self.avg = 0 # Avg per iteration max usage 197 | self.peak = 0 # Peak usage for lifetime of program 198 | self.sum = 0 199 | self.count = 0 200 | self._allow_updates = True 201 | 202 | def update(self, n=1, reset_peak_usage=True): 203 | self.val = torch.cuda.max_memory_allocated() // 1e9 204 | self.sum += self.val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | self.peak = max(self.peak, self.val) 208 | if reset_peak_usage: 209 | torch.cuda.reset_peak_memory_stats() 210 | 211 | def __str__(self): 212 | fmtstr = ( 213 | "{name}: {val" 214 | + self.fmt 215 | + "} ({avg" 216 | + self.fmt 217 | + "}/{peak" 218 | + self.fmt 219 | + "})" 220 | ) 221 | return fmtstr.format(**self.__dict__) 222 | 223 | 224 | def human_readable_time(time_seconds): 225 | time = int(time_seconds) 226 | minutes, seconds = divmod(time, 60) 227 | hours, minutes = divmod(minutes, 60) 228 | days, hours = divmod(hours, 24) 229 | return f"{days:02}d {hours:02}h {minutes:02}m" 230 | 231 | 232 | class DurationMeter: 233 | def __init__(self, name, device, fmt=":f"): 234 | self.name = name 235 | self.device = device 236 | self.fmt = fmt 237 | self.val = 0 238 | 239 | def reset(self): 240 | self.val = 0 241 | 242 | def update(self, val): 243 | self.val = val 244 | 245 | def add(self, val): 246 | self.val += val 247 | 248 | def __str__(self): 249 | return f"{self.name}: {human_readable_time(self.val)}" 250 | 251 | 252 | class ProgressMeter: 253 | def __init__(self, num_batches, meters, real_meters, prefix=""): 254 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 255 | self.meters = meters 256 | self.real_meters = real_meters 257 | self.prefix = prefix 258 | 259 | def display(self, batch, enable_print=False): 260 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 261 | entries += [str(meter) for meter in self.meters] 262 | entries += [ 263 | " | ".join( 264 | [ 265 | f"{os.path.join(name, subname)}: {val:.4f}" 266 | for subname, val in meter.compute().items() 267 | ] 268 | ) 269 | for name, meter in self.real_meters.items() 270 | ] 271 | logging.info(" | ".join(entries)) 272 | if enable_print: 273 | print(" | ".join(entries)) 274 | 275 | def _get_batch_fmtstr(self, num_batches): 276 | num_digits = len(str(num_batches // 1)) 277 | fmt = "{:" + str(num_digits) + "d}" 278 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 279 | 280 | 281 | def get_resume_checkpoint(checkpoint_save_dir): 282 | if not g_pathmgr.isdir(checkpoint_save_dir): 283 | return None 284 | ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") 285 | if not g_pathmgr.isfile(ckpt_file): 286 | return None 287 | 288 | return ckpt_file 289 | --------------------------------------------------------------------------------