├── LICENSE
├── README.md
├── app.py
├── checkpoints
└── README.md
├── download.sh
├── examples
└── infer_CT_LUNA25.py
├── gitignore
├── medsam2_infer_3D_CT.py
├── medsam2_infer_video.py
├── multi_node_train.sh
├── notebooks
├── MedSAM2_Inference_Video.ipynb
└── MedSAM2_inference_CT_Lesion.ipynb
├── pyproject.toml
├── sam2
├── _C.so
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── build_sam.cpython-312.pyc
│ ├── sam2_image_predictor.cpython-312.pyc
│ └── sam2_video_predictor_npz.cpython-312.pyc
├── build_sam.py
├── configs
│ ├── sam2.1_hiera_t512.yaml
│ └── sam2.1_hiera_tiny_finetune512.yaml
├── csrc
│ └── connected_components.cu
├── modeling
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ ├── memory_attention.cpython-312.pyc
│ │ ├── memory_encoder.cpython-312.pyc
│ │ ├── position_encoding.cpython-312.pyc
│ │ ├── sam2_base.cpython-312.pyc
│ │ └── sam2_utils.cpython-312.pyc
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ ├── hieradet.cpython-312.pyc
│ │ │ ├── image_encoder.cpython-312.pyc
│ │ │ └── utils.cpython-312.pyc
│ │ ├── hieradet.py
│ │ ├── image_encoder.py
│ │ └── utils.py
│ ├── memory_attention.py
│ ├── memory_encoder.py
│ ├── position_encoding.py
│ ├── sam
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ ├── mask_decoder.cpython-312.pyc
│ │ │ ├── prompt_encoder.cpython-312.pyc
│ │ │ └── transformer.cpython-312.pyc
│ │ ├── mask_decoder.py
│ │ ├── prompt_encoder.py
│ │ └── transformer.py
│ ├── sam2_base.py
│ └── sam2_utils.py
├── sam2_image_predictor.py
├── sam2_video_predictor.py
├── sam2_video_predictor_npz.py
├── sam2_video_trainer.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── misc.cpython-312.pyc
│ └── transforms.cpython-312.pyc
│ ├── amg.py
│ ├── misc.py
│ └── transforms.py
├── setup.py
└── training
├── __init__.py
├── __pycache__
├── __init__.cpython-312.pyc
├── loss_fns.cpython-312.pyc
├── optimizer.cpython-312.pyc
└── trainer.cpython-312.pyc
├── assets
├── MOSE_sample_train_list.txt
└── MOSE_sample_val_list.txt
├── dataset
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ ├── sam2_datasets.cpython-312.pyc
│ ├── transforms.cpython-312.pyc
│ ├── utils.cpython-312.pyc
│ ├── vos_dataset.cpython-312.pyc
│ ├── vos_raw_dataset.cpython-312.pyc
│ ├── vos_sampler.cpython-312.pyc
│ └── vos_segment_loader.cpython-312.pyc
├── sam2_datasets.py
├── transforms.py
├── utils.py
├── vos_dataset.py
├── vos_raw_dataset.py
├── vos_sampler.py
└── vos_segment_loader.py
├── loss_fns.py
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-312.pyc
│ └── sam2.cpython-312.pyc
└── sam2.py
├── optimizer.py
├── scripts
└── sav_frame_extraction_submitit.py
├── train.py
├── trainer.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-312.pyc
├── checkpoint_utils.cpython-312.pyc
├── data_utils.cpython-312.pyc
├── distributed.cpython-312.pyc
├── logger.cpython-312.pyc
└── train_utils.cpython-312.pyc
├── checkpoint_utils.py
├── data_utils.py
├── distributed.py
├── logger.py
└── train_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MedSAM2
2 | Segment Anything in 3D Medical Images and Videos
3 |
4 |
26 |
27 | Welcome to join our [mailing list](https://forms.gle/bLxGb5SEpdLCUChQ7) to get updates. We’re also actively looking to collaborate on annotating new large-scale 3D datasets. If you have unlabeled medical images or videos and want to share them with the community, let’s connect!
28 |
29 |
30 | ## Installation
31 |
32 | - Create a virtual environment: `conda create -n medsam2 python=3.12 -y` and `conda activate medsam2`
33 | - Install [PyTorch](https://pytorch.org/get-started/locally/): `pip3 install torch torchvision` (Linux CUDA 12.4)
34 | - Download code `git clone https://github.com/bowang-lab/MedSAM2.git && cd MedSAM2` and run `pip install -e ".[dev]"`
35 | - Download checkpoints: `sh download.sh`
36 | - Optional: Please install the following dependencies for gradio
37 |
38 | ```bash
39 | sudo apt-get update
40 | sudo apt-get install ffmpeg
41 | pip install gradio==3.38.0
42 | pip install numpy==1.26.3
43 | pip install ffmpeg-python
44 | pip install moviepy
45 | ```
46 |
47 | ## Download annotated datasets
48 |
49 | - [CT_DeepLesion-MedSAM2](https://huggingface.co/datasets/wanglab/CT_DeepLesion-MedSAM2)
50 |
51 |
52 |
53 | - [LLD-MMRI-MedSAM2](https://huggingface.co/datasets/wanglab/LLD-MMRI-MedSAM2)
54 |
55 | Note: Please also cite the raw [DeepLesion](https://doi.org/10.1117/1.JMI.5.3.036501) and [LLD-MMRI](https://www.sciencedirect.com/science/article/pii/S0893608025001078) dataset paper when using these datasets.
56 |
57 | - [RVENET](https://rvenet.github.io/dataset/): Waiting for authors' approval to release the mask.
58 |
59 |
60 | ## Inference
61 |
62 | ### 3D medical image segmentation
63 |
64 | - [Colab](https://colab.research.google.com/drive/1MKna9Sg9c78LNcrVyG58cQQmaePZq2k2?usp=sharing): [MedSAM2_inference_CT_Lesion_Demo.ipynb](notebooks/MedSAM2_inference_CT_Lesion.ipynb)
65 |
66 | - CMD
67 |
68 | ```bash
69 | python medsam2_infer_3D_CT.py -i CT_DeepLesion/images -o CT_DeepLesion/segmentation
70 | ```
71 |
72 | ### Medical video segmentation
73 |
74 | - [Colab](https://colab.research.google.com/drive/16niRHqdDZMCGV7lKuagNq_r_CEHtKY1f?usp=sharing): [MedSAM2_Inference_Video_Demo.ipynb](notebooks/MedSAM2_Inference_Video.ipynb)
75 |
76 |
77 | - CMD
78 |
79 | ```bash
80 | python medsam2_infer_video.py -i input_video_path -m input_mask_path -o output_video_path
81 | ```
82 |
83 |
84 |
85 |
86 | ### Gradio demo
87 |
88 | ```bash
89 | python app.py
90 | ```
91 |
92 | ## Training
93 |
94 | Specify dataset path in `sam2/configs/sam2.1_hiera_tiny_finetune512.yaml`
95 |
96 | ```bash
97 | sbatch multi_node_train.sh
98 | ```
99 |
100 | ## Acknowledgements
101 |
102 | - We highly appreciate all the challenge organizers and dataset owners for providing the public datasets to the community.
103 | - We thank Meta AI for making the source code of [SAM2](https://github.com/facebookresearch/sam2) publicly available. Please also cite this paper when using MedSAM2.
104 |
105 |
106 | ## Bibtex
107 |
108 | ```bash
109 | @article{MedSAM2,
110 | title={MedSAM2: Segment Anything in 3D Medical Images and Videos},
111 | author={Ma, Jun and Yang, Zongxin and Kim, Sumin and Chen, Bihui and Baharoon, Mohammed and Fallahpour, Adibvafa and Asakereh, Reza and Lyu, Hongwei and Wang, Bo},
112 | journal={arXiv preprint arXiv:2504.03600},
113 | year={2025}
114 | }
115 | ```
116 | Please also cite SAM2
117 | ```
118 | @article{ravi2024sam2,
119 | title={SAM 2: Segment Anything in Images and Videos},
120 | author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
121 | journal={arXiv preprint arXiv:2408.00714},
122 | url={https://arxiv.org/abs/2408.00714},
123 | year={2024}
124 | }
125 | ```
126 |
127 |
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 |
2 | Download checkpoints `sh download.sh`
3 |
4 | - `MedSAM2_2411.pt`: The based model trained in Nov. 2024
5 | - `MedSAM2_US_Heart.pt`: Fine-tuned model for heart ultrasound video segmentation
6 | - `MedSAM2_MRI_LiverLesion.pt`: Fine-tuned model for liver lesion MRI segmentation
7 | - `MedSAM2_CTLesion.pt`: Fine-tuned model for CT lesion segmentation
8 | - `MedSAM2_latest.pt` (recommended): Latest model trained on the combination of existing public datasets and newly annotated datasets
9 |
10 |
11 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # Script to download MedSAM2 model checkpoints
3 | # Create checkpoints directory if it doesn't exist
4 | mkdir -p checkpoints
5 | # Use either wget or curl to download the checkpoints
6 | if command -v wget > /dev/null 2>&1; then
7 | CMD="wget -P checkpoints"
8 | elif command -v curl > /dev/null 2>&1; then
9 | CMD="curl -L -o"
10 | CURL=1
11 | else
12 | echo "Please install wget or curl to download the checkpoints."
13 | exit 1
14 | fi
15 | # Define the base URL for MedSAM2 models on Hugging Face
16 | HF_BASE_URL="https://huggingface.co/wanglab/MedSAM2/resolve/main"
17 | # Define the model checkpoint files (as separate variables instead of an array)
18 | MODEL1="MedSAM2_2411.pt"
19 | MODEL2="MedSAM2_US_Heart.pt"
20 | MODEL3="MedSAM2_MRI_LiverLesion.pt"
21 | MODEL4="MedSAM2_CTLesion.pt"
22 | MODEL5="MedSAM2_latest.pt"
23 |
24 | # Download each checkpoint
25 | for model in $MODEL1 $MODEL2 $MODEL3 $MODEL4 $MODEL5; do
26 | echo "Downloading ${model}..."
27 | model_url="${HF_BASE_URL}/${model}"
28 |
29 | if [ -n "$CURL" ]; then
30 | $CMD "checkpoints/${model}" "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; }
31 | else
32 | $CMD "$model_url" || { echo "Failed to download checkpoint from $model_url"; exit 1; }
33 | fi
34 | done
35 | echo "All MedSAM2 model checkpoints have been downloaded successfully to the 'checkpoints' directory."
--------------------------------------------------------------------------------
/examples/infer_CT_LUNA25.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is used to run inference on the LUNA25 dataset using the MedSAM2 CT lesion model with point prompts.
3 |
4 | Manually refined masks: https://huggingface.co/datasets/wanglab/LUNA25-MedSAM2
5 | image: https://zenodo.org/records/14223624
6 | annotation: https://zenodo.org/records/14673658
7 | """
8 |
9 | from tqdm import tqdm
10 | import os
11 | from os.path import join
12 | import pandas as pd
13 | import numpy as np
14 | import argparse
15 |
16 | from PIL import Image
17 | import SimpleITK as sitk
18 | import torch
19 | import torch.multiprocessing as mp
20 | from sam2.build_sam import build_sam2_video_predictor_npz
21 |
22 | torch.set_float32_matmul_precision('high')
23 | torch.manual_seed(2024)
24 | torch.cuda.manual_seed(2024)
25 | np.random.seed(2024)
26 |
27 | parser = argparse.ArgumentParser()
28 |
29 | parser.add_argument(
30 | '--checkpoint',
31 | type=str,
32 | default="./hg_checkpoints/MedSAM2_CTLesion.pt",
33 | help='checkpoint path',
34 | )
35 | parser.add_argument(
36 | '--cfg',
37 | type=str,
38 | default="configs/sam2.1/sam2.1_hiera_t512.yaml",
39 | help='model config',
40 | )
41 | parser.add_argument(
42 | '-i',
43 | '--imgs_path',
44 | type=str,
45 | default="/path/to/luna25_images",
46 | help='imgs path',
47 | )
48 | parser.add_argument(
49 | '-o',
50 | '--pred_save_dir',
51 | type=str,
52 | default="./segs/MedSAM2_release",
53 | help='segs path',
54 | )
55 | parser.add_argument(
56 | '--num_workers',
57 | type=int,
58 | default=2,
59 | )
60 | parser.add_argument(
61 | '--df_path',
62 | type=str,
63 | default='/path/to/LUNA25_Public_Training_Development_Data.csv',
64 | )
65 |
66 | args = parser.parse_args()
67 | imsize = 512
68 | df = pd.read_csv(args.df_path)
69 | df = df[['SeriesInstanceUID', 'CoordX', 'CoordY', 'CoordZ']]
70 |
71 | checkpoint = args.checkpoint
72 | model_cfg = args.cfg
73 | imgs_path = args.imgs_path
74 | pred_save_dir = args.pred_save_dir
75 | num_workers = args.num_workers
76 | predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint)
77 | os.makedirs(pred_save_dir, exist_ok=True)
78 |
79 |
80 | def preprocess(image_data, modality="CT", window_level=-750, window_width=1500):
81 | if modality == "CT":
82 | assert window_level is not None and window_width is not None, "CT modality requires window_level and window_width"
83 | lower_bound = window_level - window_width / 2
84 | upper_bound = window_level + window_width / 2
85 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
86 | image_data_pre = (
87 | (image_data_pre - np.min(image_data_pre))
88 | / (np.max(image_data_pre) - np.min(image_data_pre))
89 | * 255.0
90 | )
91 | else:
92 | lower_bound, upper_bound = np.percentile(
93 | image_data[image_data > 0], 0.5
94 | ), np.percentile(image_data[image_data > 0], 99.5)
95 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
96 | image_data_pre = (
97 | (image_data_pre - np.min(image_data_pre))
98 | / (np.max(image_data_pre) - np.min(image_data_pre))
99 | * 255.0
100 | )
101 | image_data_pre[image_data == 0] = 0
102 |
103 | return image_data_pre
104 |
105 |
106 | def resize_grayscale_to_rgb_and_resize(array, image_size):
107 | """
108 | Resize a 3D grayscale NumPy array to an RGB image and then resize it.
109 |
110 | Parameters:
111 | array (np.ndarray): Input array of shape (d, h, w).
112 | image_size (int): Desired size for the width and height.
113 |
114 | Returns:
115 | np.ndarray: Resized array of shape (d, 3, image_size, image_size).
116 | """
117 | d, h, w = array.shape
118 | resized_array = np.zeros((d, 3, image_size, image_size))
119 |
120 | for i in range(d):
121 | img_pil = Image.fromarray(array[i].astype(np.uint8))
122 | img_rgb = img_pil.convert("RGB")
123 | img_resized = img_rgb.resize((image_size, image_size))
124 | img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size)
125 | resized_array[i] = img_array
126 |
127 | return resized_array
128 |
129 |
130 | @torch.inference_mode()
131 | def infer_3d(mha_name):
132 | print(f'processing {mha_name}')
133 | # get the corresponding keypoints of mha_name
134 | df_file = df[df['SeriesInstanceUID'] == mha_name.replace('.mha', '')]
135 |
136 | # read and preprocess the image
137 | sitk_img = sitk.ReadImage(join(imgs_path, mha_name))
138 | img_3D = preprocess(sitk.GetArrayFromImage(sitk_img))
139 | assert np.max(img_3D) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D)}'
140 |
141 | # initialize segmentation mask
142 | segs_3D = np.zeros(img_3D.shape, dtype=np.uint8)
143 |
144 | # resize and normalize the image
145 | video_height = img_3D.shape[1]
146 | video_width = img_3D.shape[2]
147 | if video_height != imsize or video_width != imsize:
148 | img_resized = resize_grayscale_to_rgb_and_resize(img_3D, imsize) #d, 3, 512, 512
149 | else:
150 | img_resized = img_3D[:,None].repeat(3, axis=1) # d, 3, h, w
151 | img_resized = img_resized / 255.0
152 | img_resized = torch.from_numpy(img_resized).cuda()
153 | img_mean=(0.485, 0.456, 0.406)
154 | img_std=(0.229, 0.224, 0.225)
155 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda()
156 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda()
157 | img_resized -= img_mean
158 | img_resized /= img_std
159 | z_mids = []
160 | coords = []
161 |
162 | # for each point in the dataframe, get the corresponding 3D mask using keypoint prompts
163 | for index, (_, row) in enumerate(df_file.iterrows(), 1):
164 |
165 | x = row['CoordX']
166 | y = row['CoordY']
167 | z = row['CoordZ']
168 | # convert the coordinates to voxel coordinates
169 | voxel_x, voxel_y, voxel_z = sitk_img.TransformPhysicalPointToIndex((x, y, z))
170 | coords.append([voxel_x, voxel_y, voxel_z])
171 | z_mids.append(voxel_z)
172 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
173 | inference_state = predictor.init_state(img_resized, video_height, video_width)
174 |
175 | points = np.array([[voxel_x, voxel_y]], dtype=np.float32)
176 | # for labels, `1` means positive click and `0` means negative click
177 | labels = np.array([1], np.int32)
178 | # add point prompt
179 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
180 | inference_state=inference_state,
181 | frame_idx=voxel_z,
182 | obj_id=1,
183 | points=points,
184 | labels=labels,
185 | )
186 | mask_prompt = (out_mask_logits[0] > 0.0).squeeze(0).cpu().numpy().astype(np.uint8)
187 |
188 |
189 | frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=voxel_z, obj_id=1, mask=mask_prompt)
190 | segs_3D[voxel_z, ((masks[0] > 0.0).cpu().numpy())[0]] = index
191 | # propagate in the video
192 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=voxel_z, reverse=False):
193 | segs_3D[(out_frame_idx), (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = index
194 |
195 | # reverse process, delete old memory and initialize new predictor
196 | predictor.reset_state(inference_state)
197 | inference_state = predictor.init_state(img_resized, video_height, video_width)
198 | frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=voxel_z, obj_id=1, mask=mask_prompt)
199 |
200 |
201 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, start_frame_idx=voxel_z, reverse=True):
202 | segs_3D[(out_frame_idx), (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = index
203 |
204 | predictor.reset_state(inference_state)
205 |
206 | sitk.WriteImage(sitk.GetImageFromArray(segs_3D), join(pred_save_dir, mha_name.replace('.mha', '.nii.gz')))
207 |
208 | return
209 |
210 |
211 |
212 | if __name__ == '__main__':
213 | img_mha_files = os.listdir(imgs_path)
214 | img_mha_files = [x for x in img_mha_files if x.endswith('.mha')]
215 | process_files = list(set(df['SeriesInstanceUID'].values))
216 | img_mha_files = [x for x in img_mha_files if x.replace('.mha', '') in process_files]
217 |
218 | print(f'number of files to process: {len(img_mha_files)}')
219 |
220 | mp.set_start_method('spawn', force=True)
221 | with mp.Pool(processes=num_workers) as pool:
222 | with tqdm(total=len(img_mha_files)) as pbar:
223 | for i, ret in tqdm(enumerate(pool.imap_unordered(infer_3d, img_mha_files))):
224 | pbar.update()
225 |
226 |
227 |
228 |
--------------------------------------------------------------------------------
/gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .DS_Store
3 | __pycache__/
4 | *-checkpoint.ipynb
5 | .venv
6 | *.egg*
7 | build/*
8 | _C.*
9 | *.nii.gz
10 | *.csv
11 | outputs/*
12 | checkpoints/*.pt
13 | *.pt
14 |
--------------------------------------------------------------------------------
/medsam2_infer_3D_CT.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from tqdm import tqdm
3 | import os
4 | from os.path import join, basename
5 | import re
6 | import matplotlib.pyplot as plt
7 | from collections import OrderedDict
8 | import pandas as pd
9 | import numpy as np
10 | import argparse
11 |
12 | from PIL import Image
13 | import SimpleITK as sitk
14 | import torch
15 | import torch.multiprocessing as mp
16 | from sam2.build_sam import build_sam2_video_predictor_npz
17 | import SimpleITK as sitk
18 | from skimage import measure, morphology
19 |
20 | torch.set_float32_matmul_precision('high')
21 | torch.manual_seed(2024)
22 | torch.cuda.manual_seed(2024)
23 | np.random.seed(2024)
24 |
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument(
28 | '--checkpoint',
29 | type=str,
30 | default="checkpoints/MedSAM2_latest.pt",
31 | help='checkpoint path',
32 | )
33 | parser.add_argument(
34 | '--cfg',
35 | type=str,
36 | default="configs/sam2.1_hiera_t512.yaml",
37 | help='model config',
38 | )
39 |
40 | parser.add_argument(
41 | '-i',
42 | '--imgs_path',
43 | type=str,
44 | default="CT_DeepLesion/images",
45 | help='imgs path',
46 | )
47 | parser.add_argument(
48 | '--gts_path',
49 | default=None,
50 | help='simulate prompts based on ground truth',
51 | )
52 | parser.add_argument(
53 | '-o',
54 | '--pred_save_dir',
55 | type=str,
56 | default="./DeeLesion_results",
57 | help='path to save segmentation results',
58 | )
59 | # add option to propagate with either box or mask
60 | parser.add_argument(
61 | '--propagate_with_box',
62 | default=True,
63 | action='store_true',
64 | help='whether to propagate with box'
65 | )
66 |
67 | args = parser.parse_args()
68 | checkpoint = args.checkpoint
69 | model_cfg = args.cfg
70 | imgs_path = args.imgs_path
71 | gts_path = args.gts_path
72 | pred_save_dir = args.pred_save_dir
73 | os.makedirs(pred_save_dir, exist_ok=True)
74 | propagate_with_box = args.propagate_with_box
75 |
76 | def getLargestCC(segmentation):
77 | labels = measure.label(segmentation)
78 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
79 | return largestCC
80 |
81 | def dice_multi_class(preds, targets):
82 | smooth = 1.0
83 | assert preds.shape == targets.shape
84 | labels = np.unique(targets)[1:]
85 | dices = []
86 | for label in labels:
87 | pred = preds == label
88 | target = targets == label
89 | intersection = (pred * target).sum()
90 | dices.append((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))
91 | return np.mean(dices)
92 |
93 | def show_mask(mask, ax, mask_color=None, alpha=0.5):
94 | """
95 | show mask on the image
96 |
97 | Parameters
98 | ----------
99 | mask : numpy.ndarray
100 | mask of the image
101 | ax : matplotlib.axes.Axes
102 | axes to plot the mask
103 | mask_color : numpy.ndarray
104 | color of the mask
105 | alpha : float
106 | transparency of the mask
107 | """
108 | if mask_color is not None:
109 | color = np.concatenate([mask_color, np.array([alpha])], axis=0)
110 | else:
111 | color = np.array([251/255, 252/255, 30/255, alpha])
112 | h, w = mask.shape[-2:]
113 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
114 | ax.imshow(mask_image)
115 |
116 |
117 | def show_box(box, ax, edgecolor='blue'):
118 | """
119 | show bounding box on the image
120 |
121 | Parameters
122 | ----------
123 | box : numpy.ndarray
124 | bounding box coordinates in the original image
125 | ax : matplotlib.axes.Axes
126 | axes to plot the bounding box
127 | edgecolor : str
128 | color of the bounding box
129 | """
130 | x0, y0 = box[0], box[1]
131 | w, h = box[2] - box[0], box[3] - box[1]
132 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))
133 |
134 |
135 | def resize_grayscale_to_rgb_and_resize(array, image_size):
136 | """
137 | Resize a 3D grayscale NumPy array to an RGB image and then resize it.
138 |
139 | Parameters:
140 | array (np.ndarray): Input array of shape (d, h, w).
141 | image_size (int): Desired size for the width and height.
142 |
143 | Returns:
144 | np.ndarray: Resized array of shape (d, 3, image_size, image_size).
145 | """
146 | d, h, w = array.shape
147 | resized_array = np.zeros((d, 3, image_size, image_size))
148 |
149 | for i in range(d):
150 | img_pil = Image.fromarray(array[i].astype(np.uint8))
151 | img_rgb = img_pil.convert("RGB")
152 | img_resized = img_rgb.resize((image_size, image_size))
153 | img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size)
154 | resized_array[i] = img_array
155 |
156 | return resized_array
157 |
158 | def mask2D_to_bbox(gt2D, max_shift=20):
159 | y_indices, x_indices = np.where(gt2D > 0)
160 | x_min, x_max = np.min(x_indices), np.max(x_indices)
161 | y_min, y_max = np.min(y_indices), np.max(y_indices)
162 | H, W = gt2D.shape
163 | bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
164 | x_min = max(0, x_min - bbox_shift)
165 | x_max = min(W-1, x_max + bbox_shift)
166 | y_min = max(0, y_min - bbox_shift)
167 | y_max = min(H-1, y_max + bbox_shift)
168 | boxes = np.array([x_min, y_min, x_max, y_max])
169 | return boxes
170 |
171 | def mask3D_to_bbox(gt3D, max_shift=20):
172 | z_indices, y_indices, x_indices = np.where(gt3D > 0)
173 | x_min, x_max = np.min(x_indices), np.max(x_indices)
174 | y_min, y_max = np.min(y_indices), np.max(y_indices)
175 | z_min, z_max = np.min(z_indices), np.max(z_indices)
176 | D, H, W = gt3D.shape
177 | bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
178 | x_min = max(0, x_min - bbox_shift)
179 | x_max = min(W-1, x_max + bbox_shift)
180 | y_min = max(0, y_min - bbox_shift)
181 | y_max = min(H-1, y_max + bbox_shift)
182 | z_min = max(0, z_min)
183 | z_max = min(D-1, z_max)
184 | boxes3d = np.array([x_min, y_min, z_min, x_max, y_max, z_max])
185 | return boxes3d
186 |
187 |
188 | DL_info = pd.read_csv('CT_DeepLesion/DeepLesion_Dataset_Info.csv')
189 | nii_fnames = sorted(os.listdir(imgs_path))
190 | nii_fnames = [i for i in nii_fnames if i.endswith('.nii.gz')]
191 | nii_fnames = [i for i in nii_fnames if not i.startswith('._')]
192 | print(f'Processing {len(nii_fnames)} nii files')
193 | seg_info = OrderedDict()
194 | seg_info['nii_name'] = []
195 | seg_info['key_slice_index'] = []
196 | seg_info['DICOM_windows'] = []
197 | # initialized predictor
198 | predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint)
199 |
200 | for nii_fname in tqdm(nii_fnames):
201 | # get corresponding case info
202 | range_suffix = re.findall(r'\d{3}-\d{3}', nii_fname)[0]
203 | slice_range = range_suffix.split('-')
204 | slice_range = [str(int(s)) for s in slice_range]
205 | slice_range = ', '.join(slice_range)
206 | nii_image = sitk.ReadImage(join(imgs_path, nii_fname))
207 | nii_image_data = sitk.GetArrayFromImage(nii_image)
208 |
209 | case_name = re.findall(r'^(\d{6}_\d{2}_\d{2})', nii_fname)[0]
210 | case_df = DL_info[
211 | DL_info['File_name'].str.contains(case_name) &
212 | DL_info['Slice_range'].str.contains(slice_range)
213 | ].copy()
214 |
215 | segs_3D = np.zeros(nii_image_data.shape, dtype=np.uint8)
216 |
217 | for row_id, row in case_df.iterrows():
218 | # print(f'Processing {case_name} tumor {tumor_idx}')
219 | # get the key slice info
220 | lower_bound, upper_bound = row['DICOM_windows'].split(',')
221 | lower_bound, upper_bound = float(lower_bound), float(upper_bound)
222 | nii_image_data_pre = np.clip(nii_image_data, lower_bound, upper_bound)
223 | nii_image_data_pre = (nii_image_data_pre - np.min(nii_image_data_pre))/(np.max(nii_image_data_pre)-np.min(nii_image_data_pre))*255.0
224 | nii_image_data_pre = np.uint8(nii_image_data_pre)
225 | key_slice_idx = row['Key_slice_index']
226 | key_slice_idx = int(key_slice_idx)
227 | slice_range = row['Slice_range']
228 | slice_idx_start, slice_idx_end = slice_range.split(',')
229 | slice_idx_start, slice_idx_end = int(slice_idx_start), int(slice_idx_end)
230 | bbox_coords = row['Bounding_boxes']
231 | bbox_coords = bbox_coords.split(',')
232 | bbox_coords = [int(float(coord)) for coord in bbox_coords]
233 | #bbox_coords = expand_box(bbox_coords)
234 | bbox = np.array(bbox_coords) # y_min, x_min, y_max, x_max
235 | bbox = np.array([bbox[1], bbox[0], bbox[3], bbox[2]])
236 |
237 | key_slice_idx_offset = key_slice_idx - slice_idx_start
238 | key_slice_img = nii_image_data_pre[key_slice_idx_offset, :,:]
239 |
240 | img_3D_ori = nii_image_data_pre
241 | assert np.max(img_3D_ori) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D_ori)}'
242 |
243 | video_height = key_slice_img.shape[0]
244 | video_width = key_slice_img.shape[1]
245 | img_resized = resize_grayscale_to_rgb_and_resize(img_3D_ori, 512)
246 | img_resized = img_resized / 255.0
247 | img_resized = torch.from_numpy(img_resized).cuda()
248 | img_mean=(0.485, 0.456, 0.406)
249 | img_std=(0.229, 0.224, 0.225)
250 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda()
251 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda()
252 | img_resized -= img_mean
253 | img_resized /= img_std
254 | z_mids = []
255 |
256 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
257 | inference_state = predictor.init_state(img_resized, video_height, video_width)
258 | if propagate_with_box:
259 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
260 | inference_state=inference_state,
261 | frame_idx=key_slice_idx_offset,
262 | obj_id=1,
263 | box=bbox,
264 | )
265 | else: # gt
266 | pass
267 |
268 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
269 | segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
270 | predictor.reset_state(inference_state)
271 | if propagate_with_box:
272 | _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
273 | inference_state=inference_state,
274 | frame_idx=key_slice_idx_offset,
275 | obj_id=1,
276 | box=bbox,
277 | )
278 | else: # gt
279 | pass
280 |
281 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
282 | segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
283 | predictor.reset_state(inference_state)
284 | if np.max(segs_3D) > 0:
285 | segs_3D = getLargestCC(segs_3D)
286 | segs_3D = np.uint8(segs_3D)
287 | sitk_image = sitk.GetImageFromArray(img_3D_ori)
288 | sitk_image.CopyInformation(nii_image)
289 | sitk_mask = sitk.GetImageFromArray(segs_3D)
290 | sitk_mask.CopyInformation(nii_image)
291 | # save single lesion
292 | key_slice_idx = row['Key_slice_index']
293 | save_seg_name = nii_fname.split('.nii.gz')[0] + f'_k{key_slice_idx}_mask.nii.gz'
294 | sitk.WriteImage(sitk_image, os.path.join(pred_save_dir, nii_fname.replace('.nii.gz', '_img.nii.gz')))
295 | sitk.WriteImage(sitk_mask, os.path.join(pred_save_dir, save_seg_name))
296 | seg_info['nii_name'].append(save_seg_name)
297 | seg_info['key_slice_index'].append(key_slice_idx)
298 | seg_info['DICOM_windows'].append(row['DICOM_windows'])
299 |
300 | seg_info_df = pd.DataFrame(seg_info)
301 | seg_info_df.to_csv(join(pred_save_dir, 'tiny_seg_info202412.csv'), index=False)
302 |
303 |
304 |
305 |
--------------------------------------------------------------------------------
/multi_node_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -t 7-00:0:0
3 | #SBATCH -J medsam2-tr-tiny
4 | #SBATCH --mem=450G
5 | #SBATCH -c 60
6 | #SBATCH -N 3
7 | #SBATCH --ntasks-per-node=1
8 | #SBATCH --gres=gpu:4
9 | #SBATCH -o out_mnodes_tiny.out
10 |
11 | export PATH=/usr/local/cuda/bin:$PATH
12 | timestamp=$(date +"%Y%m%d-%H%M")
13 |
14 | # Set the master node address (first node in the allocation)
15 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
16 | # export MASTER_PORT=29500
17 | export MASTER_PORT=$(python - <=61.0",
4 | "torch>=2.5.1",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
--------------------------------------------------------------------------------
/sam2/_C.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/_C.so
--------------------------------------------------------------------------------
/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from hydra import initialize_config_module
8 | from hydra.core.global_hydra import GlobalHydra
9 |
10 | if not GlobalHydra.instance().is_initialized():
11 | initialize_config_module("sam2", version_base="1.2")
12 |
--------------------------------------------------------------------------------
/sam2/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/__pycache__/build_sam.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/build_sam.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/__pycache__/sam2_image_predictor.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/sam2_image_predictor.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/__pycache__/sam2_video_predictor_npz.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | import torch
10 | from hydra import compose
11 | from hydra.utils import instantiate
12 | from omegaconf import OmegaConf
13 |
14 | HF_MODEL_ID_TO_FILENAMES = {
15 | "facebook/sam2-hiera-tiny": (
16 | "configs/sam2/sam2_hiera_t.yaml",
17 | "sam2_hiera_tiny.pt",
18 | ),
19 | "facebook/sam2-hiera-small": (
20 | "configs/sam2/sam2_hiera_s.yaml",
21 | "sam2_hiera_small.pt",
22 | ),
23 | "facebook/sam2-hiera-base-plus": (
24 | "configs/sam2/sam2_hiera_b+.yaml",
25 | "sam2_hiera_base_plus.pt",
26 | ),
27 | "facebook/sam2-hiera-large": (
28 | "configs/sam2/sam2_hiera_l.yaml",
29 | "sam2_hiera_large.pt",
30 | ),
31 | "facebook/sam2.1-hiera-tiny": (
32 | "configs/sam2.1/sam2.1_hiera_t.yaml",
33 | "sam2.1_hiera_tiny.pt",
34 | ),
35 | "facebook/sam2.1-hiera-small": (
36 | "configs/sam2.1/sam2.1_hiera_s.yaml",
37 | "sam2.1_hiera_small.pt",
38 | ),
39 | "facebook/sam2.1-hiera-base-plus": (
40 | "configs/sam2.1/sam2.1_hiera_b+.yaml",
41 | "sam2.1_hiera_base_plus.pt",
42 | ),
43 | "facebook/sam2.1-hiera-large": (
44 | "configs/sam2.1/sam2.1_hiera_l.yaml",
45 | "sam2.1_hiera_large.pt",
46 | ),
47 | }
48 |
49 |
50 | def get_best_available_device():
51 | """
52 | Get the best available device in the order: CUDA, MPS, CPU
53 | Returns: device string for torch.device
54 | """
55 | if torch.cuda.is_available():
56 | return "cuda"
57 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
58 | return "mps"
59 | else:
60 | return "cpu"
61 |
62 |
63 | def build_sam2(
64 | config_file,
65 | ckpt_path=None,
66 | device=None,
67 | mode="eval",
68 | hydra_overrides_extra=[],
69 | apply_postprocessing=True,
70 | **kwargs,
71 | ):
72 | # Use the provided device or get the best available one
73 | device = device or get_best_available_device()
74 | logging.info(f"Using device: {device}")
75 |
76 | if apply_postprocessing:
77 | hydra_overrides_extra = hydra_overrides_extra.copy()
78 | hydra_overrides_extra += [
79 | # dynamically fall back to multi-mask if the single mask is not stable
80 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
81 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
82 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
83 | ]
84 | # Read config and init model
85 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
86 | OmegaConf.resolve(cfg)
87 | model = instantiate(cfg.model, _recursive_=True)
88 | _load_checkpoint(model, ckpt_path)
89 | model = model.to(device)
90 | if mode == "eval":
91 | model.eval()
92 | return model
93 |
94 |
95 | def build_sam2_video_predictor(
96 | config_file,
97 | ckpt_path=None,
98 | device=None,
99 | mode="eval",
100 | hydra_overrides_extra=[],
101 | apply_postprocessing=True,
102 | **kwargs,
103 | ):
104 | # Use the provided device or get the best available one
105 | device = device or get_best_available_device()
106 | logging.info(f"Using device: {device}")
107 |
108 | hydra_overrides = [
109 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
110 | ]
111 | if apply_postprocessing:
112 | hydra_overrides_extra = hydra_overrides_extra.copy()
113 | hydra_overrides_extra += [
114 | # dynamically fall back to multi-mask if the single mask is not stable
115 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
116 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
117 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
118 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
119 | "++model.binarize_mask_from_pts_for_mem_enc=true",
120 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
121 | "++model.fill_hole_area=8",
122 | ]
123 | hydra_overrides.extend(hydra_overrides_extra)
124 |
125 | # Read config and init model
126 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
127 | OmegaConf.resolve(cfg)
128 | model = instantiate(cfg.model, _recursive_=True)
129 | _load_checkpoint(model, ckpt_path)
130 | model = model.to(device)
131 | if mode == "eval":
132 | model.eval()
133 | return model
134 |
135 | def build_sam2_video_predictor_npz(
136 | config_file,
137 | ckpt_path=None,
138 | device=None,
139 | mode="eval",
140 | hydra_overrides_extra=[],
141 | apply_postprocessing=True,
142 | **kwargs,
143 | ):
144 | # Use the provided device or get the best available one
145 | device = device or get_best_available_device()
146 | logging.info(f"Using device: {device}")
147 |
148 | hydra_overrides = [
149 | "++model._target_=sam2.sam2_video_predictor_npz.SAM2VideoPredictorNPZ",
150 | ]
151 | if apply_postprocessing:
152 | hydra_overrides_extra = hydra_overrides_extra.copy()
153 | hydra_overrides_extra += [
154 | # dynamically fall back to multi-mask if the single mask is not stable
155 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
156 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
157 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
158 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
159 | "++model.binarize_mask_from_pts_for_mem_enc=true",
160 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
161 | "++model.fill_hole_area=8",
162 | ]
163 | hydra_overrides.extend(hydra_overrides_extra)
164 |
165 | # Read config and init model
166 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
167 | OmegaConf.resolve(cfg)
168 | model = instantiate(cfg.model, _recursive_=True)
169 | _load_checkpoint(model, ckpt_path)
170 | model = model.to(device)
171 | if mode == "eval":
172 | model.eval()
173 | return model
174 |
175 |
176 |
177 | def _hf_download(model_id):
178 | from huggingface_hub import hf_hub_download
179 |
180 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
181 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
182 | return config_name, ckpt_path
183 |
184 |
185 | def build_sam2_hf(model_id, **kwargs):
186 | config_name, ckpt_path = _hf_download(model_id)
187 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
188 |
189 |
190 | def build_sam2_video_predictor_hf(model_id, **kwargs):
191 | config_name, ckpt_path = _hf_download(model_id)
192 | return build_sam2_video_predictor(
193 | config_file=config_name, ckpt_path=ckpt_path, **kwargs
194 | )
195 |
196 |
197 | def _load_checkpoint(model, ckpt_path):
198 | if ckpt_path is not None:
199 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
200 | missing_keys, unexpected_keys = model.load_state_dict(sd)
201 | if missing_keys:
202 | logging.error(missing_keys)
203 | raise RuntimeError()
204 | if unexpected_keys:
205 | logging.error(unexpected_keys)
206 | raise RuntimeError()
207 | logging.info("Loaded checkpoint sucessfully")
208 |
--------------------------------------------------------------------------------
/sam2/configs/sam2.1_hiera_t512.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 512
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | no_obj_embed_spatial: true
97 | # use high-resolution feature map in the SAM mask decoder
98 | use_high_res_features_in_sam: true
99 | # output 3 masks on the first click on initial conditioning frames
100 | multimask_output_in_sam: true
101 | # SAM heads
102 | iou_prediction_use_sigmoid: True
103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104 | use_obj_ptrs_in_encoder: true
105 | add_tpos_enc_to_obj_ptrs: true
106 | proj_tpos_enc_in_obj_ptrs: true
107 | use_signed_tpos_enc_to_obj_ptrs: true
108 | only_obj_ptrs_in_the_past_for_eval: true
109 | # object occlusion prediction
110 | pred_obj_scores: true
111 | pred_obj_scores_mlp: true
112 | fixed_no_obj_ptr: true
113 | # multimask tracking settings
114 | multimask_output_for_tracking: true
115 | use_multimask_token_for_obj_ptr: true
116 | multimask_min_pt_num: 0
117 | multimask_max_pt_num: 1
118 | use_mlp_for_obj_ptr_proj: true
119 | # Compilation flag
120 | # HieraT does not currently support compilation, should always be set to False
121 | compile_image_encoder: False
122 |
--------------------------------------------------------------------------------
/sam2/csrc/connected_components.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Meta Platforms, Inc. and affiliates.
2 | // All rights reserved.
3 |
4 | // This source code is licensed under the license found in the
5 | // LICENSE file in the root directory of this source tree.
6 |
7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch
8 | // with license found in the LICENSE_cctorch file in the root directory.
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | // 2d
17 | #define BLOCK_ROWS 16
18 | #define BLOCK_COLS 16
19 |
20 | namespace cc2d {
21 |
22 | template
23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24 | return (bitmap >> pos) & 1;
25 | }
26 |
27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28 | while (s_buf[n] != n)
29 | n = s_buf[n];
30 | return n;
31 | }
32 |
33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34 | const int32_t id = n;
35 | while (s_buf[n] != n) {
36 | n = s_buf[n];
37 | s_buf[id] = n;
38 | }
39 | return n;
40 | }
41 |
42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43 | bool done;
44 | do {
45 | a = find(s_buf, a);
46 | b = find(s_buf, b);
47 |
48 | if (a < b) {
49 | int32_t old = atomicMin(s_buf + b, a);
50 | done = (old == b);
51 | b = old;
52 | } else if (b < a) {
53 | int32_t old = atomicMin(s_buf + a, b);
54 | done = (old == a);
55 | a = old;
56 | } else
57 | done = true;
58 |
59 | } while (!done);
60 | }
61 |
62 | __global__ void
63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66 | const uint32_t idx = row * W + col;
67 |
68 | if (row < H && col < W)
69 | label[idx] = idx;
70 | }
71 |
72 | __global__ void
73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76 | const uint32_t idx = row * W + col;
77 |
78 | if (row >= H || col >= W)
79 | return;
80 |
81 | uint32_t P = 0;
82 |
83 | if (img[idx])
84 | P |= 0x777;
85 | if (row + 1 < H && img[idx + W])
86 | P |= 0x777 << 4;
87 | if (col + 1 < W && img[idx + 1])
88 | P |= 0x777 << 1;
89 |
90 | if (col == 0)
91 | P &= 0xEEEE;
92 | if (col + 1 >= W)
93 | P &= 0x3333;
94 | else if (col + 2 >= W)
95 | P &= 0x7777;
96 |
97 | if (row == 0)
98 | P &= 0xFFF0;
99 | if (row + 1 >= H)
100 | P &= 0xFF;
101 |
102 | if (P > 0) {
103 | // If need check about top-left pixel(if flag the first bit) and hit the
104 | // top-left pixel
105 | if (hasBit(P, 0) && img[idx - W - 1]) {
106 | union_(label, idx, idx - 2 * W - 2); // top left block
107 | }
108 |
109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110 | union_(label, idx, idx - 2 * W); // top bottom block
111 |
112 | if (hasBit(P, 3) && img[idx + 2 - W])
113 | union_(label, idx, idx - 2 * W + 2); // top right block
114 |
115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116 | union_(label, idx, idx - 2); // just left block
117 | }
118 | }
119 |
120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123 | const uint32_t idx = row * W + col;
124 |
125 | if (row < H && col < W)
126 | find_n_compress(label, idx);
127 | }
128 |
129 | __global__ void final_labeling(
130 | const uint8_t* img,
131 | int32_t* label,
132 | const int32_t W,
133 | const int32_t H) {
134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136 | const uint32_t idx = row * W + col;
137 |
138 | if (row >= H || col >= W)
139 | return;
140 |
141 | int32_t y = label[idx] + 1;
142 |
143 | if (img[idx])
144 | label[idx] = y;
145 | else
146 | label[idx] = 0;
147 |
148 | if (col + 1 < W) {
149 | if (img[idx + 1])
150 | label[idx + 1] = y;
151 | else
152 | label[idx + 1] = 0;
153 |
154 | if (row + 1 < H) {
155 | if (img[idx + W + 1])
156 | label[idx + W + 1] = y;
157 | else
158 | label[idx + W + 1] = 0;
159 | }
160 | }
161 |
162 | if (row + 1 < H) {
163 | if (img[idx + W])
164 | label[idx + W] = y;
165 | else
166 | label[idx + W] = 0;
167 | }
168 | }
169 |
170 | __global__ void init_counting(
171 | const int32_t* label,
172 | int32_t* count_init,
173 | const int32_t W,
174 | const int32_t H) {
175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177 | const uint32_t idx = row * W + col;
178 |
179 | if (row >= H || col >= W)
180 | return;
181 |
182 | int32_t y = label[idx];
183 | if (y > 0) {
184 | int32_t count_idx = y - 1;
185 | atomicAdd(count_init + count_idx, 1);
186 | }
187 | }
188 |
189 | __global__ void final_counting(
190 | const int32_t* label,
191 | const int32_t* count_init,
192 | int32_t* count_final,
193 | const int32_t W,
194 | const int32_t H) {
195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197 | const uint32_t idx = row * W + col;
198 |
199 | if (row >= H || col >= W)
200 | return;
201 |
202 | int32_t y = label[idx];
203 | if (y > 0) {
204 | int32_t count_idx = y - 1;
205 | count_final[idx] = count_init[count_idx];
206 | } else {
207 | count_final[idx] = 0;
208 | }
209 | }
210 |
211 | } // namespace cc2d
212 |
213 | std::vector get_connected_componnets(
214 | const torch::Tensor& inputs) {
215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217 | AT_ASSERTM(
218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219 |
220 | const uint32_t N = inputs.size(0);
221 | const uint32_t C = inputs.size(1);
222 | const uint32_t H = inputs.size(2);
223 | const uint32_t W = inputs.size(3);
224 |
225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226 | AT_ASSERTM((H % 2) == 0, "height must be an even number");
227 | AT_ASSERTM((W % 2) == 0, "width must be an even number");
228 |
229 | // label must be uint32_t
230 | auto label_options =
231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235 |
236 | dim3 grid = dim3(
237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240 | dim3 grid_count =
241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244 |
245 | for (int n = 0; n < N; n++) {
246 | uint32_t offset = n * H * W;
247 |
248 | cc2d::init_labeling<<>>(
249 | labels.data_ptr() + offset, W, H);
250 | cc2d::merge<<>>(
251 | inputs.data_ptr() + offset,
252 | labels.data_ptr() + offset,
253 | W,
254 | H);
255 | cc2d::compression<<>>(
256 | labels.data_ptr() + offset, W, H);
257 | cc2d::final_labeling<<>>(
258 | inputs.data_ptr() + offset,
259 | labels.data_ptr() + offset,
260 | W,
261 | H);
262 |
263 | // get the counting of each pixel
264 | cc2d::init_counting<<>>(
265 | labels.data_ptr() + offset,
266 | counts_init.data_ptr() + offset,
267 | W,
268 | H);
269 | cc2d::final_counting<<>>(
270 | labels.data_ptr() + offset,
271 | counts_init.data_ptr() + offset,
272 | counts_final.data_ptr() + offset,
273 | W,
274 | H);
275 | }
276 |
277 | // returned values are [labels, counts]
278 | std::vector outputs;
279 | outputs.push_back(labels);
280 | outputs.push_back(counts_final);
281 | return outputs;
282 | }
283 |
284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285 | m.def(
286 | "get_connected_componnets",
287 | &get_connected_componnets,
288 | "get_connected_componnets");
289 | }
290 |
--------------------------------------------------------------------------------
/sam2/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/memory_attention.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/memory_attention.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/position_encoding.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/position_encoding.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/sam2_base.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/sam2_base.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/backbones/hieradet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from functools import partial
9 | from typing import List, Tuple, Union
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from iopath.common.file_io import g_pathmgr
15 |
16 | from sam2.modeling.backbones.utils import (
17 | PatchEmbed,
18 | window_partition,
19 | window_unpartition,
20 | )
21 |
22 | from sam2.modeling.sam2_utils import DropPath, MLP
23 |
24 |
25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26 | if pool is None:
27 | return x
28 | # (B, H, W, C) -> (B, C, H, W)
29 | x = x.permute(0, 3, 1, 2)
30 | x = pool(x)
31 | # (B, C, H', W') -> (B, H', W', C)
32 | x = x.permute(0, 2, 3, 1)
33 | if norm:
34 | x = norm(x)
35 |
36 | return x
37 |
38 |
39 | class MultiScaleAttention(nn.Module):
40 | def __init__(
41 | self,
42 | dim: int,
43 | dim_out: int,
44 | num_heads: int,
45 | q_pool: nn.Module = None,
46 | ):
47 | super().__init__()
48 |
49 | self.dim = dim
50 | self.dim_out = dim_out
51 | self.num_heads = num_heads
52 | self.q_pool = q_pool
53 | self.qkv = nn.Linear(dim, dim_out * 3)
54 | self.proj = nn.Linear(dim_out, dim_out)
55 |
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | B, H, W, _ = x.shape
58 | # qkv with shape (B, H * W, 3, nHead, C)
59 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60 | # q, k, v with shape (B, H * W, nheads, C)
61 | q, k, v = torch.unbind(qkv, 2)
62 |
63 | # Q pooling (for downsample at stage changes)
64 | if self.q_pool:
65 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66 | H, W = q.shape[1:3] # downsampled shape
67 | q = q.reshape(B, H * W, self.num_heads, -1)
68 |
69 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70 | x = F.scaled_dot_product_attention(
71 | q.transpose(1, 2),
72 | k.transpose(1, 2),
73 | v.transpose(1, 2),
74 | )
75 | # Transpose back
76 | x = x.transpose(1, 2)
77 | x = x.reshape(B, H, W, -1)
78 |
79 | x = self.proj(x)
80 |
81 | return x
82 |
83 |
84 | class MultiScaleBlock(nn.Module):
85 | def __init__(
86 | self,
87 | dim: int,
88 | dim_out: int,
89 | num_heads: int,
90 | mlp_ratio: float = 4.0,
91 | drop_path: float = 0.0,
92 | norm_layer: Union[nn.Module, str] = "LayerNorm",
93 | q_stride: Tuple[int, int] = None,
94 | act_layer: nn.Module = nn.GELU,
95 | window_size: int = 0,
96 | ):
97 | super().__init__()
98 |
99 | if isinstance(norm_layer, str):
100 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101 |
102 | self.dim = dim
103 | self.dim_out = dim_out
104 | self.norm1 = norm_layer(dim)
105 |
106 | self.window_size = window_size
107 |
108 | self.pool, self.q_stride = None, q_stride
109 | if self.q_stride:
110 | self.pool = nn.MaxPool2d(
111 | kernel_size=q_stride, stride=q_stride, ceil_mode=False
112 | )
113 |
114 | self.attn = MultiScaleAttention(
115 | dim,
116 | dim_out,
117 | num_heads=num_heads,
118 | q_pool=self.pool,
119 | )
120 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121 |
122 | self.norm2 = norm_layer(dim_out)
123 | self.mlp = MLP(
124 | dim_out,
125 | int(dim_out * mlp_ratio),
126 | dim_out,
127 | num_layers=2,
128 | activation=act_layer,
129 | )
130 |
131 | if dim != dim_out:
132 | self.proj = nn.Linear(dim, dim_out)
133 |
134 | def forward(self, x: torch.Tensor) -> torch.Tensor:
135 | shortcut = x # B, H, W, C
136 | x = self.norm1(x)
137 |
138 | # Skip connection
139 | if self.dim != self.dim_out:
140 | shortcut = do_pool(self.proj(x), self.pool)
141 |
142 | # Window partition
143 | window_size = self.window_size
144 | if window_size > 0:
145 | H, W = x.shape[1], x.shape[2]
146 | x, pad_hw = window_partition(x, window_size)
147 |
148 | # Window Attention + Q Pooling (if stage change)
149 | x = self.attn(x)
150 | if self.q_stride:
151 | # Shapes have changed due to Q pooling
152 | window_size = self.window_size // self.q_stride[0]
153 | H, W = shortcut.shape[1:3]
154 |
155 | pad_h = (window_size - H % window_size) % window_size
156 | pad_w = (window_size - W % window_size) % window_size
157 | pad_hw = (H + pad_h, W + pad_w)
158 |
159 | # Reverse window partition
160 | if self.window_size > 0:
161 | x = window_unpartition(x, window_size, pad_hw, (H, W))
162 |
163 | x = shortcut + self.drop_path(x)
164 | # MLP
165 | x = x + self.drop_path(self.mlp(self.norm2(x)))
166 | return x
167 |
168 |
169 | class Hiera(nn.Module):
170 | """
171 | Reference: https://arxiv.org/abs/2306.00989
172 | """
173 |
174 | def __init__(
175 | self,
176 | embed_dim: int = 96, # initial embed dim
177 | num_heads: int = 1, # initial number of heads
178 | drop_path_rate: float = 0.0, # stochastic depth
179 | q_pool: int = 3, # number of q_pool stages
180 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182 | dim_mul: float = 2.0, # dim_mul factor at stage shift
183 | head_mul: float = 2.0, # head_mul factor at stage shift
184 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185 | # window size per stage, when not using global att.
186 | window_spec: Tuple[int, ...] = (
187 | 8,
188 | 4,
189 | 14,
190 | 7,
191 | ),
192 | # global attn in these blocks
193 | global_att_blocks: Tuple[int, ...] = (
194 | 12,
195 | 16,
196 | 20,
197 | ),
198 | weights_path=None,
199 | return_interm_layers=True, # return feats from every stage
200 | ):
201 | super().__init__()
202 |
203 | assert len(stages) == len(window_spec)
204 | self.window_spec = window_spec
205 |
206 | depth = sum(stages)
207 | self.q_stride = q_stride
208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209 | assert 0 <= q_pool <= len(self.stage_ends[:-1])
210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211 | self.return_interm_layers = return_interm_layers
212 |
213 | self.patch_embed = PatchEmbed(
214 | embed_dim=embed_dim,
215 | )
216 | # Which blocks have global att?
217 | self.global_att_blocks = global_att_blocks
218 |
219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221 | self.pos_embed = nn.Parameter(
222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223 | )
224 | self.pos_embed_window = nn.Parameter(
225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226 | )
227 |
228 | dpr = [
229 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
230 | ] # stochastic depth decay rule
231 |
232 | cur_stage = 1
233 | self.blocks = nn.ModuleList()
234 |
235 | for i in range(depth):
236 | dim_out = embed_dim
237 | # lags by a block, so first block of
238 | # next stage uses an initial window size
239 | # of previous stage and final window size of current stage
240 | window_size = self.window_spec[cur_stage - 1]
241 |
242 | if self.global_att_blocks is not None:
243 | window_size = 0 if i in self.global_att_blocks else window_size
244 |
245 | if i - 1 in self.stage_ends:
246 | dim_out = int(embed_dim * dim_mul)
247 | num_heads = int(num_heads * head_mul)
248 | cur_stage += 1
249 |
250 | block = MultiScaleBlock(
251 | dim=embed_dim,
252 | dim_out=dim_out,
253 | num_heads=num_heads,
254 | drop_path=dpr[i],
255 | q_stride=self.q_stride if i in self.q_pool_blocks else None,
256 | window_size=window_size,
257 | )
258 |
259 | embed_dim = dim_out
260 | self.blocks.append(block)
261 |
262 | self.channel_list = (
263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264 | if return_interm_layers
265 | else [self.blocks[-1].dim_out]
266 | )
267 |
268 | if weights_path is not None:
269 | with g_pathmgr.open(weights_path, "rb") as f:
270 | chkpt = torch.load(f, map_location="cpu")
271 | logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272 |
273 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274 | h, w = hw
275 | window_embed = self.pos_embed_window
276 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277 | pos_embed = pos_embed + window_embed.tile(
278 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279 | )
280 | pos_embed = pos_embed.permute(0, 2, 3, 1)
281 | return pos_embed
282 |
283 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284 | x = self.patch_embed(x)
285 | # x: (B, H, W, C)
286 |
287 | # Add pos embed
288 | x = x + self._get_pos_embed(x.shape[1:3])
289 |
290 | outputs = []
291 | for i, blk in enumerate(self.blocks):
292 | x = blk(x)
293 | if (i == self.stage_ends[-1]) or (
294 | i in self.stage_ends and self.return_interm_layers
295 | ):
296 | feats = x.permute(0, 3, 1, 2)
297 | outputs.append(feats)
298 |
299 | return outputs
300 |
301 | def get_layer_id(self, layer_name):
302 | # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303 | num_layers = self.get_num_layers()
304 |
305 | if layer_name.find("rel_pos") != -1:
306 | return num_layers + 1
307 | elif layer_name.find("pos_embed") != -1:
308 | return 0
309 | elif layer_name.find("patch_embed") != -1:
310 | return 0
311 | elif layer_name.find("blocks") != -1:
312 | return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313 | else:
314 | return num_layers + 1
315 |
316 | def get_num_layers(self) -> int:
317 | return len(self.blocks)
318 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class ImageEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | trunk: nn.Module,
18 | neck: nn.Module,
19 | scalp: int = 0,
20 | ):
21 | super().__init__()
22 | self.trunk = trunk
23 | self.neck = neck
24 | self.scalp = scalp
25 | assert (
26 | self.trunk.channel_list == self.neck.backbone_channel_list
27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28 |
29 | def forward(self, sample: torch.Tensor):
30 | # Forward through backbone
31 | features, pos = self.neck(self.trunk(sample))
32 | if self.scalp > 0:
33 | # Discard the lowest resolution features
34 | features, pos = features[: -self.scalp], pos[: -self.scalp]
35 |
36 | src = features[-1]
37 | output = {
38 | "vision_features": src,
39 | "vision_pos_enc": pos,
40 | "backbone_fpn": features,
41 | }
42 | return output
43 |
44 |
45 | class FpnNeck(nn.Module):
46 | """
47 | A modified variant of Feature Pyramid Network (FPN) neck
48 | (we remove output conv and also do bicubic interpolation similar to ViT
49 | pos embed interpolation)
50 | """
51 |
52 | def __init__(
53 | self,
54 | position_encoding: nn.Module,
55 | d_model: int,
56 | backbone_channel_list: List[int],
57 | kernel_size: int = 1,
58 | stride: int = 1,
59 | padding: int = 0,
60 | fpn_interp_model: str = "bilinear",
61 | fuse_type: str = "sum",
62 | fpn_top_down_levels: Optional[List[int]] = None,
63 | ):
64 | """Initialize the neck
65 | :param trunk: the backbone
66 | :param position_encoding: the positional encoding to use
67 | :param d_model: the dimension of the model
68 | :param neck_norm: the normalization to use
69 | """
70 | super().__init__()
71 | self.position_encoding = position_encoding
72 | self.convs = nn.ModuleList()
73 | self.backbone_channel_list = backbone_channel_list
74 | self.d_model = d_model
75 | for dim in backbone_channel_list:
76 | current = nn.Sequential()
77 | current.add_module(
78 | "conv",
79 | nn.Conv2d(
80 | in_channels=dim,
81 | out_channels=d_model,
82 | kernel_size=kernel_size,
83 | stride=stride,
84 | padding=padding,
85 | ),
86 | )
87 |
88 | self.convs.append(current)
89 | self.fpn_interp_model = fpn_interp_model
90 | assert fuse_type in ["sum", "avg"]
91 | self.fuse_type = fuse_type
92 |
93 | # levels to have top-down features in its outputs
94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95 | # have top-down propagation, while outputs of level 0 and level 1 have only
96 | # lateral features from the same backbone level.
97 | if fpn_top_down_levels is None:
98 | # default is to have top-down features on all levels
99 | fpn_top_down_levels = range(len(self.convs))
100 | self.fpn_top_down_levels = list(fpn_top_down_levels)
101 |
102 | def forward(self, xs: List[torch.Tensor]):
103 |
104 | out = [None] * len(self.convs)
105 | pos = [None] * len(self.convs)
106 | assert len(xs) == len(self.convs)
107 | # fpn forward pass
108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109 | prev_features = None
110 | # forward in top-down order (from low to high resolution)
111 | n = len(self.convs) - 1
112 | for i in range(n, -1, -1):
113 | x = xs[i]
114 | lateral_features = self.convs[n - i](x)
115 | if i in self.fpn_top_down_levels and prev_features is not None:
116 | top_down_features = F.interpolate(
117 | prev_features.to(dtype=torch.float32),
118 | scale_factor=2.0,
119 | mode=self.fpn_interp_model,
120 | align_corners=(
121 | None if self.fpn_interp_model == "nearest" else False
122 | ),
123 | antialias=False,
124 | )
125 | prev_features = lateral_features + top_down_features
126 | if self.fuse_type == "avg":
127 | prev_features /= 2
128 | else:
129 | prev_features = lateral_features
130 | x_out = prev_features
131 | out[i] = x_out
132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133 |
134 | return out, pos
135 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = (
36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | )
38 | return windows, (Hp, Wp)
39 |
40 |
41 | def window_unpartition(windows, window_size, pad_hw, hw):
42 | """
43 | Window unpartition into original sequences and removing padding.
44 | Args:
45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46 | window_size (int): window size.
47 | pad_hw (Tuple): padded height and width (Hp, Wp).
48 | hw (Tuple): original height and width (H, W) before padding.
49 | Returns:
50 | x: unpartitioned sequences with [B, H, W, C].
51 | """
52 | Hp, Wp = pad_hw
53 | H, W = hw
54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55 | x = windows.view(
56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57 | )
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59 |
60 | if Hp > H or Wp > W:
61 | x = x[:, :H, :W, :].contiguous()
62 | return x
63 |
64 |
65 | class PatchEmbed(nn.Module):
66 | """
67 | Image to Patch Embedding.
68 | """
69 |
70 | def __init__(
71 | self,
72 | kernel_size: Tuple[int, ...] = (7, 7),
73 | stride: Tuple[int, ...] = (4, 4),
74 | padding: Tuple[int, ...] = (3, 3),
75 | in_chans: int = 3,
76 | embed_dim: int = 768,
77 | ):
78 | """
79 | Args:
80 | kernel_size (Tuple): kernel size of the projection layer.
81 | stride (Tuple): stride of the projection layer.
82 | padding (Tuple): padding size of the projection layer.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): embed_dim (int): Patch embedding dimension.
85 | """
86 | super().__init__()
87 | self.proj = nn.Conv2d(
88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89 | )
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | x = self.proj(x)
93 | # B C H W -> B H W C
94 | x = x.permute(0, 2, 3, 1)
95 | return x
96 |
--------------------------------------------------------------------------------
/sam2/modeling/memory_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn, Tensor
11 |
12 | from sam2.modeling.sam.transformer import RoPEAttention
13 |
14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15 |
16 |
17 | class MemoryAttentionLayer(nn.Module):
18 |
19 | def __init__(
20 | self,
21 | activation: str,
22 | cross_attention: nn.Module,
23 | d_model: int,
24 | dim_feedforward: int,
25 | dropout: float,
26 | pos_enc_at_attn: bool,
27 | pos_enc_at_cross_attn_keys: bool,
28 | pos_enc_at_cross_attn_queries: bool,
29 | self_attention: nn.Module,
30 | ):
31 | super().__init__()
32 | self.d_model = d_model
33 | self.dim_feedforward = dim_feedforward
34 | self.dropout_value = dropout
35 | self.self_attn = self_attention
36 | self.cross_attn_image = cross_attention
37 |
38 | # Implementation of Feedforward model
39 | self.linear1 = nn.Linear(d_model, dim_feedforward)
40 | self.dropout = nn.Dropout(dropout)
41 | self.linear2 = nn.Linear(dim_feedforward, d_model)
42 |
43 | self.norm1 = nn.LayerNorm(d_model)
44 | self.norm2 = nn.LayerNorm(d_model)
45 | self.norm3 = nn.LayerNorm(d_model)
46 | self.dropout1 = nn.Dropout(dropout)
47 | self.dropout2 = nn.Dropout(dropout)
48 | self.dropout3 = nn.Dropout(dropout)
49 |
50 | self.activation_str = activation
51 | self.activation = get_activation_fn(activation)
52 |
53 | # Where to add pos enc
54 | self.pos_enc_at_attn = pos_enc_at_attn
55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57 |
58 | def _forward_sa(self, tgt, query_pos):
59 | # Self-Attention
60 | tgt2 = self.norm1(tgt)
61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62 | tgt2 = self.self_attn(q, k, v=tgt2)
63 | tgt = tgt + self.dropout1(tgt2)
64 | return tgt
65 |
66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67 | kwds = {}
68 | if num_k_exclude_rope > 0:
69 | assert isinstance(self.cross_attn_image, RoPEAttention)
70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71 |
72 | # Cross-Attention
73 | tgt2 = self.norm2(tgt)
74 | tgt2 = self.cross_attn_image(
75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77 | v=memory,
78 | **kwds,
79 | )
80 | tgt = tgt + self.dropout2(tgt2)
81 | return tgt
82 |
83 | def forward(
84 | self,
85 | tgt,
86 | memory,
87 | pos: Optional[Tensor] = None,
88 | query_pos: Optional[Tensor] = None,
89 | num_k_exclude_rope: int = 0,
90 | ) -> torch.Tensor:
91 |
92 | # Self-Attn, Cross-Attn
93 | tgt = self._forward_sa(tgt, query_pos)
94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95 | # MLP
96 | tgt2 = self.norm3(tgt)
97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98 | tgt = tgt + self.dropout3(tgt2)
99 | return tgt
100 |
101 |
102 | class MemoryAttention(nn.Module):
103 | def __init__(
104 | self,
105 | d_model: int,
106 | pos_enc_at_input: bool,
107 | layer: nn.Module,
108 | num_layers: int,
109 | batch_first: bool = True, # Do layers expect batch first input?
110 | ):
111 | super().__init__()
112 | self.d_model = d_model
113 | self.layers = get_clones(layer, num_layers)
114 | self.num_layers = num_layers
115 | self.norm = nn.LayerNorm(d_model)
116 | self.pos_enc_at_input = pos_enc_at_input
117 | self.batch_first = batch_first
118 |
119 | def forward(
120 | self,
121 | curr: torch.Tensor, # self-attention inputs
122 | memory: torch.Tensor, # cross-attention inputs
123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126 | ):
127 | if isinstance(curr, list):
128 | assert isinstance(curr_pos, list)
129 | assert len(curr) == len(curr_pos) == 1
130 | curr, curr_pos = (
131 | curr[0],
132 | curr_pos[0],
133 | )
134 |
135 | assert (
136 | curr.shape[1] == memory.shape[1]
137 | ), "Batch size must be the same for curr and memory"
138 |
139 | output = curr
140 | if self.pos_enc_at_input and curr_pos is not None:
141 | output = output + 0.1 * curr_pos
142 |
143 | if self.batch_first:
144 | # Convert to batch first
145 | output = output.transpose(0, 1)
146 | curr_pos = curr_pos.transpose(0, 1)
147 | memory = memory.transpose(0, 1)
148 | memory_pos = memory_pos.transpose(0, 1)
149 |
150 | for layer in self.layers:
151 | kwds = {}
152 | if isinstance(layer.cross_attn_image, RoPEAttention):
153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154 |
155 | output = layer(
156 | tgt=output,
157 | memory=memory,
158 | pos=memory_pos,
159 | query_pos=curr_pos,
160 | **kwds,
161 | )
162 | normed_output = self.norm(output)
163 |
164 | if self.batch_first:
165 | # Convert back to seq first
166 | normed_output = normed_output.transpose(0, 1)
167 | curr_pos = curr_pos.transpose(0, 1)
168 |
169 | return normed_output
170 |
--------------------------------------------------------------------------------
/sam2/modeling/memory_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15 |
16 |
17 | class MaskDownSampler(nn.Module):
18 | """
19 | Progressively downsample a mask by total_stride, each time by stride.
20 | Note that LayerNorm is applied per *token*, like in ViT.
21 |
22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23 | In the end, we linearly project to embed_dim channels.
24 | """
25 |
26 | def __init__(
27 | self,
28 | embed_dim=256,
29 | kernel_size=4,
30 | stride=4,
31 | padding=0,
32 | total_stride=16,
33 | activation=nn.GELU,
34 | ):
35 | super().__init__()
36 | num_layers = int(math.log2(total_stride) // math.log2(stride))
37 | assert stride**num_layers == total_stride
38 | self.encoder = nn.Sequential()
39 | mask_in_chans, mask_out_chans = 1, 1
40 | for _ in range(num_layers):
41 | mask_out_chans = mask_in_chans * (stride**2)
42 | self.encoder.append(
43 | nn.Conv2d(
44 | mask_in_chans,
45 | mask_out_chans,
46 | kernel_size=kernel_size,
47 | stride=stride,
48 | padding=padding,
49 | )
50 | )
51 | self.encoder.append(LayerNorm2d(mask_out_chans))
52 | self.encoder.append(activation())
53 | mask_in_chans = mask_out_chans
54 |
55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56 |
57 | def forward(self, x):
58 | return self.encoder(x)
59 |
60 |
61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62 | class CXBlock(nn.Module):
63 | r"""ConvNeXt Block. There are two equivalent implementations:
64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66 | We use (2) as we find it slightly faster in PyTorch
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | drop_path (float): Stochastic depth rate. Default: 0.0
71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72 | """
73 |
74 | def __init__(
75 | self,
76 | dim,
77 | kernel_size=7,
78 | padding=3,
79 | drop_path=0.0,
80 | layer_scale_init_value=1e-6,
81 | use_dwconv=True,
82 | ):
83 | super().__init__()
84 | self.dwconv = nn.Conv2d(
85 | dim,
86 | dim,
87 | kernel_size=kernel_size,
88 | padding=padding,
89 | groups=dim if use_dwconv else 1,
90 | ) # depthwise conv
91 | self.norm = LayerNorm2d(dim, eps=1e-6)
92 | self.pwconv1 = nn.Linear(
93 | dim, 4 * dim
94 | ) # pointwise/1x1 convs, implemented with linear layers
95 | self.act = nn.GELU()
96 | self.pwconv2 = nn.Linear(4 * dim, dim)
97 | self.gamma = (
98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99 | if layer_scale_init_value > 0
100 | else None
101 | )
102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103 |
104 | def forward(self, x):
105 | input = x
106 | x = self.dwconv(x)
107 | x = self.norm(x)
108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109 | x = self.pwconv1(x)
110 | x = self.act(x)
111 | x = self.pwconv2(x)
112 | if self.gamma is not None:
113 | x = self.gamma * x
114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115 |
116 | x = input + self.drop_path(x)
117 | return x
118 |
119 |
120 | class Fuser(nn.Module):
121 | def __init__(self, layer, num_layers, dim=None, input_projection=False):
122 | super().__init__()
123 | self.proj = nn.Identity()
124 | self.layers = get_clones(layer, num_layers)
125 |
126 | if input_projection:
127 | assert dim is not None
128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129 |
130 | def forward(self, x):
131 | # normally x: (N, C, H, W)
132 | x = self.proj(x)
133 | for layer in self.layers:
134 | x = layer(x)
135 | return x
136 |
137 |
138 | class MemoryEncoder(nn.Module):
139 | def __init__(
140 | self,
141 | out_dim,
142 | mask_downsampler,
143 | fuser,
144 | position_encoding,
145 | in_dim=256, # in_dim of pix_feats
146 | ):
147 | super().__init__()
148 |
149 | self.mask_downsampler = mask_downsampler
150 |
151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152 | self.fuser = fuser
153 | self.position_encoding = position_encoding
154 | self.out_proj = nn.Identity()
155 | if out_dim != in_dim:
156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157 |
158 | def forward(
159 | self,
160 | pix_feat: torch.Tensor,
161 | masks: torch.Tensor,
162 | skip_mask_sigmoid: bool = False,
163 | ) -> Tuple[torch.Tensor, torch.Tensor]:
164 | ## Process masks
165 | # sigmoid, so that less domain shift from gt masks which are bool
166 | if not skip_mask_sigmoid:
167 | masks = F.sigmoid(masks)
168 | masks = self.mask_downsampler(masks)
169 |
170 | ## Fuse pix_feats and downsampled masks
171 | # in case the visual features are on CPU, cast them to CUDA
172 | pix_feat = pix_feat.to(masks.device)
173 |
174 | x = self.pix_feat_proj(pix_feat)
175 | x = x + masks
176 | x = self.fuser(x)
177 | x = self.out_proj(x)
178 |
179 | pos = self.position_encoding(x).to(x.dtype)
180 |
181 | return {"vision_features": x, "vision_pos_enc": [pos]}
182 |
--------------------------------------------------------------------------------
/sam2/modeling/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Any, Optional, Tuple
9 |
10 | import numpy as np
11 |
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class PositionEmbeddingSine(nn.Module):
17 | """
18 | This is a more standard version of the position embedding, very similar to the one
19 | used by the Attention Is All You Need paper, generalized to work on images.
20 | """
21 |
22 | def __init__(
23 | self,
24 | num_pos_feats,
25 | temperature: int = 10000,
26 | normalize: bool = True,
27 | scale: Optional[float] = None,
28 | ):
29 | super().__init__()
30 | assert num_pos_feats % 2 == 0, "Expecting even model width"
31 | self.num_pos_feats = num_pos_feats // 2
32 | self.temperature = temperature
33 | self.normalize = normalize
34 | if scale is not None and normalize is False:
35 | raise ValueError("normalize should be True if scale is passed")
36 | if scale is None:
37 | scale = 2 * math.pi
38 | self.scale = scale
39 |
40 | self.cache = {}
41 |
42 | def _encode_xy(self, x, y):
43 | # The positions are expected to be normalized
44 | assert len(x) == len(y) and x.ndim == y.ndim == 1
45 | x_embed = x * self.scale
46 | y_embed = y * self.scale
47 |
48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50 |
51 | pos_x = x_embed[:, None] / dim_t
52 | pos_y = y_embed[:, None] / dim_t
53 | pos_x = torch.stack(
54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
55 | ).flatten(1)
56 | pos_y = torch.stack(
57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
58 | ).flatten(1)
59 | return pos_x, pos_y
60 |
61 | @torch.no_grad()
62 | def encode_boxes(self, x, y, w, h):
63 | pos_x, pos_y = self._encode_xy(x, y)
64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
65 | return pos
66 |
67 | encode = encode_boxes # Backwards compatibility
68 |
69 | @torch.no_grad()
70 | def encode_points(self, x, y, labels):
71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
72 | assert bx == by and nx == ny and bx == bl and nx == nl
73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
76 | return pos
77 |
78 | @torch.no_grad()
79 | def forward(self, x: torch.Tensor):
80 | cache_key = (x.shape[-2], x.shape[-1])
81 | if cache_key in self.cache:
82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
83 | y_embed = (
84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
85 | .view(1, -1, 1)
86 | .repeat(x.shape[0], 1, x.shape[-1])
87 | )
88 | x_embed = (
89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
90 | .view(1, 1, -1)
91 | .repeat(x.shape[0], x.shape[-2], 1)
92 | )
93 |
94 | if self.normalize:
95 | eps = 1e-6
96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
98 |
99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
101 |
102 | pos_x = x_embed[:, :, :, None] / dim_t
103 | pos_y = y_embed[:, :, :, None] / dim_t
104 | pos_x = torch.stack(
105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
106 | ).flatten(3)
107 | pos_y = torch.stack(
108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
109 | ).flatten(3)
110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
111 | self.cache[cache_key] = pos[0]
112 | return pos
113 |
114 |
115 | class PositionEmbeddingRandom(nn.Module):
116 | """
117 | Positional encoding using random spatial frequencies.
118 | """
119 |
120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
121 | super().__init__()
122 | if scale is None or scale <= 0.0:
123 | scale = 1.0
124 | self.register_buffer(
125 | "positional_encoding_gaussian_matrix",
126 | scale * torch.randn((2, num_pos_feats)),
127 | )
128 |
129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
130 | """Positionally encode points that are normalized to [0,1]."""
131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
132 | coords = 2 * coords - 1
133 | coords = coords @ self.positional_encoding_gaussian_matrix
134 | coords = 2 * np.pi * coords
135 | # outputs d_1 x ... x d_n x C shape
136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
137 |
138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
139 | """Generate positional encoding for a grid of the specified size."""
140 | h, w = size
141 | device: Any = self.positional_encoding_gaussian_matrix.device
142 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
143 | y_embed = grid.cumsum(dim=0) - 0.5
144 | x_embed = grid.cumsum(dim=1) - 0.5
145 | y_embed = y_embed / h
146 | x_embed = x_embed / w
147 |
148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
149 | return pe.permute(2, 0, 1) # C x H x W
150 |
151 | def forward_with_coords(
152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
153 | ) -> torch.Tensor:
154 | """Positionally encode points that are not normalized to [0,1]."""
155 | coords = coords_input.clone()
156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
159 |
160 |
161 | # Rotary Positional Encoding, adapted from:
162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
163 | # 2. https://github.com/naver-ai/rope-vit
164 | # 3. https://github.com/lucidrains/rotary-embedding-torch
165 |
166 |
167 | def init_t_xy(end_x: int, end_y: int):
168 | t = torch.arange(end_x * end_y, dtype=torch.float32)
169 | t_x = (t % end_x).float()
170 | t_y = torch.div(t, end_x, rounding_mode="floor").float()
171 | return t_x, t_y
172 |
173 |
174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
177 |
178 | t_x, t_y = init_t_xy(end_x, end_y)
179 | freqs_x = torch.outer(t_x, freqs_x)
180 | freqs_y = torch.outer(t_y, freqs_y)
181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
184 |
185 |
186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
187 | ndim = x.ndim
188 | assert 0 <= 1 < ndim
189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
191 | return freqs_cis.view(*shape)
192 |
193 |
194 | def apply_rotary_enc(
195 | xq: torch.Tensor,
196 | xk: torch.Tensor,
197 | freqs_cis: torch.Tensor,
198 | repeat_freqs_k: bool = False,
199 | ):
200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
201 | xk_ = (
202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
203 | if xk.shape[-2] != 0
204 | else None
205 | )
206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
208 | if xk_ is None:
209 | # no keys to rotate, due to dropout
210 | return xq_out.type_as(xq).to(xq.device), xk
211 | # repeat freqs along seq_len dim to match k seq_len
212 | if repeat_freqs_k:
213 | r = xk_.shape[-2] // xq_.shape[-2]
214 | if freqs_cis.is_cuda:
215 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
216 | else:
217 | # torch.repeat on complex numbers may not be supported on non-CUDA devices
218 | # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
219 | freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
220 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
221 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
222 |
--------------------------------------------------------------------------------
/sam2/modeling/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/modeling/sam/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom
13 |
14 | from sam2.modeling.sam2_utils import LayerNorm2d
15 |
16 |
17 | class PromptEncoder(nn.Module):
18 | def __init__(
19 | self,
20 | embed_dim: int,
21 | image_embedding_size: Tuple[int, int],
22 | input_image_size: Tuple[int, int],
23 | mask_in_chans: int,
24 | activation: Type[nn.Module] = nn.GELU,
25 | ) -> None:
26 | """
27 | Encodes prompts for input to SAM's mask decoder.
28 |
29 | Arguments:
30 | embed_dim (int): The prompts' embedding dimension
31 | image_embedding_size (tuple(int, int)): The spatial size of the
32 | image embedding, as (H, W).
33 | input_image_size (int): The padded size of the image as input
34 | to the image encoder, as (H, W).
35 | mask_in_chans (int): The number of hidden channels used for
36 | encoding input masks.
37 | activation (nn.Module): The activation to use when encoding
38 | input masks.
39 | """
40 | super().__init__()
41 | self.embed_dim = embed_dim
42 | self.input_image_size = input_image_size
43 | self.image_embedding_size = image_embedding_size
44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45 |
46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47 | point_embeddings = [
48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49 | ]
50 | self.point_embeddings = nn.ModuleList(point_embeddings)
51 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
52 |
53 | self.mask_input_size = (
54 | 4 * image_embedding_size[0],
55 | 4 * image_embedding_size[1],
56 | )
57 | self.mask_downscaling = nn.Sequential(
58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59 | LayerNorm2d(mask_in_chans // 4),
60 | activation(),
61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62 | LayerNorm2d(mask_in_chans),
63 | activation(),
64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65 | )
66 | self.no_mask_embed = nn.Embedding(1, embed_dim)
67 |
68 | def get_dense_pe(self) -> torch.Tensor:
69 | """
70 | Returns the positional encoding used to encode point prompts,
71 | applied to a dense set of points the shape of the image encoding.
72 |
73 | Returns:
74 | torch.Tensor: Positional encoding with shape
75 | 1x(embed_dim)x(embedding_h)x(embedding_w)
76 | """
77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78 |
79 | def _embed_points(
80 | self,
81 | points: torch.Tensor,
82 | labels: torch.Tensor,
83 | pad: bool,
84 | ) -> torch.Tensor:
85 | """Embeds point prompts."""
86 | points = points + 0.5 # Shift to center of pixel
87 | if pad:
88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90 | points = torch.cat([points, padding_point], dim=1)
91 | labels = torch.cat([labels, padding_label], dim=1)
92 | point_embedding = self.pe_layer.forward_with_coords(
93 | points, self.input_image_size
94 | )
95 | point_embedding[labels == -1] = 0.0
96 | point_embedding[labels == -1] += self.not_a_point_embed.weight
97 | point_embedding[labels == 0] += self.point_embeddings[0].weight
98 | point_embedding[labels == 1] += self.point_embeddings[1].weight
99 | point_embedding[labels == 2] += self.point_embeddings[2].weight
100 | point_embedding[labels == 3] += self.point_embeddings[3].weight
101 | return point_embedding
102 |
103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104 | """Embeds box prompts."""
105 | boxes = boxes + 0.5 # Shift to center of pixel
106 | coords = boxes.reshape(-1, 2, 2)
107 | corner_embedding = self.pe_layer.forward_with_coords(
108 | coords, self.input_image_size
109 | )
110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112 | return corner_embedding
113 |
114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115 | """Embeds mask inputs."""
116 | mask_embedding = self.mask_downscaling(masks)
117 | return mask_embedding
118 |
119 | def _get_batch_size(
120 | self,
121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122 | boxes: Optional[torch.Tensor],
123 | masks: Optional[torch.Tensor],
124 | ) -> int:
125 | """
126 | Gets the batch size of the output given the batch size of the input prompts.
127 | """
128 | if points is not None:
129 | return points[0].shape[0]
130 | elif boxes is not None:
131 | return boxes.shape[0]
132 | elif masks is not None:
133 | return masks.shape[0]
134 | else:
135 | return 1
136 |
137 | def _get_device(self) -> torch.device:
138 | return self.point_embeddings[0].weight.device
139 |
140 | def forward(
141 | self,
142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143 | boxes: Optional[torch.Tensor],
144 | masks: Optional[torch.Tensor],
145 | ) -> Tuple[torch.Tensor, torch.Tensor]:
146 | """
147 | Embeds different types of prompts, returning both sparse and dense
148 | embeddings.
149 |
150 | Arguments:
151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152 | and labels to embed.
153 | boxes (torch.Tensor or none): boxes to embed
154 | masks (torch.Tensor or none): masks to embed
155 |
156 | Returns:
157 | torch.Tensor: sparse embeddings for the points and boxes, with shape
158 | BxNx(embed_dim), where N is determined by the number of input points
159 | and boxes.
160 | torch.Tensor: dense embeddings for the masks, in the shape
161 | Bx(embed_dim)x(embed_H)x(embed_W)
162 | """
163 | bs = self._get_batch_size(points, boxes, masks)
164 | sparse_embeddings = torch.empty(
165 | (bs, 0, self.embed_dim), device=self._get_device()
166 | )
167 | if points is not None:
168 | coords, labels = points
169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171 | if boxes is not None:
172 | box_embeddings = self._embed_boxes(boxes)
173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174 |
175 | if masks is not None:
176 | dense_embeddings = self._embed_masks(masks)
177 | else:
178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180 | )
181 |
182 | return sparse_embeddings, dense_embeddings
183 |
--------------------------------------------------------------------------------
/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/utils/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/utils/__pycache__/misc.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/misc.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/utils/__pycache__/transforms.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/sam2/utils/__pycache__/transforms.cpython-312.pyc
--------------------------------------------------------------------------------
/sam2/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import warnings
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchvision.transforms import Normalize, Resize, ToTensor
13 |
14 |
15 | class SAM2Transforms(nn.Module):
16 | def __init__(
17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
18 | ):
19 | """
20 | Transforms for SAM2.
21 | """
22 | super().__init__()
23 | self.resolution = resolution
24 | self.mask_threshold = mask_threshold
25 | self.max_hole_area = max_hole_area
26 | self.max_sprinkle_area = max_sprinkle_area
27 | self.mean = [0.485, 0.456, 0.406]
28 | self.std = [0.229, 0.224, 0.225]
29 | self.to_tensor = ToTensor()
30 | self.transforms = torch.jit.script(
31 | nn.Sequential(
32 | Resize((self.resolution, self.resolution)),
33 | Normalize(self.mean, self.std),
34 | )
35 | )
36 |
37 | def __call__(self, x):
38 | x = self.to_tensor(x)
39 | return self.transforms(x)
40 |
41 | def forward_batch(self, img_list):
42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
43 | img_batch = torch.stack(img_batch, dim=0)
44 | return img_batch
45 |
46 | def transform_coords(
47 | self, coords: torch.Tensor, normalize=False, orig_hw=None
48 | ) -> torch.Tensor:
49 | """
50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
52 |
53 | Returns
54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
55 | """
56 | if normalize:
57 | assert orig_hw is not None
58 | h, w = orig_hw
59 | coords = coords.clone()
60 | coords[..., 0] = coords[..., 0] / w
61 | coords[..., 1] = coords[..., 1] / h
62 |
63 | coords = coords * self.resolution # unnormalize coords
64 | return coords
65 |
66 | def transform_boxes(
67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
68 | ) -> torch.Tensor:
69 | """
70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
72 | """
73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
74 | return boxes
75 |
76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
77 | """
78 | Perform PostProcessing on output masks.
79 | """
80 | from sam2.utils.misc import get_connected_components
81 |
82 | masks = masks.float()
83 | input_masks = masks
84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85 | try:
86 | if self.max_hole_area > 0:
87 | # Holes are those connected components in background with area <= self.fill_hole_area
88 | # (background regions are those with mask scores <= self.mask_threshold)
89 | labels, areas = get_connected_components(
90 | mask_flat <= self.mask_threshold
91 | )
92 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
93 | is_hole = is_hole.reshape_as(masks)
94 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
96 |
97 | if self.max_sprinkle_area > 0:
98 | labels, areas = get_connected_components(
99 | mask_flat > self.mask_threshold
100 | )
101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
102 | is_hole = is_hole.reshape_as(masks)
103 | # We fill holes with negative mask score (-10.0) to change them to background.
104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
105 | except Exception as e:
106 | # Skip the post-processing step if the CUDA kernel fails
107 | warnings.warn(
108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can "
109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
110 | "functionality may be limited (which doesn't affect the results in most cases; see "
111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
112 | category=UserWarning,
113 | stacklevel=2,
114 | )
115 | masks = input_masks
116 |
117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
118 | return masks
119 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 |
8 | from setuptools import find_packages, setup
9 |
10 | # Package metadata
11 | NAME = "MedSAM2"
12 | VERSION = "1.0"
13 | DESCRIPTION = "MedSAM2 was adapted from SAM2 (https://github.com/facebookresearch/sam2) for medical image segmentation."
14 | URL = "https://github.com/bowang-lab/MedSAM2"
15 | AUTHOR = "WangLab"
16 | AUTHOR_EMAIL = "medseg20s@gmail.com"
17 | LICENSE = "Apache 2.0"
18 |
19 | # Read the contents of README file
20 | with open("README.md", "r", encoding="utf-8") as f:
21 | LONG_DESCRIPTION = f.read()
22 |
23 | # Required dependencies
24 | REQUIRED_PACKAGES = [
25 | "torch>=2.5.1",
26 | "torchvision>=0.20.1",
27 | "numpy>=2.0.1",
28 | "tqdm>=4.66.5",
29 | "hydra-core>=1.3.2",
30 | "iopath>=0.1.10",
31 | "pillow>=10.4.0",
32 | "SimpleITK>=2.4.0",
33 | ]
34 |
35 | EXTRA_PACKAGES = {
36 | "notebooks": [
37 | "matplotlib>=3.9.1",
38 | "jupyter>=1.0.0",
39 | "opencv-python>=4.10.0",
40 | "eva-decord>=0.6.1",
41 | ],
42 | "interactive-demo": [
43 | "Flask>=3.0.3",
44 | "Flask-Cors>=5.0.0",
45 | "av>=13.0.0",
46 | "dataclasses-json>=0.6.7",
47 | "eva-decord>=0.6.1",
48 | "gunicorn>=23.0.0",
49 | "imagesize>=1.4.1",
50 | "pycocotools>=2.0.8",
51 | "strawberry-graphql>=0.239.2",
52 | ],
53 | "dev": [
54 | "matplotlib>=3.9.1",
55 | "jupyter>=1.0.0",
56 | "black==24.2.0",
57 | "usort==1.0.2",
58 | "ufmt==2.0.0b2",
59 | "fvcore>=0.1.5.post20221221",
60 | "pandas>=2.2.3",
61 | "scikit-image>=0.24.0",
62 | "tensorboard>=2.17.0",
63 | "pycocotools>=2.0.8",
64 | "tensordict>=0.5.0",
65 | "opencv-python>=4.10.0",
66 | "submitit>=1.5.1",
67 | ],
68 | }
69 |
70 | # By default, we also build the SAM 2 CUDA extension.
71 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
72 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
73 | # By default, we allow SAM 2 installation to proceed even with build errors.
74 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
75 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
76 |
77 | # Catch and skip errors during extension building and print a warning message
78 | # (note that this message only shows up under verbose build mode
79 | # "pip install -v -e ." or "python setup.py build_ext -v")
80 | CUDA_ERROR_MSG = (
81 | "{}\n\n"
82 | "Failed to build the SAM 2 CUDA extension due to the error above. "
83 | "You can still use SAM 2 and it's OK to ignore the error above, although some "
84 | "post-processing functionality may be limited (which doesn't affect the results in most cases; "
85 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
86 | )
87 |
88 |
89 | def get_extensions():
90 | if not BUILD_CUDA:
91 | return []
92 |
93 | try:
94 | from torch.utils.cpp_extension import CUDAExtension
95 |
96 | srcs = ["sam2/csrc/connected_components.cu"]
97 | compile_args = {
98 | "cxx": [],
99 | "nvcc": [
100 | "-DCUDA_HAS_FP16=1",
101 | "-D__CUDA_NO_HALF_OPERATORS__",
102 | "-D__CUDA_NO_HALF_CONVERSIONS__",
103 | "-D__CUDA_NO_HALF2_OPERATORS__",
104 | ],
105 | }
106 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
107 | except Exception as e:
108 | if BUILD_ALLOW_ERRORS:
109 | print(CUDA_ERROR_MSG.format(e))
110 | ext_modules = []
111 | else:
112 | raise e
113 |
114 | return ext_modules
115 |
116 |
117 | try:
118 | from torch.utils.cpp_extension import BuildExtension
119 |
120 | class BuildExtensionIgnoreErrors(BuildExtension):
121 |
122 | def finalize_options(self):
123 | try:
124 | super().finalize_options()
125 | except Exception as e:
126 | print(CUDA_ERROR_MSG.format(e))
127 | self.extensions = []
128 |
129 | def build_extensions(self):
130 | try:
131 | super().build_extensions()
132 | except Exception as e:
133 | print(CUDA_ERROR_MSG.format(e))
134 | self.extensions = []
135 |
136 | def get_ext_filename(self, ext_name):
137 | try:
138 | return super().get_ext_filename(ext_name)
139 | except Exception as e:
140 | print(CUDA_ERROR_MSG.format(e))
141 | self.extensions = []
142 | return "_C.so"
143 |
144 | cmdclass = {
145 | "build_ext": (
146 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
147 | if BUILD_ALLOW_ERRORS
148 | else BuildExtension.with_options(no_python_abi_suffix=True)
149 | )
150 | }
151 | except Exception as e:
152 | cmdclass = {}
153 | if BUILD_ALLOW_ERRORS:
154 | print(CUDA_ERROR_MSG.format(e))
155 | else:
156 | raise e
157 |
158 |
159 | # Setup configuration
160 | setup(
161 | name=NAME,
162 | version=VERSION,
163 | description=DESCRIPTION,
164 | long_description=LONG_DESCRIPTION,
165 | long_description_content_type="text/markdown",
166 | url=URL,
167 | author=AUTHOR,
168 | author_email=AUTHOR_EMAIL,
169 | license=LICENSE,
170 | packages=find_packages(exclude="notebooks"),
171 | include_package_data=True,
172 | install_requires=REQUIRED_PACKAGES,
173 | extras_require=EXTRA_PACKAGES,
174 | python_requires=">=3.10.0",
175 | ext_modules=get_extensions(),
176 | cmdclass=cmdclass,
177 | )
178 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/training/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/training/__pycache__/loss_fns.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/loss_fns.cpython-312.pyc
--------------------------------------------------------------------------------
/training/__pycache__/optimizer.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/optimizer.cpython-312.pyc
--------------------------------------------------------------------------------
/training/__pycache__/trainer.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/__pycache__/trainer.cpython-312.pyc
--------------------------------------------------------------------------------
/training/assets/MOSE_sample_val_list.txt:
--------------------------------------------------------------------------------
1 | 32e5d721
2 | 5bad0bab
3 | 267bfd6c
4 | 0a43a414
5 | 56c56ca9
6 | 9a1146b3
7 | c6ad7aaf
8 | 78a1f4b1
9 | fc455e73
10 | 072e7b3f
11 | 77ccb57d
12 | a76ee415
13 | 8cdcfc17
14 | 5d518b42
15 | 376dd830
16 | 0e843fc8
17 | 2af0e766
18 | 2bd4e845
19 | de2f2a6a
20 | ade9ee91
21 | 001ca3cb
22 | fc4c1c67
23 | 8ef55579
24 | b84ce852
25 | 4cc8528a
26 | 767ffaaa
27 | 112a2ef0
28 | a338c8aa
29 | cbd144f5
30 | 5ff72128
31 | 86a949e2
32 | 9f2323ac
33 | 1fab1d1c
34 | 75924351
35 | ef55817b
36 | 02deca50
37 | 4d979d99
38 | 4d65f873
39 | 28470fa0
40 | 0d1575fe
41 | 06ea172e
42 | 29a6ddc2
43 | 797f1bec
44 | 780e7a99
45 | b9ed5b44
46 | 02a236b4
47 | 607d8ff5
48 | af5666b2
49 | 0558d0ed
50 | a938c6b2
51 | 103df575
52 | 77110e80
53 | 739e5a07
54 | 6763a576
55 | 06ebc138
56 | ba4b3b09
57 | b35cc2f3
58 | 4e0597a0
59 | 5949ee84
60 | 5348d547
61 | 323c4236
62 | b3b51117
63 | 55727ddd
64 | ab2714f3
65 | d2878895
66 | c0734cb3
67 | 94f7c53e
68 | 2a2745e5
69 | 442ffb54
70 | 3592425a
71 | 50ae03b0
72 | 5f150435
73 | 3067f9fa
74 | 9ffb2818
75 | adeaf5aa
76 | 31caacec
77 | 1cd99b86
78 | aa22f9d0
79 | 8fa50320
80 | e6348d2c
81 | 42ff84a5
82 | 8c8b7913
83 | c96adcbc
84 | 495be321
85 | db735509
86 | ee113fc4
87 | a678cdab
88 | c409ca4d
89 | 68d2b259
90 | 592b4dee
91 | 4e2b4dc7
92 | eb4d26e1
93 | 2009a00f
94 | bec5c89d
95 | 67191f24
96 | a3e85b4b
97 | da7080cd
98 | 80d978e9
99 | 36dcb93f
100 | a41e8c44
101 | 12fdc864
102 | 46d140ea
103 | 657c9dd9
104 | a86f84ee
105 | 90c1c43d
106 | 33015509
107 | afc7664d
108 | 23df06e1
109 | 291d4799
110 | 0ab75563
111 | 251bf059
112 | bcefdcc4
113 | ce9a2796
114 | 94d3403a
115 | 8f2e04bc
116 | f9cda066
117 | 9dfa2cc5
118 | 66924c91
119 | e765a09e
120 | 15654ee1
121 | 48e0bd39
122 | ee095221
123 | 2463609b
124 | 544d0d1f
125 | 51b8c2e1
126 | d321dde4
127 | 4cb11a5f
128 | d7058a0d
129 | 37af282a
130 | fabae187
131 | 7be91184
132 | 181ec185
133 | 2d16ceeb
134 | b56be4b1
135 | 6699eff0
136 | 79acac96
137 | d61c4665
138 | 0c13e1e7
139 | 100f6ecf
140 | 71217dfc
141 | 82df0888
142 | 4c42c747
143 | c9fdf703
144 | d2efeb4b
145 | 69ed9d14
146 | 64914fb6
147 | 255bedbc
148 | 4ea934d8
149 | a034feb2
150 | e4f4ddae
151 | e36a3026
152 | c1489591
153 | 111bb373
154 | e1d9fb32
155 | 93e22d48
156 | c1ec4b26
157 | d9638e69
158 | 60ab04c5
159 | cfe7773a
160 | 62132822
161 | 2f5fb2a3
162 | 7bdd197d
163 | 033333fd
164 | 130fcdbe
165 | 12e509c2
166 | 67138c33
167 | 6f90cc5f
168 | 4e3020fe
169 | bbdd8bb7
170 | b399ccdb
171 | fecd10d2
172 | 2e0967f7
173 | f509054f
174 | 792c6ff7
175 | 48e2afc5
176 | d904c048
177 | 111e0a5c
178 | b83024e2
179 | e6a7b79c
180 | bdc5ccf7
181 | b8146d00
182 | 9d394f1a
183 | 645b84f9
184 | 95ab2d0f
185 | e6f8a31d
186 | b4f876fb
187 | dc2c570d
188 | 3afd02d7
189 | 5c80c82c
190 | b1b32ddd
191 | 9f25fc61
192 | ba538072
193 | f8916fef
194 | 43c04ad2
195 | a658e949
196 | 2861dd53
197 | f6e40aba
198 | 09d305d1
199 | aac33bff
200 | 8d9d4c08
201 |
--------------------------------------------------------------------------------
/training/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/training/dataset/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/sam2_datasets.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/sam2_datasets.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/transforms.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/transforms.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/utils.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/vos_dataset.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_dataset.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/vos_raw_dataset.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_raw_dataset.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/vos_sampler.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_sampler.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/__pycache__/vos_segment_loader.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/dataset/__pycache__/vos_segment_loader.cpython-312.pyc
--------------------------------------------------------------------------------
/training/dataset/sam2_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import math
9 | from typing import Callable, Iterable, List, Optional, Sequence
10 |
11 | import torch
12 |
13 | from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset
14 |
15 | from torch.utils.data.distributed import DistributedSampler
16 |
17 |
18 | class MixedDataLoader:
19 | def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor):
20 | """
21 | Args:
22 | dataloaders (List[DataLoader]): List of DataLoaders to be mixed.
23 | mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from
24 |
25 | """
26 | assert len(dataloaders) == mixing_prob.shape[0]
27 | self.dataloaders = dataloaders
28 | self.mixing_prob = mixing_prob
29 | # Iterator state
30 | self._iter_dls = None
31 | self._iter_mixing_prob = None
32 | self.random_generator = torch.Generator()
33 |
34 | def __len__(self):
35 | return sum([len(d) for d in self.dataloaders])
36 |
37 | def __iter__(self):
38 | # Synchronize dataloader seeds
39 | self.random_generator.manual_seed(42)
40 | self._iter_dls = [iter(loader) for loader in self.dataloaders]
41 | self._iter_mixing_prob = self.mixing_prob.clone()
42 | return self
43 |
44 | def __next__(self):
45 | """
46 | Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted.
47 | """
48 | if self._iter_dls is None:
49 | raise TypeError(f"{type(self).__name__} object is not an iterator")
50 |
51 | while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob.
52 | dataset_idx = self._iter_mixing_prob.multinomial(
53 | 1, generator=self.random_generator
54 | ).item()
55 | try:
56 | item = next(self._iter_dls[dataset_idx])
57 | return item
58 | except StopIteration:
59 | # No more iterations for this dataset, set it's mixing probability to zero and try again.
60 | self._iter_mixing_prob[dataset_idx] = 0
61 | except Exception as e:
62 | # log and raise any other unexpected error.
63 | logging.error(e)
64 | raise e
65 |
66 | # Exhausted all iterators
67 | raise StopIteration
68 |
69 |
70 | class TorchTrainMixedDataset:
71 | def __init__(
72 | self,
73 | datasets: List[Dataset],
74 | batch_sizes: List[int],
75 | num_workers: int,
76 | shuffle: bool,
77 | pin_memory: bool,
78 | drop_last: bool,
79 | collate_fn: Optional[Callable] = None,
80 | worker_init_fn: Optional[Callable] = None,
81 | phases_per_epoch: int = 1,
82 | dataset_prob: Optional[List[float]] = None,
83 | ) -> None:
84 | """
85 | Args:
86 | datasets (List[Dataset]): List of Datasets to be mixed.
87 | batch_sizes (List[int]): Batch sizes for each dataset in the list.
88 | num_workers (int): Number of workers per dataloader.
89 | shuffle (bool): Whether or not to shuffle data.
90 | pin_memory (bool): If True, use pinned memory when loading tensors from disk.
91 | drop_last (bool): Whether or not to drop the last batch of data.
92 | collate_fn (Callable): Function to merge a list of samples into a mini-batch.
93 | worker_init_fn (Callable): Function to init each dataloader worker.
94 | phases_per_epoch (int): Number of phases per epoch.
95 | dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0
96 | """
97 |
98 | self.datasets = datasets
99 | self.batch_sizes = batch_sizes
100 | self.num_workers = num_workers
101 | self.shuffle = shuffle
102 | self.pin_memory = pin_memory
103 | self.drop_last = drop_last
104 | self.collate_fn = collate_fn
105 | self.worker_init_fn = worker_init_fn
106 | assert len(self.datasets) > 0
107 | for dataset in self.datasets:
108 | assert not isinstance(dataset, IterableDataset), "Not supported"
109 | # `RepeatFactorWrapper` requires calling set_epoch first to get its length
110 | self._set_dataset_epoch(dataset, 0)
111 | self.phases_per_epoch = phases_per_epoch
112 | self.chunks = [None] * len(datasets)
113 | if dataset_prob is None:
114 | # If not provided, assign each dataset a probability proportional to its length.
115 | dataset_lens = [
116 | (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs))
117 | for d, bs in zip(datasets, batch_sizes)
118 | ]
119 | total_len = sum(dataset_lens)
120 | dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens])
121 | else:
122 | assert len(dataset_prob) == len(datasets)
123 | dataset_prob = torch.tensor(dataset_prob)
124 |
125 | logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}")
126 | assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0"
127 | self.dataset_prob = dataset_prob
128 |
129 | def _set_dataset_epoch(self, dataset, epoch: int) -> None:
130 | if hasattr(dataset, "epoch"):
131 | dataset.epoch = epoch
132 | if hasattr(dataset, "set_epoch"):
133 | dataset.set_epoch(epoch)
134 |
135 | def get_loader(self, epoch) -> Iterable:
136 | dataloaders = []
137 | for d_idx, (dataset, batch_size) in enumerate(
138 | zip(self.datasets, self.batch_sizes)
139 | ):
140 | if self.phases_per_epoch > 1:
141 | # Major epoch that looops over entire dataset
142 | # len(main_epoch) == phases_per_epoch * len(epoch)
143 | main_epoch = epoch // self.phases_per_epoch
144 |
145 | # Phase with in the main epoch
146 | local_phase = epoch % self.phases_per_epoch
147 |
148 | # Start of new data-epoch or job is resumed after preemtion.
149 | if local_phase == 0 or self.chunks[d_idx] is None:
150 | # set seed for dataset epoch
151 | # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
152 | self._set_dataset_epoch(dataset, main_epoch)
153 |
154 | # Separate random generator for subset sampling
155 | g = torch.Generator()
156 | g.manual_seed(main_epoch)
157 | self.chunks[d_idx] = torch.chunk(
158 | torch.randperm(len(dataset), generator=g),
159 | self.phases_per_epoch,
160 | )
161 |
162 | dataset = Subset(dataset, self.chunks[d_idx][local_phase])
163 | else:
164 | self._set_dataset_epoch(dataset, epoch)
165 |
166 | sampler = DistributedSampler(dataset, shuffle=self.shuffle)
167 | sampler.set_epoch(epoch)
168 |
169 | batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
170 | dataloaders.append(
171 | DataLoader(
172 | dataset,
173 | num_workers=self.num_workers,
174 | pin_memory=self.pin_memory,
175 | batch_sampler=batch_sampler,
176 | collate_fn=self.collate_fn,
177 | worker_init_fn=self.worker_init_fn,
178 | )
179 | )
180 | return MixedDataLoader(dataloaders, self.dataset_prob)
181 |
--------------------------------------------------------------------------------
/training/dataset/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some wrapping utilities extended from pytorch's to support repeat factor sampling in particular"""
8 |
9 | from typing import Iterable
10 |
11 | import torch
12 | from torch.utils.data import (
13 | ConcatDataset as TorchConcatDataset,
14 | Dataset,
15 | Subset as TorchSubset,
16 | )
17 |
18 |
19 | class ConcatDataset(TorchConcatDataset):
20 | def __init__(self, datasets: Iterable[Dataset]) -> None:
21 | super(ConcatDataset, self).__init__(datasets)
22 |
23 | self.repeat_factors = torch.cat([d.repeat_factors for d in datasets])
24 |
25 | def set_epoch(self, epoch: int):
26 | for dataset in self.datasets:
27 | if hasattr(dataset, "epoch"):
28 | dataset.epoch = epoch
29 | if hasattr(dataset, "set_epoch"):
30 | dataset.set_epoch(epoch)
31 |
32 |
33 | class Subset(TorchSubset):
34 | def __init__(self, dataset, indices) -> None:
35 | super(Subset, self).__init__(dataset, indices)
36 |
37 | self.repeat_factors = dataset.repeat_factors[indices]
38 | assert len(indices) == len(self.repeat_factors)
39 |
40 |
41 | # Adapted from Detectron2
42 | class RepeatFactorWrapper(Dataset):
43 | """
44 | Thin wrapper around a dataset to implement repeat factor sampling.
45 | The underlying dataset must have a repeat_factors member to indicate the per-image factor.
46 | Set it to uniformly ones to disable repeat factor sampling
47 | """
48 |
49 | def __init__(self, dataset, seed: int = 0):
50 | self.dataset = dataset
51 | self.epoch_ids = None
52 | self._seed = seed
53 |
54 | # Split into whole number (_int_part) and fractional (_frac_part) parts.
55 | self._int_part = torch.trunc(dataset.repeat_factors)
56 | self._frac_part = dataset.repeat_factors - self._int_part
57 |
58 | def _get_epoch_indices(self, generator):
59 | """
60 | Create a list of dataset indices (with repeats) to use for one epoch.
61 |
62 | Args:
63 | generator (torch.Generator): pseudo random number generator used for
64 | stochastic rounding.
65 |
66 | Returns:
67 | torch.Tensor: list of dataset indices to use in one epoch. Each index
68 | is repeated based on its calculated repeat factor.
69 | """
70 | # Since repeat factors are fractional, we use stochastic rounding so
71 | # that the target repeat factor is achieved in expectation over the
72 | # course of training
73 | rands = torch.rand(len(self._frac_part), generator=generator)
74 | rep_factors = self._int_part + (rands < self._frac_part).float()
75 | # Construct a list of indices in which we repeat images as specified
76 | indices = []
77 | for dataset_index, rep_factor in enumerate(rep_factors):
78 | indices.extend([dataset_index] * int(rep_factor.item()))
79 | return torch.tensor(indices, dtype=torch.int64)
80 |
81 | def __len__(self):
82 | if self.epoch_ids is None:
83 | # Here we raise an error instead of returning original len(self.dataset) avoid
84 | # accidentally using unwrapped length. Otherwise it's error-prone since the
85 | # length changes to `len(self.epoch_ids)`changes after set_epoch is called.
86 | raise RuntimeError("please call set_epoch first to get wrapped length")
87 | # return len(self.dataset)
88 |
89 | return len(self.epoch_ids)
90 |
91 | def set_epoch(self, epoch: int):
92 | g = torch.Generator()
93 | g.manual_seed(self._seed + epoch)
94 | self.epoch_ids = self._get_epoch_indices(g)
95 | if hasattr(self.dataset, "set_epoch"):
96 | self.dataset.set_epoch(epoch)
97 |
98 | def __getitem__(self, idx):
99 | if self.epoch_ids is None:
100 | raise RuntimeError(
101 | "Repeat ids haven't been computed. Did you forget to call set_epoch?"
102 | )
103 |
104 | return self.dataset[self.epoch_ids[idx]]
105 |
--------------------------------------------------------------------------------
/training/dataset/vos_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import random
9 | from copy import deepcopy
10 |
11 | import numpy as np
12 |
13 | import torch
14 | from iopath.common.file_io import g_pathmgr
15 | from PIL import Image as PILImage
16 | from torchvision.datasets.vision import VisionDataset
17 |
18 | from training.dataset.vos_raw_dataset import VOSRawDataset
19 | from training.dataset.vos_sampler import VOSSampler
20 | from training.dataset.vos_segment_loader import JSONSegmentLoader
21 |
22 | from training.utils.data_utils import Frame, Object, VideoDatapoint
23 |
24 | MAX_RETRIES = 100
25 |
26 |
27 | class VOSDataset(VisionDataset):
28 | def __init__(
29 | self,
30 | transforms,
31 | training: bool,
32 | video_dataset: VOSRawDataset,
33 | sampler: VOSSampler,
34 | multiplier: int,
35 | always_target=True,
36 | target_segments_available=True,
37 | ):
38 | self._transforms = transforms
39 | self.training = training
40 | self.video_dataset = video_dataset
41 | self.sampler = sampler
42 |
43 | self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32)
44 | self.repeat_factors *= multiplier
45 | print(f"Raw dataset length = {len(self.video_dataset)}")
46 |
47 | self.curr_epoch = 0 # Used in case data loader behavior changes across epochs
48 | self.always_target = always_target
49 | self.target_segments_available = target_segments_available
50 |
51 | def _get_datapoint(self, idx):
52 |
53 | for retry in range(MAX_RETRIES):
54 | try:
55 | if isinstance(idx, torch.Tensor):
56 | idx = idx.item()
57 | # sample a video
58 | video, segment_loader = self.video_dataset.get_video(idx)
59 | # sample frames and object indices to be used in a datapoint
60 | sampled_frms_and_objs = self.sampler.sample(
61 | video, segment_loader, epoch=self.curr_epoch
62 | )
63 | break # Succesfully loaded video
64 | except Exception as e:
65 | if self.training:
66 | logging.warning(
67 | f"Loading failed (id={idx}); Retry {retry} with exception: {e}"
68 | )
69 | idx = random.randrange(0, len(self.video_dataset))
70 | else:
71 | # Shouldn't fail to load a val video
72 | raise e
73 |
74 | datapoint = self.construct(video, sampled_frms_and_objs, segment_loader)
75 | for transform in self._transforms:
76 | datapoint = transform(datapoint, epoch=self.curr_epoch)
77 | return datapoint
78 |
79 | def construct(self, video, sampled_frms_and_objs, segment_loader):
80 | """
81 | Constructs a VideoDatapoint sample to pass to transforms
82 | """
83 | sampled_frames = sampled_frms_and_objs.frames
84 | sampled_object_ids = sampled_frms_and_objs.object_ids
85 |
86 | images = []
87 | rgb_images = load_images(sampled_frames)
88 | # Iterate over the sampled frames and store their rgb data and object data (bbox, segment)
89 | for frame_idx, frame in enumerate(sampled_frames):
90 | w, h = rgb_images[frame_idx].size
91 | images.append(
92 | Frame(
93 | data=rgb_images[frame_idx],
94 | objects=[],
95 | )
96 | )
97 | # We load the gt segments associated with the current frame
98 | if isinstance(segment_loader, JSONSegmentLoader):
99 | segments = segment_loader.load(
100 | frame.frame_idx, obj_ids=sampled_object_ids
101 | )
102 | else:
103 | segments = segment_loader.load(frame.frame_idx)
104 | for obj_id in sampled_object_ids:
105 | # Extract the segment
106 | if obj_id in segments:
107 | assert (
108 | segments[obj_id] is not None
109 | ), "None targets are not supported"
110 | # segment is uint8 and remains uint8 throughout the transforms
111 | segment = segments[obj_id].to(torch.uint8)
112 | else:
113 | # There is no target, we either use a zero mask target or drop this object
114 | if not self.always_target:
115 | continue
116 | segment = torch.zeros(h, w, dtype=torch.uint8)
117 |
118 | images[frame_idx].objects.append(
119 | Object(
120 | object_id=obj_id,
121 | frame_index=frame.frame_idx,
122 | segment=segment,
123 | )
124 | )
125 | return VideoDatapoint(
126 | frames=images,
127 | video_id=video.video_id,
128 | size=(h, w),
129 | )
130 |
131 | def __getitem__(self, idx):
132 | return self._get_datapoint(idx)
133 |
134 | def __len__(self):
135 | return len(self.video_dataset)
136 |
137 |
138 | def load_images(frames):
139 | all_images = []
140 | cache = {}
141 | for frame in frames:
142 | if frame.data is None:
143 | # Load the frame rgb data from file
144 | path = frame.image_path
145 | if path in cache:
146 | all_images.append(deepcopy(all_images[cache[path]]))
147 | continue
148 | with g_pathmgr.open(path, "rb") as fopen:
149 | all_images.append(PILImage.open(fopen).convert("RGB"))
150 | cache[path] = len(all_images) - 1
151 | else:
152 | # The frame rgb data has already been loaded
153 | # Convert it to a PILImage
154 | all_images.append(tensor_2_PIL(frame.data))
155 |
156 | return all_images
157 |
158 |
159 | def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image:
160 | data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0
161 | data = data.astype(np.uint8)
162 | return PILImage.fromarray(data)
163 |
--------------------------------------------------------------------------------
/training/dataset/vos_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | from dataclasses import dataclass
9 | from typing import List
10 |
11 | from training.dataset.vos_segment_loader import LazySegments
12 |
13 | MAX_RETRIES = 1000
14 |
15 |
16 | @dataclass
17 | class SampledFramesAndObjects:
18 | frames: List[int]
19 | object_ids: List[int]
20 |
21 |
22 | class VOSSampler:
23 | def __init__(self, sort_frames=True):
24 | # frames are ordered by frame id when sort_frames is True
25 | self.sort_frames = sort_frames
26 |
27 | def sample(self, video):
28 | raise NotImplementedError()
29 |
30 |
31 | class RandomUniformSampler(VOSSampler):
32 | def __init__(
33 | self,
34 | num_frames,
35 | max_num_objects,
36 | reverse_time_prob=0.0,
37 | ):
38 | self.num_frames = num_frames
39 | self.max_num_objects = max_num_objects
40 | self.reverse_time_prob = reverse_time_prob
41 |
42 | def sample(self, video, segment_loader, epoch=None):
43 |
44 | for retry in range(MAX_RETRIES):
45 | if len(video.frames) < self.num_frames:
46 | raise Exception(
47 | f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames."
48 | )
49 | start = random.randrange(0, len(video.frames) - self.num_frames + 1)
50 | frames = [video.frames[start + step] for step in range(self.num_frames)]
51 | if random.uniform(0, 1) < self.reverse_time_prob:
52 | # Reverse time
53 | frames = frames[::-1]
54 |
55 | # Get first frame object ids
56 | visible_object_ids = []
57 | loaded_segms = segment_loader.load(frames[0].frame_idx)
58 | if isinstance(loaded_segms, LazySegments):
59 | # LazySegments for SA1BRawDataset
60 | visible_object_ids = list(loaded_segms.keys())
61 | else:
62 | for object_id, segment in segment_loader.load(
63 | frames[0].frame_idx
64 | ).items():
65 | if segment.sum():
66 | visible_object_ids.append(object_id)
67 |
68 | # First frame needs to have at least a target to track
69 | if len(visible_object_ids) > 0:
70 | break
71 | if retry >= MAX_RETRIES - 1:
72 | raise Exception("No visible objects")
73 |
74 | object_ids = random.sample(
75 | visible_object_ids,
76 | min(len(visible_object_ids), self.max_num_objects),
77 | )
78 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
79 |
80 |
81 | class EvalSampler(VOSSampler):
82 | """
83 | VOS Sampler for evaluation: sampling all the frames and all the objects in a video
84 | """
85 |
86 | def __init__(
87 | self,
88 | ):
89 | super().__init__()
90 |
91 | def sample(self, video, segment_loader, epoch=None):
92 | """
93 | Sampling all the frames and all the objects
94 | """
95 | if self.sort_frames:
96 | # ordered by frame id
97 | frames = sorted(video.frames, key=lambda x: x.frame_idx)
98 | else:
99 | # use the original order
100 | frames = video.frames
101 | object_ids = segment_loader.load(frames[0].frame_idx).keys()
102 | if len(object_ids) == 0:
103 | raise Exception("First frame of the video has no objects")
104 |
105 | return SampledFramesAndObjects(frames=frames, object_ids=object_ids)
106 |
--------------------------------------------------------------------------------
/training/dataset/vos_segment_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import glob
8 | import json
9 | import os
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import torch
14 |
15 | from PIL import Image as PILImage
16 |
17 | try:
18 | from pycocotools import mask as mask_utils
19 | except:
20 | pass
21 |
22 |
23 | class JSONSegmentLoader:
24 | def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None):
25 | # Annotations in the json are provided every ann_every th frame
26 | self.ann_every = ann_every
27 | # Ids of the objects to consider when sampling this video
28 | self.valid_obj_ids = valid_obj_ids
29 | with open(video_json_path, "r") as f:
30 | data = json.load(f)
31 | if isinstance(data, list):
32 | self.frame_annots = data
33 | elif isinstance(data, dict):
34 | masklet_field_name = "masklet" if "masklet" in data else "masks"
35 | self.frame_annots = data[masklet_field_name]
36 | if "fps" in data:
37 | if isinstance(data["fps"], list):
38 | annotations_fps = int(data["fps"][0])
39 | else:
40 | annotations_fps = int(data["fps"])
41 | assert frames_fps % annotations_fps == 0
42 | self.ann_every = frames_fps // annotations_fps
43 | else:
44 | raise NotImplementedError
45 |
46 | def load(self, frame_id, obj_ids=None):
47 | assert frame_id % self.ann_every == 0
48 | rle_mask = self.frame_annots[frame_id // self.ann_every]
49 |
50 | valid_objs_ids = set(range(len(rle_mask)))
51 | if self.valid_obj_ids is not None:
52 | # Remove the masklets that have been filtered out for this video
53 | valid_objs_ids &= set(self.valid_obj_ids)
54 | if obj_ids is not None:
55 | # Only keep the objects that have been sampled
56 | valid_objs_ids &= set(obj_ids)
57 | valid_objs_ids = sorted(list(valid_objs_ids))
58 |
59 | # Construct rle_masks_filtered that only contains the rle masks we are interested in
60 | id_2_idx = {}
61 | rle_mask_filtered = []
62 | for obj_id in valid_objs_ids:
63 | if rle_mask[obj_id] is not None:
64 | id_2_idx[obj_id] = len(rle_mask_filtered)
65 | rle_mask_filtered.append(rle_mask[obj_id])
66 | else:
67 | id_2_idx[obj_id] = None
68 |
69 | # Decode the masks
70 | raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
71 | 2, 0, 1
72 | ) # (num_obj, h, w)
73 | segments = {}
74 | for obj_id in valid_objs_ids:
75 | if id_2_idx[obj_id] is None:
76 | segments[obj_id] = None
77 | else:
78 | idx = id_2_idx[obj_id]
79 | segments[obj_id] = raw_segments[idx]
80 | return segments
81 |
82 | def get_valid_obj_frames_ids(self, num_frames_min=None):
83 | # For each object, find all the frames with a valid (not None) mask
84 | num_objects = len(self.frame_annots[0])
85 |
86 | # The result dict associates each obj_id with the id of its valid frames
87 | res = {obj_id: [] for obj_id in range(num_objects)}
88 |
89 | for annot_idx, annot in enumerate(self.frame_annots):
90 | for obj_id in range(num_objects):
91 | if annot[obj_id] is not None:
92 | res[obj_id].append(int(annot_idx * self.ann_every))
93 |
94 | if num_frames_min is not None:
95 | # Remove masklets that have less than num_frames_min valid masks
96 | for obj_id, valid_frames in list(res.items()):
97 | if len(valid_frames) < num_frames_min:
98 | res.pop(obj_id)
99 |
100 | return res
101 |
102 |
103 | class PalettisedPNGSegmentLoader:
104 | def __init__(self, video_png_root, sample_rate=1):
105 | """
106 | SegmentLoader for datasets with masks stored as palettised PNGs.
107 | video_png_root: the folder contains all the masks stored in png
108 | """
109 | self.video_png_root = video_png_root
110 | self.sample_rate = sample_rate
111 | # build a mapping from frame id to their PNG mask path
112 | # note that in some datasets, the PNG paths could have more
113 | # than 5 digits, e.g. "00000000.png" instead of "00000.png"
114 | png_filenames = sorted(glob.glob(os.path.join(self.video_png_root, "*.png"))) # os.listdir(self.video_png_root)
115 | self.frame_id_to_png_filename = {}
116 | for idx, filename in enumerate(png_filenames[::self.sample_rate]):
117 | frame_id = idx # int(os.path.basename(filename).split(".")[0])
118 | self.frame_id_to_png_filename[frame_id] = filename
119 |
120 | def load(self, frame_id):
121 | """
122 | load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
123 | Args:
124 | frame_id: int, define the mask path
125 | Return:
126 | binary_segments: dict
127 | """
128 | # check the path
129 | mask_path = os.path.join(
130 | self.video_png_root, self.frame_id_to_png_filename[frame_id]
131 | )
132 |
133 | # load the mask
134 | masks = PILImage.open(mask_path).convert("P")
135 | masks = np.array(masks)
136 |
137 | object_id = pd.unique(masks.flatten())
138 | object_id = object_id[object_id != 0] # remove background (0)
139 |
140 | # convert into N binary segmentation masks
141 | binary_segments = {}
142 | for i in object_id:
143 | bs = masks == i
144 | binary_segments[i] = torch.from_numpy(bs)
145 |
146 | return binary_segments
147 |
148 | def __len__(self):
149 | return
150 |
151 |
152 | class MultiplePNGSegmentLoader:
153 | def __init__(self, video_png_root, single_object_mode=False):
154 | """
155 | video_png_root: the folder contains all the masks stored in png
156 | single_object_mode: whether to load only a single object at a time
157 | """
158 | self.video_png_root = video_png_root
159 | self.single_object_mode = single_object_mode
160 | # read a mask to know the resolution of the video
161 | if self.single_object_mode:
162 | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0]
163 | else:
164 | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0]
165 | tmp_mask = np.array(PILImage.open(tmp_mask_path))
166 | self.H = tmp_mask.shape[0]
167 | self.W = tmp_mask.shape[1]
168 | if self.single_object_mode:
169 | self.obj_id = (
170 | int(video_png_root.split("/")[-1]) + 1
171 | ) # offset by 1 as bg is 0
172 | else:
173 | self.obj_id = None
174 |
175 | def load(self, frame_id):
176 | if self.single_object_mode:
177 | return self._load_single_png(frame_id)
178 | else:
179 | return self._load_multiple_pngs(frame_id)
180 |
181 | def _load_single_png(self, frame_id):
182 | """
183 | load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png')
184 | Args:
185 | frame_id: int, define the mask path
186 | Return:
187 | binary_segments: dict
188 | """
189 | mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png")
190 | binary_segments = {}
191 |
192 | if os.path.exists(mask_path):
193 | mask = np.array(PILImage.open(mask_path))
194 | else:
195 | # if png doesn't exist, empty mask
196 | mask = np.zeros((self.H, self.W), dtype=bool)
197 | binary_segments[self.obj_id] = torch.from_numpy(mask > 0)
198 | return binary_segments
199 |
200 | def _load_multiple_pngs(self, frame_id):
201 | """
202 | load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png')
203 | Args:
204 | frame_id: int, define the mask path
205 | Return:
206 | binary_segments: dict
207 | """
208 | # get the path
209 | all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*")))
210 | num_objects = len(all_objects)
211 | assert num_objects > 0
212 |
213 | # load the masks
214 | binary_segments = {}
215 | for obj_folder in all_objects:
216 | # obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder
217 | obj_id = int(obj_folder.split("/")[-1])
218 | obj_id = obj_id + 1 # offset 1 as bg is 0
219 | mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png")
220 | if os.path.exists(mask_path):
221 | mask = np.array(PILImage.open(mask_path))
222 | else:
223 | mask = np.zeros((self.H, self.W), dtype=bool)
224 | binary_segments[obj_id] = torch.from_numpy(mask > 0)
225 |
226 | return binary_segments
227 |
228 | def __len__(self):
229 | return
230 |
231 |
232 | class LazySegments:
233 | """
234 | Only decodes segments that are actually used.
235 | """
236 |
237 | def __init__(self):
238 | self.segments = {}
239 | self.cache = {}
240 |
241 | def __setitem__(self, key, item):
242 | self.segments[key] = item
243 |
244 | def __getitem__(self, key):
245 | if key in self.cache:
246 | return self.cache[key]
247 | rle = self.segments[key]
248 | mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0]
249 | self.cache[key] = mask
250 | return mask
251 |
252 | def __contains__(self, key):
253 | return key in self.segments
254 |
255 | def __len__(self):
256 | return len(self.segments)
257 |
258 | def keys(self):
259 | return self.segments.keys()
260 |
261 |
262 | class SA1BSegmentLoader:
263 | def __init__(
264 | self,
265 | video_mask_path,
266 | mask_area_frac_thresh=1.1,
267 | video_frame_path=None,
268 | uncertain_iou=-1,
269 | ):
270 | with open(video_mask_path, "r") as f:
271 | self.frame_annots = json.load(f)
272 |
273 | if mask_area_frac_thresh <= 1.0:
274 | # Lazily read frame
275 | orig_w, orig_h = PILImage.open(video_frame_path).size
276 | area = orig_w * orig_h
277 |
278 | self.frame_annots = self.frame_annots["annotations"]
279 |
280 | rle_masks = []
281 | for frame_annot in self.frame_annots:
282 | if not frame_annot["area"] > 0:
283 | continue
284 | if ("uncertain_iou" in frame_annot) and (
285 | frame_annot["uncertain_iou"] < uncertain_iou
286 | ):
287 | # uncertain_iou is stability score
288 | continue
289 | if (
290 | mask_area_frac_thresh <= 1.0
291 | and (frame_annot["area"] / area) >= mask_area_frac_thresh
292 | ):
293 | continue
294 | rle_masks.append(frame_annot["segmentation"])
295 |
296 | self.segments = LazySegments()
297 | for i, rle in enumerate(rle_masks):
298 | self.segments[i] = rle
299 |
300 | def load(self, frame_idx):
301 | return self.segments
302 |
303 |
304 | class NPZSegmentLoader:
305 | def __init__(self, masks):
306 | """
307 | Initialize the NPZSegmentLoader.
308 |
309 | Args:
310 | masks (numpy.ndarray): Array of masks with shape (img_num, H, W).
311 | """
312 | self.masks = masks
313 |
314 | def load(self, frame_idx):
315 | """
316 | Load the single mask for the given frame index and convert it to binary segments.
317 |
318 | Args:
319 | frame_idx (int): Index of the frame to load.
320 |
321 | Returns:
322 | dict: A dictionary where keys are object IDs and values are binary masks.
323 | """
324 | mask = self.masks[frame_idx]
325 |
326 | # Find unique object IDs in the mask, excluding the background (0)
327 | object_ids = np.unique(mask)
328 | object_ids = object_ids[object_ids != 0]
329 |
330 | # Convert into binary segmentation masks for each object
331 | binary_segments = {}
332 | for obj_id in object_ids:
333 | binary_mask = (mask == obj_id)
334 | binary_segments[int(obj_id)] = torch.from_numpy(binary_mask).bool()
335 |
336 | return binary_segments
337 |
--------------------------------------------------------------------------------
/training/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/training/model/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/model/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/training/model/__pycache__/sam2.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/model/__pycache__/sam2.cpython-312.pyc
--------------------------------------------------------------------------------
/training/scripts/sav_frame_extraction_submitit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | import argparse
4 | import os
5 | from pathlib import Path
6 |
7 | import cv2
8 |
9 | import numpy as np
10 | import submitit
11 | import tqdm
12 |
13 |
14 | def get_args_parser():
15 | parser = argparse.ArgumentParser(
16 | description="[SA-V Preprocessing] Extracting JPEG frames",
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18 | )
19 |
20 | # ------------
21 | # DATA
22 | # ------------
23 | data_parser = parser.add_argument_group(
24 | title="SA-V dataset data root",
25 | description="What data to load and how to process it.",
26 | )
27 | data_parser.add_argument(
28 | "--sav-vid-dir",
29 | type=str,
30 | required=True,
31 | help=("Where to find the SAV videos"),
32 | )
33 | data_parser.add_argument(
34 | "--sav-frame-sample-rate",
35 | type=int,
36 | default=4,
37 | help="Rate at which to sub-sample frames",
38 | )
39 |
40 | # ------------
41 | # LAUNCH
42 | # ------------
43 | launch_parser = parser.add_argument_group(
44 | title="Cluster launch settings",
45 | description="Number of jobs and retry settings.",
46 | )
47 | launch_parser.add_argument(
48 | "--n-jobs",
49 | type=int,
50 | required=True,
51 | help="Shard the run over this many jobs.",
52 | )
53 | launch_parser.add_argument(
54 | "--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
55 | )
56 | launch_parser.add_argument(
57 | "--partition", type=str, required=True, help="Partition to launch on."
58 | )
59 | launch_parser.add_argument(
60 | "--account", type=str, required=True, help="Partition to launch on."
61 | )
62 | launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
63 |
64 | # ------------
65 | # OUTPUT
66 | # ------------
67 | output_parser = parser.add_argument_group(
68 | title="Setting for results output", description="Where and how to save results."
69 | )
70 | output_parser.add_argument(
71 | "--output-dir",
72 | type=str,
73 | required=True,
74 | help=("Where to dump the extracted jpeg frames"),
75 | )
76 | output_parser.add_argument(
77 | "--slurm-output-root-dir",
78 | type=str,
79 | required=True,
80 | help=("Where to save slurm outputs"),
81 | )
82 | return parser
83 |
84 |
85 | def decode_video(video_path: str):
86 | assert os.path.exists(video_path)
87 | video = cv2.VideoCapture(video_path)
88 | video_frames = []
89 | while video.isOpened():
90 | ret, frame = video.read()
91 | if ret:
92 | video_frames.append(frame)
93 | else:
94 | break
95 | return video_frames
96 |
97 |
98 | def extract_frames(video_path, sample_rate):
99 | frames = decode_video(video_path)
100 | return frames[::sample_rate]
101 |
102 |
103 | def submitit_launch(video_paths, sample_rate, save_root):
104 | for path in tqdm.tqdm(video_paths):
105 | frames = extract_frames(path, sample_rate)
106 | output_folder = os.path.join(save_root, Path(path).stem)
107 | if not os.path.exists(output_folder):
108 | os.makedirs(output_folder)
109 | for fid, frame in enumerate(frames):
110 | frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
111 | cv2.imwrite(frame_path, frame)
112 | print(f"Saved output to {save_root}")
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = get_args_parser()
117 | args = parser.parse_args()
118 |
119 | sav_vid_dir = args.sav_vid_dir
120 | save_root = args.output_dir
121 | sample_rate = args.sav_frame_sample_rate
122 |
123 | # List all SA-V videos
124 | mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
125 | mp4_files = np.array(mp4_files)
126 | chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
127 |
128 | print(f"Processing videos in: {sav_vid_dir}")
129 | print(f"Processing {len(mp4_files)} files")
130 | print(f"Beginning processing in {args.n_jobs} processes")
131 |
132 | # Submitit params
133 | jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
134 | cpus_per_task = 4
135 | executor = submitit.AutoExecutor(folder=jobs_dir)
136 | executor.update_parameters(
137 | timeout_min=args.timeout,
138 | gpus_per_node=0,
139 | tasks_per_node=1,
140 | slurm_array_parallelism=args.n_jobs,
141 | cpus_per_task=cpus_per_task,
142 | slurm_partition=args.partition,
143 | slurm_account=args.account,
144 | slurm_qos=args.qos,
145 | )
146 | executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
147 |
148 | # Launch
149 | jobs = []
150 | with executor.batch():
151 | for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
152 | job = executor.submit(
153 | submitit_launch,
154 | video_paths=mp4_chunk,
155 | sample_rate=sample_rate,
156 | save_root=save_root,
157 | )
158 | jobs.append(job)
159 |
160 | for j in jobs:
161 | print(f"Slurm JobID: {j.job_id}")
162 | print(f"Saving outputs to {save_root}")
163 | print(f"Slurm outputs at {args.slurm_output_root_dir}")
164 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import random
10 | import sys
11 | import traceback
12 | from argparse import ArgumentParser
13 |
14 | import submitit
15 | import torch
16 |
17 | from hydra import compose, initialize_config_module
18 | from hydra.utils import instantiate
19 |
20 | from iopath.common.file_io import g_pathmgr
21 | from omegaconf import OmegaConf
22 |
23 | from training.utils.train_utils import makedir, register_omegaconf_resolvers
24 |
25 | os.environ["HYDRA_FULL_ERROR"] = "1"
26 |
27 |
28 | def single_proc_run(local_rank, main_port, cfg, world_size, node_rank, master_addr):
29 | """Single GPU process"""
30 | os.environ["MASTER_ADDR"] = master_addr
31 | os.environ["MASTER_PORT"] = str(main_port)
32 | os.environ["RANK"] = str(node_rank * cfg.launcher.gpus_per_node + local_rank)
33 | os.environ["LOCAL_RANK"] = str(local_rank)
34 | os.environ["WORLD_SIZE"] = str(world_size)
35 | try:
36 | register_omegaconf_resolvers()
37 | except Exception as e:
38 | logging.info(e)
39 |
40 | trainer = instantiate(cfg.trainer, _recursive_=False)
41 | trainer.run()
42 |
43 |
44 | def single_node_runner(cfg, main_port: int, node_rank: int = 0, master_addr: str = "localhost"):
45 | num_proc = cfg.launcher.gpus_per_node
46 | world_size = cfg.launcher.gpus_per_node * cfg.launcher.num_nodes
47 | torch.multiprocessing.set_start_method(
48 | "spawn"
49 | ) # CUDA runtime does not support `fork`
50 | if num_proc == 1:
51 | # directly call single_proc so we can easily set breakpoints
52 | # mp.spawn does not let us set breakpoints
53 | single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=world_size, node_rank=node_rank, master_addr=master_addr)
54 | else:
55 | mp_runner = torch.multiprocessing.start_processes
56 | args = (main_port, cfg, world_size, node_rank, master_addr)
57 | mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
58 |
59 |
60 | def format_exception(e: Exception, limit=20):
61 | traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
62 | return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"
63 |
64 |
65 | class SubmititRunner(submitit.helpers.Checkpointable):
66 | """A callable which is passed to submitit to launch the jobs."""
67 |
68 | def __init__(self, port, cfg):
69 | self.cfg = cfg
70 | self.port = port
71 | self.has_setup = False
72 |
73 | def run_trainer(self):
74 | job_env = submitit.JobEnvironment()
75 | # Need to add this again so the hydra.job.set_env PYTHONPATH
76 | # is also set when launching jobs.
77 | add_pythonpath_to_sys_path()
78 | os.environ["MASTER_ADDR"] = job_env.hostnames[0]
79 | os.environ["MASTER_PORT"] = str(self.port)
80 | os.environ["RANK"] = str(job_env.global_rank)
81 | os.environ["LOCAL_RANK"] = str(job_env.local_rank)
82 | os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
83 |
84 | register_omegaconf_resolvers()
85 | cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)
86 | cfg_resolved = OmegaConf.create(cfg_resolved)
87 |
88 | trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
89 | trainer.run()
90 |
91 | def __call__(self):
92 | job_env = submitit.JobEnvironment()
93 | self.setup_job_info(job_env.job_id, job_env.global_rank)
94 | try:
95 | self.run_trainer()
96 | except Exception as e:
97 | # Log the exception. Then raise it again (as what SubmititRunner currently does).
98 | message = format_exception(e)
99 | logging.error(message)
100 | raise e
101 |
102 | def setup_job_info(self, job_id, rank):
103 | """Set up slurm job info"""
104 | self.job_info = {
105 | "job_id": job_id,
106 | "rank": rank,
107 | "cluster": self.cfg.get("cluster", None),
108 | "experiment_log_dir": self.cfg.launcher.experiment_log_dir,
109 | }
110 |
111 | self.has_setup = True
112 |
113 |
114 | def add_pythonpath_to_sys_path():
115 | if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
116 | return
117 | sys.path = os.environ["PYTHONPATH"].split(":") + sys.path
118 |
119 |
120 | def main(args, cfg) -> None:
121 |
122 | if cfg.launcher.experiment_log_dir is None:
123 | cfg.launcher.experiment_log_dir = os.path.join(
124 | os.getcwd(), "sam2_logs", args.config
125 | )
126 | print("###################### Train App Config ####################")
127 | print(OmegaConf.to_yaml(cfg))
128 | print("############################################################")
129 |
130 | add_pythonpath_to_sys_path()
131 | makedir(cfg.launcher.experiment_log_dir)
132 | with g_pathmgr.open(
133 | os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
134 | ) as f:
135 | f.write(OmegaConf.to_yaml(cfg))
136 |
137 | cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
138 | cfg_resolved = OmegaConf.create(cfg_resolved)
139 |
140 | with g_pathmgr.open(
141 | os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
142 | ) as f:
143 | f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))
144 |
145 | submitit_conf = cfg.get("submitit", None)
146 | assert submitit_conf is not None, "Missing submitit config"
147 |
148 | submitit_dir = cfg.launcher.experiment_log_dir
149 | submitit_dir = os.path.join(submitit_dir, "submitit_logs")
150 | # Priotrize cmd line args
151 | cfg.launcher.gpus_per_node = (
152 | args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
153 | )
154 | cfg.launcher.num_nodes = (
155 | args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
156 | )
157 | submitit_conf.use_cluster = (
158 | args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
159 | )
160 | if submitit_conf.use_cluster:
161 | executor = submitit.AutoExecutor(folder=submitit_dir)
162 | submitit_conf.partition = (
163 | args.partition
164 | if args.partition is not None
165 | else submitit_conf.get("partition", None)
166 | )
167 | submitit_conf.account = (
168 | args.account
169 | if args.account is not None
170 | else submitit_conf.get("account", None)
171 | )
172 | submitit_conf.qos = (
173 | args.qos if args.qos is not None else submitit_conf.get("qos", None)
174 | )
175 | job_kwargs = {
176 | "timeout_min": 60 * submitit_conf.timeout_hour,
177 | "name": (
178 | submitit_conf.name if hasattr(submitit_conf, "name") else args.config
179 | ),
180 | "slurm_partition": submitit_conf.partition,
181 | "gpus_per_node": cfg.launcher.gpus_per_node,
182 | "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU
183 | "cpus_per_task": submitit_conf.cpus_per_task,
184 | "nodes": cfg.launcher.num_nodes,
185 | "slurm_additional_parameters": {
186 | "exclude": " ".join(submitit_conf.get("exclude_nodes", [])),
187 | },
188 | }
189 | if "include_nodes" in submitit_conf:
190 | assert (
191 | len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
192 | ), "Not enough nodes"
193 | job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
194 | submitit_conf["include_nodes"]
195 | )
196 | if submitit_conf.account is not None:
197 | job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
198 | if submitit_conf.qos is not None:
199 | job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos
200 |
201 | if submitit_conf.get("mem_gb", None) is not None:
202 | job_kwargs["mem_gb"] = submitit_conf.mem_gb
203 | elif submitit_conf.get("mem", None) is not None:
204 | job_kwargs["slurm_mem"] = submitit_conf.mem
205 |
206 | if submitit_conf.get("constraints", None) is not None:
207 | job_kwargs["slurm_constraint"] = submitit_conf.constraints
208 |
209 | if submitit_conf.get("comment", None) is not None:
210 | job_kwargs["slurm_comment"] = submitit_conf.comment
211 |
212 | # Supports only cpu-bind option within srun_args. New options can be added here
213 | if submitit_conf.get("srun_args", None) is not None:
214 | job_kwargs["slurm_srun_args"] = []
215 | if submitit_conf.srun_args.get("cpu_bind", None) is not None:
216 | job_kwargs["slurm_srun_args"].extend(
217 | ["--cpu-bind", submitit_conf.srun_args.cpu_bind]
218 | )
219 |
220 | print("###################### SLURM Config ####################")
221 | print(job_kwargs)
222 | print("##########################################")
223 | executor.update_parameters(**job_kwargs)
224 |
225 | main_port = random.randint(
226 | submitit_conf.port_range[0], submitit_conf.port_range[1]
227 | )
228 | runner = SubmititRunner(main_port, cfg)
229 | job = executor.submit(runner)
230 | print(f"Submitit Job ID: {job.job_id}")
231 | runner.setup_job_info(job.job_id, rank=0)
232 | else:
233 | # Handle the master_addr separately
234 | master_addr = args.master_addr if args.master_addr else "localhost"
235 |
236 | main_port = args.main_port if args.main_port else random.randint(
237 | submitit_conf.port_range[0], submitit_conf.port_range[1]
238 | )
239 | if 'SLURM_PROCID' in os.environ:
240 | node_rank = int(os.environ['SLURM_PROCID'])
241 | else:
242 | node_rank = 0
243 | single_node_runner(cfg, main_port, node_rank=node_rank, master_addr=master_addr)
244 |
245 |
246 | if __name__ == "__main__":
247 |
248 | initialize_config_module("sam2", version_base="1.2")
249 | parser = ArgumentParser()
250 | parser.add_argument(
251 | "-c",
252 | "--config",
253 | required=True,
254 | type=str,
255 | help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
256 | )
257 | parser.add_argument(
258 | "--use-cluster",
259 | type=int,
260 | default=None,
261 | help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
262 | )
263 | parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
264 | parser.add_argument("--account", type=str, default=None, help="SLURM account")
265 | parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
266 | parser.add_argument(
267 | "--num-gpus", type=int, default=None, help="number of GPUS per node"
268 | )
269 | parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
270 | parser.add_argument("--master-addr", type=str, default=None, help="Master node address")
271 | parser.add_argument("--main-port", type=int, default=None, help="Main port for communication")
272 | parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset, overrides cfg.dataset.folder")
273 | parser.add_argument("--output-path", type=str, default=None, help="Path to the experiment output, overrides cfg.launcher.experiment_log_dir")
274 | args = parser.parse_args()
275 | args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
276 | register_omegaconf_resolvers()
277 |
278 | # Override dataset folder and experiment output path if specified
279 | cfg = compose(config_name=args.config)
280 | if args.dataset_path is not None:
281 | cfg.dataset.folder = args.dataset_path
282 | if args.output_path is not None:
283 |
284 | cfg.launcher.experiment_log_dir = args.output_path
285 |
286 | main(args, cfg)
--------------------------------------------------------------------------------
/training/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/training/utils/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/__pycache__/checkpoint_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/checkpoint_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/__pycache__/data_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/data_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/__pycache__/distributed.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/distributed.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/__pycache__/logger.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/logger.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/__pycache__/train_utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bowang-lab/MedSAM2/8f160bc226d81eca0b6bca03f43149ee89b0293c/training/utils/__pycache__/train_utils.cpython-312.pyc
--------------------------------------------------------------------------------
/training/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Misc functions, including distributed helpers.
9 |
10 | Mostly copy-paste from torchvision references.
11 | """
12 |
13 | from dataclasses import dataclass
14 | from typing import List, Optional, Tuple, Union
15 |
16 | import torch
17 |
18 | from PIL import Image as PILImage
19 | from tensordict import tensorclass
20 |
21 |
22 | @tensorclass
23 | class BatchedVideoMetaData:
24 | """
25 | This class represents metadata about a batch of videos.
26 | Attributes:
27 | unique_objects_identifier: A tensor of shape Bx3 containing unique identifiers for each object in the batch. Index consists of (video_id, obj_id, frame_id)
28 | frame_orig_size: A tensor of shape Bx2 containing the original size of each frame in the batch.
29 | """
30 |
31 | unique_objects_identifier: torch.LongTensor
32 | frame_orig_size: torch.LongTensor
33 |
34 |
35 | @tensorclass
36 | class BatchedVideoDatapoint:
37 | """
38 | This class represents a batch of videos with associated annotations and metadata.
39 | Attributes:
40 | img_batch: A [TxBxCxHxW] tensor containing the image data for each frame in the batch, where T is the number of frames per video, and B is the number of videos in the batch.
41 | obj_to_frame_idx: A [TxOx2] tensor containing the image_batch index which the object belongs to. O is the number of objects in the batch.
42 | masks: A [TxOxHxW] tensor containing binary masks for each object in the batch.
43 | metadata: An instance of BatchedVideoMetaData containing metadata about the batch.
44 | dict_key: A string key used to identify the batch.
45 | """
46 |
47 | img_batch: torch.FloatTensor
48 | obj_to_frame_idx: torch.IntTensor
49 | masks: torch.BoolTensor
50 | metadata: BatchedVideoMetaData
51 |
52 | dict_key: str
53 |
54 | def pin_memory(self, device=None):
55 | return self.apply(torch.Tensor.pin_memory, device=device)
56 |
57 | @property
58 | def num_frames(self) -> int:
59 | """
60 | Returns the number of frames per video.
61 | """
62 | return self.batch_size[0]
63 |
64 | @property
65 | def num_videos(self) -> int:
66 | """
67 | Returns the number of videos in the batch.
68 | """
69 | return self.img_batch.shape[1]
70 |
71 | @property
72 | def flat_obj_to_img_idx(self) -> torch.IntTensor:
73 | """
74 | Returns a flattened tensor containing the object to img index.
75 | The flat index can be used to access a flattened img_batch of shape [(T*B)xCxHxW]
76 | """
77 | frame_idx, video_idx = self.obj_to_frame_idx.unbind(dim=-1)
78 | flat_idx = video_idx * self.num_frames + frame_idx
79 | return flat_idx
80 |
81 | @property
82 | def flat_img_batch(self) -> torch.FloatTensor:
83 | """
84 | Returns a flattened img_batch_tensor of shape [(B*T)xCxHxW]
85 | """
86 |
87 | return self.img_batch.transpose(0, 1).flatten(0, 1)
88 |
89 |
90 | @dataclass
91 | class Object:
92 | # Id of the object in the media
93 | object_id: int
94 | # Index of the frame in the media (0 if single image)
95 | frame_index: int
96 | segment: Union[torch.Tensor, dict] # RLE dict or binary mask
97 |
98 |
99 | @dataclass
100 | class Frame:
101 | data: Union[torch.Tensor, PILImage.Image]
102 | objects: List[Object]
103 |
104 |
105 | @dataclass
106 | class VideoDatapoint:
107 | """Refers to an image/video and all its annotations"""
108 |
109 | frames: List[Frame]
110 | video_id: int
111 | size: Tuple[int, int]
112 |
113 |
114 | def collate_fn(
115 | batch: List[VideoDatapoint],
116 | dict_key,
117 | ) -> BatchedVideoDatapoint:
118 | """
119 | Args:
120 | batch: A list of VideoDatapoint instances.
121 | dict_key (str): A string key used to identify the batch.
122 | """
123 | img_batch = []
124 | for video in batch:
125 | img_batch += [torch.stack([frame.data for frame in video.frames], dim=0)]
126 |
127 | img_batch = torch.stack(img_batch, dim=0).permute((1, 0, 2, 3, 4))
128 | T = img_batch.shape[0]
129 | # Prepare data structures for sequential processing. Per-frame processing but batched across videos.
130 | step_t_objects_identifier = [[] for _ in range(T)]
131 | step_t_frame_orig_size = [[] for _ in range(T)]
132 |
133 | step_t_masks = [[] for _ in range(T)]
134 | step_t_obj_to_frame_idx = [
135 | [] for _ in range(T)
136 | ] # List to store frame indices for each time step
137 |
138 | for video_idx, video in enumerate(batch):
139 | orig_video_id = video.video_id
140 | orig_frame_size = video.size
141 | for t, frame in enumerate(video.frames):
142 | objects = frame.objects
143 | for obj in objects:
144 | orig_obj_id = obj.object_id
145 | orig_frame_idx = obj.frame_index
146 | step_t_obj_to_frame_idx[t].append(
147 | torch.tensor([t, video_idx], dtype=torch.int)
148 | )
149 | step_t_masks[t].append(obj.segment.to(torch.bool))
150 | step_t_objects_identifier[t].append(
151 | torch.tensor([orig_video_id, orig_obj_id, orig_frame_idx])
152 | )
153 | step_t_frame_orig_size[t].append(torch.tensor(orig_frame_size))
154 |
155 | obj_to_frame_idx = torch.stack(
156 | [
157 | torch.stack(obj_to_frame_idx, dim=0)
158 | for obj_to_frame_idx in step_t_obj_to_frame_idx
159 | ],
160 | dim=0,
161 | )
162 | masks = torch.stack([torch.stack(masks, dim=0) for masks in step_t_masks], dim=0)
163 | objects_identifier = torch.stack(
164 | [torch.stack(id, dim=0) for id in step_t_objects_identifier], dim=0
165 | )
166 | frame_orig_size = torch.stack(
167 | [torch.stack(id, dim=0) for id in step_t_frame_orig_size], dim=0
168 | )
169 | return BatchedVideoDatapoint(
170 | img_batch=img_batch,
171 | obj_to_frame_idx=obj_to_frame_idx,
172 | masks=masks,
173 | metadata=BatchedVideoMetaData(
174 | unique_objects_identifier=objects_identifier,
175 | frame_orig_size=frame_orig_size,
176 | ),
177 | dict_key=dict_key,
178 | batch_size=[T],
179 | )
180 |
--------------------------------------------------------------------------------
/training/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
8 | import atexit
9 | import functools
10 | import logging
11 | import sys
12 | import uuid
13 | from typing import Any, Dict, Optional, Union
14 |
15 | from hydra.utils import instantiate
16 |
17 | from iopath.common.file_io import g_pathmgr
18 | from numpy import ndarray
19 | from torch import Tensor
20 | from torch.utils.tensorboard import SummaryWriter
21 |
22 | from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
23 |
24 | Scalar = Union[Tensor, ndarray, int, float]
25 |
26 |
27 | def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
28 | makedir(log_dir)
29 | summary_writer_method = SummaryWriter
30 | return TensorBoardLogger(
31 | path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
32 | )
33 |
34 |
35 | class TensorBoardWriterWrapper:
36 | """
37 | A wrapper around a SummaryWriter object.
38 | """
39 |
40 | def __init__(
41 | self,
42 | path: str,
43 | *args: Any,
44 | filename_suffix: str = None,
45 | summary_writer_method: Any = SummaryWriter,
46 | **kwargs: Any,
47 | ) -> None:
48 | """Create a new TensorBoard logger.
49 | On construction, the logger creates a new events file that logs
50 | will be written to. If the environment variable `RANK` is defined,
51 | logger will only log if RANK = 0.
52 |
53 | NOTE: If using the logger with distributed training:
54 | - This logger can call collective operations
55 | - Logs will be written on rank 0 only
56 | - Logger must be constructed synchronously *after* initializing distributed process group.
57 |
58 | Args:
59 | path (str): path to write logs to
60 | *args, **kwargs: Extra arguments to pass to SummaryWriter
61 | """
62 | self._writer: Optional[SummaryWriter] = None
63 | _, self._rank = get_machine_local_and_dist_rank()
64 | self._path: str = path
65 | if self._rank == 0:
66 | logging.info(
67 | f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
68 | )
69 | self._writer = summary_writer_method(
70 | log_dir=path,
71 | *args,
72 | filename_suffix=filename_suffix or str(uuid.uuid4()),
73 | **kwargs,
74 | )
75 | else:
76 | logging.debug(
77 | f"Not logging meters on this host because env RANK: {self._rank} != 0"
78 | )
79 | atexit.register(self.close)
80 |
81 | @property
82 | def writer(self) -> Optional[SummaryWriter]:
83 | return self._writer
84 |
85 | @property
86 | def path(self) -> str:
87 | return self._path
88 |
89 | def flush(self) -> None:
90 | """Writes pending logs to disk."""
91 |
92 | if not self._writer:
93 | return
94 |
95 | self._writer.flush()
96 |
97 | def close(self) -> None:
98 | """Close writer, flushing pending logs to disk.
99 | Logs cannot be written after `close` is called.
100 | """
101 |
102 | if not self._writer:
103 | return
104 |
105 | self._writer.close()
106 | self._writer = None
107 |
108 |
109 | class TensorBoardLogger(TensorBoardWriterWrapper):
110 | """
111 | A simple logger for TensorBoard.
112 | """
113 |
114 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
115 | """Add multiple scalar values to TensorBoard.
116 |
117 | Args:
118 | payload (dict): dictionary of tag name and scalar value
119 | step (int, Optional): step value to record
120 | """
121 | if not self._writer:
122 | return
123 | for k, v in payload.items():
124 | self.log(k, v, step)
125 |
126 | def log(self, name: str, data: Scalar, step: int) -> None:
127 | """Add scalar data to TensorBoard.
128 |
129 | Args:
130 | name (string): tag name used to group scalars
131 | data (float/int/Tensor): scalar data to log
132 | step (int, optional): step value to record
133 | """
134 | if not self._writer:
135 | return
136 | self._writer.add_scalar(name, data, global_step=step, new_style=True)
137 |
138 | def log_hparams(
139 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
140 | ) -> None:
141 | """Add hyperparameter data to TensorBoard.
142 |
143 | Args:
144 | hparams (dict): dictionary of hyperparameter names and corresponding values
145 | meters (dict): dictionary of name of meter and corersponding values
146 | """
147 | if not self._writer:
148 | return
149 | self._writer.add_hparams(hparams, meters)
150 |
151 |
152 | class Logger:
153 | """
154 | A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
155 | """
156 |
157 | def __init__(self, logging_conf):
158 | # allow turning off TensorBoard with "should_log: false" in config
159 | tb_config = logging_conf.tensorboard_writer
160 | tb_should_log = tb_config and tb_config.pop("should_log", True)
161 | self.tb_logger = instantiate(tb_config) if tb_should_log else None
162 |
163 | def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
164 | if self.tb_logger:
165 | self.tb_logger.log_dict(payload, step)
166 |
167 | def log(self, name: str, data: Scalar, step: int) -> None:
168 | if self.tb_logger:
169 | self.tb_logger.log(name, data, step)
170 |
171 | def log_hparams(
172 | self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
173 | ) -> None:
174 | if self.tb_logger:
175 | self.tb_logger.log_hparams(hparams, meters)
176 |
177 |
178 | # cache the opened file object, so that different calls to `setup_logger`
179 | # with the same file name can safely write to the same file.
180 | @functools.lru_cache(maxsize=None)
181 | def _cached_log_stream(filename):
182 | # we tune the buffering value so that the logs are updated
183 | # frequently.
184 | log_buffer_kb = 10 * 1024 # 10KB
185 | io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
186 | atexit.register(io.close)
187 | return io
188 |
189 |
190 | def setup_logging(
191 | name,
192 | output_dir=None,
193 | rank=0,
194 | log_level_primary="INFO",
195 | log_level_secondary="ERROR",
196 | ):
197 | """
198 | Setup various logging streams: stdout and file handlers.
199 | For file handlers, we only setup for the master gpu.
200 | """
201 | # get the filename if we want to log to the file as well
202 | log_filename = None
203 | if output_dir:
204 | makedir(output_dir)
205 | if rank == 0:
206 | log_filename = f"{output_dir}/log.txt"
207 |
208 | logger = logging.getLogger(name)
209 | logger.setLevel(log_level_primary)
210 |
211 | # create formatter
212 | FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
213 | formatter = logging.Formatter(FORMAT)
214 |
215 | # Cleanup any existing handlers
216 | for h in logger.handlers:
217 | logger.removeHandler(h)
218 | logger.root.handlers = []
219 |
220 | # setup the console handler
221 | console_handler = logging.StreamHandler(sys.stdout)
222 | console_handler.setFormatter(formatter)
223 | logger.addHandler(console_handler)
224 | if rank == 0:
225 | console_handler.setLevel(log_level_primary)
226 | else:
227 | console_handler.setLevel(log_level_secondary)
228 |
229 | # we log to file as well if user wants
230 | if log_filename and rank == 0:
231 | file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
232 | file_handler.setLevel(log_level_primary)
233 | file_handler.setFormatter(formatter)
234 | logger.addHandler(file_handler)
235 |
236 | logging.root = logger
237 |
238 |
239 | def shutdown_logging():
240 | """
241 | After training is done, we ensure to shut down all the logger streams.
242 | """
243 | logging.info("Shutting down loggers...")
244 | handlers = logging.root.handlers
245 | for handler in handlers:
246 | handler.close()
247 |
--------------------------------------------------------------------------------
/training/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import math
9 | import os
10 | import random
11 | import re
12 | from datetime import timedelta
13 | from typing import Optional
14 |
15 | import hydra
16 |
17 | import numpy as np
18 | import omegaconf
19 | import torch
20 | import torch.distributed as dist
21 | from iopath.common.file_io import g_pathmgr
22 | from omegaconf import OmegaConf
23 |
24 |
25 | def multiply_all(*args):
26 | return np.prod(np.array(args)).item()
27 |
28 |
29 | def collect_dict_keys(config):
30 | """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined"""
31 | val_keys = []
32 | # If the this config points to the collate function, then it has a key
33 | if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]):
34 | val_keys.append(config["dict_key"])
35 | else:
36 | # Recursively proceed
37 | for v in config.values():
38 | if isinstance(v, type(config)):
39 | val_keys.extend(collect_dict_keys(v))
40 | elif isinstance(v, omegaconf.listconfig.ListConfig):
41 | for item in v:
42 | if isinstance(item, type(config)):
43 | val_keys.extend(collect_dict_keys(item))
44 | return val_keys
45 |
46 |
47 | class Phase:
48 | TRAIN = "train"
49 | VAL = "val"
50 |
51 |
52 | def register_omegaconf_resolvers():
53 | OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
54 | OmegaConf.register_new_resolver("get_class", hydra.utils.get_class)
55 | OmegaConf.register_new_resolver("add", lambda x, y: x + y)
56 | OmegaConf.register_new_resolver("times", multiply_all)
57 | OmegaConf.register_new_resolver("divide", lambda x, y: x / y)
58 | OmegaConf.register_new_resolver("pow", lambda x, y: x**y)
59 | OmegaConf.register_new_resolver("subtract", lambda x, y: x - y)
60 | OmegaConf.register_new_resolver("range", lambda x: list(range(x)))
61 | OmegaConf.register_new_resolver("int", lambda x: int(x))
62 | OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x)))
63 | OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x))
64 |
65 |
66 | def setup_distributed_backend(backend, timeout_mins):
67 | """
68 | Initialize torch.distributed and set the CUDA device.
69 | Expects environment variables to be set as per
70 | https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization
71 | along with the environ variable "LOCAL_RANK" which is used to set the CUDA device.
72 | """
73 | # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins
74 | # of waiting
75 | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
76 | logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins")
77 | dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins))
78 | return dist.get_rank()
79 |
80 |
81 | def get_machine_local_and_dist_rank():
82 | """
83 | Get the distributed and local rank of the current gpu.
84 | """
85 | local_rank = int(os.environ.get("LOCAL_RANK", None))
86 | distributed_rank = int(os.environ.get("RANK", None))
87 | assert (
88 | local_rank is not None and distributed_rank is not None
89 | ), "Please the set the RANK and LOCAL_RANK environment variables."
90 | return local_rank, distributed_rank
91 |
92 |
93 | def print_cfg(cfg):
94 | """
95 | Supports printing both Hydra DictConfig and also the AttrDict config
96 | """
97 | logging.info("Training with config:")
98 | logging.info(OmegaConf.to_yaml(cfg))
99 |
100 |
101 | def set_seeds(seed_value, max_epochs, dist_rank):
102 | """
103 | Set the python random, numpy and torch seed for each gpu. Also set the CUDA
104 | seeds if the CUDA is available. This ensures deterministic nature of the training.
105 | """
106 | # Since in the pytorch sampler, we increment the seed by 1 for every epoch.
107 | seed_value = (seed_value + dist_rank) * max_epochs
108 | logging.info(f"MACHINE SEED: {seed_value}")
109 | random.seed(seed_value)
110 | np.random.seed(seed_value)
111 | torch.manual_seed(seed_value)
112 | if torch.cuda.is_available():
113 | torch.cuda.manual_seed_all(seed_value)
114 |
115 |
116 | def makedir(dir_path):
117 | """
118 | Create the directory if it does not exist.
119 | """
120 | is_success = False
121 | try:
122 | if not g_pathmgr.exists(dir_path):
123 | g_pathmgr.mkdirs(dir_path)
124 | is_success = True
125 | except BaseException:
126 | logging.info(f"Error creating directory: {dir_path}")
127 | return is_success
128 |
129 |
130 | def is_dist_avail_and_initialized():
131 | if not dist.is_available():
132 | return False
133 | if not dist.is_initialized():
134 | return False
135 | return True
136 |
137 |
138 | def get_amp_type(amp_type: Optional[str] = None):
139 | if amp_type is None:
140 | return None
141 | assert amp_type in ["bfloat16", "float16"], "Invalid Amp type."
142 | if amp_type == "bfloat16":
143 | return torch.bfloat16
144 | else:
145 | return torch.float16
146 |
147 |
148 | def log_env_variables():
149 | env_keys = sorted(list(os.environ.keys()))
150 | st = ""
151 | for k in env_keys:
152 | v = os.environ[k]
153 | st += f"{k}={v}\n"
154 | logging.info("Logging ENV_VARIABLES")
155 | logging.info(st)
156 |
157 |
158 | class AverageMeter:
159 | """Computes and stores the average and current value"""
160 |
161 | def __init__(self, name, device, fmt=":f"):
162 | self.name = name
163 | self.fmt = fmt
164 | self.device = device
165 | self.reset()
166 |
167 | def reset(self):
168 | self.val = 0
169 | self.avg = 0
170 | self.sum = 0
171 | self.count = 0
172 | self._allow_updates = True
173 |
174 | def update(self, val, n=1):
175 | self.val = val
176 | self.sum += val * n
177 | self.count += n
178 | self.avg = self.sum / self.count
179 |
180 | def __str__(self):
181 | fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
182 | return fmtstr.format(**self.__dict__)
183 |
184 |
185 | class MemMeter:
186 | """Computes and stores the current, avg, and max of peak Mem usage per iteration"""
187 |
188 | def __init__(self, name, device, fmt=":f"):
189 | self.name = name
190 | self.fmt = fmt
191 | self.device = device
192 | self.reset()
193 |
194 | def reset(self):
195 | self.val = 0 # Per iteration max usage
196 | self.avg = 0 # Avg per iteration max usage
197 | self.peak = 0 # Peak usage for lifetime of program
198 | self.sum = 0
199 | self.count = 0
200 | self._allow_updates = True
201 |
202 | def update(self, n=1, reset_peak_usage=True):
203 | self.val = torch.cuda.max_memory_allocated() // 1e9
204 | self.sum += self.val * n
205 | self.count += n
206 | self.avg = self.sum / self.count
207 | self.peak = max(self.peak, self.val)
208 | if reset_peak_usage:
209 | torch.cuda.reset_peak_memory_stats()
210 |
211 | def __str__(self):
212 | fmtstr = (
213 | "{name}: {val"
214 | + self.fmt
215 | + "} ({avg"
216 | + self.fmt
217 | + "}/{peak"
218 | + self.fmt
219 | + "})"
220 | )
221 | return fmtstr.format(**self.__dict__)
222 |
223 |
224 | def human_readable_time(time_seconds):
225 | time = int(time_seconds)
226 | minutes, seconds = divmod(time, 60)
227 | hours, minutes = divmod(minutes, 60)
228 | days, hours = divmod(hours, 24)
229 | return f"{days:02}d {hours:02}h {minutes:02}m"
230 |
231 |
232 | class DurationMeter:
233 | def __init__(self, name, device, fmt=":f"):
234 | self.name = name
235 | self.device = device
236 | self.fmt = fmt
237 | self.val = 0
238 |
239 | def reset(self):
240 | self.val = 0
241 |
242 | def update(self, val):
243 | self.val = val
244 |
245 | def add(self, val):
246 | self.val += val
247 |
248 | def __str__(self):
249 | return f"{self.name}: {human_readable_time(self.val)}"
250 |
251 |
252 | class ProgressMeter:
253 | def __init__(self, num_batches, meters, real_meters, prefix=""):
254 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
255 | self.meters = meters
256 | self.real_meters = real_meters
257 | self.prefix = prefix
258 |
259 | def display(self, batch, enable_print=False):
260 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
261 | entries += [str(meter) for meter in self.meters]
262 | entries += [
263 | " | ".join(
264 | [
265 | f"{os.path.join(name, subname)}: {val:.4f}"
266 | for subname, val in meter.compute().items()
267 | ]
268 | )
269 | for name, meter in self.real_meters.items()
270 | ]
271 | logging.info(" | ".join(entries))
272 | if enable_print:
273 | print(" | ".join(entries))
274 |
275 | def _get_batch_fmtstr(self, num_batches):
276 | num_digits = len(str(num_batches // 1))
277 | fmt = "{:" + str(num_digits) + "d}"
278 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
279 |
280 |
281 | def get_resume_checkpoint(checkpoint_save_dir):
282 | if not g_pathmgr.isdir(checkpoint_save_dir):
283 | return None
284 | ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
285 | if not g_pathmgr.isfile(ckpt_file):
286 | return None
287 |
288 | return ckpt_file
289 |
--------------------------------------------------------------------------------