├── .gitignore ├── README.md ├── cfg.py ├── configs ├── BCSS.py ├── CRAG.py └── GlaS.py ├── dataset_cfg ├── BCSS_cv.csv ├── CRAG_cv.csv └── GlaS_cv.csv ├── image_mask_dataset.py ├── imgs └── image2.png ├── losses.py ├── main.py ├── model ├── 0.0_cache_data ├── 0.0_config.yml ├── 0.0_state └── history.txt ├── network ├── __init__.py ├── get_network.py ├── hipt │ ├── extract_hipt_ft.py │ ├── hipt_mil.py │ ├── hipt_prompt.py │ ├── vision_transformer.py │ └── vision_transformer4k.py ├── kan.py ├── sam_network.py ├── sam_network_2.py └── sam_network_backup.py ├── pl_module_sam_seg.py ├── sam2_train ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_t.yaml ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ └── transforms.py └── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling ├── __init__.py ├── common.py ├── image_encoder.py ├── mask_decoder.py ├── prompt_encoder.py ├── sam.py └── transformer.py ├── predictor.py └── utils ├── __init__.py ├── amg.py ├── onnx.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | .idea/ 162 | 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Path-SAM2: Transfer SAM2 for digital pathology semantic segmentation 2 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-net', type=str, default='sam2', help='net type') 7 | parser.add_argument('-encoder', type=str, default='vit_b', help='encoder type') 8 | parser.add_argument('-exp_name', default='simzhang sam2', type=str, help='experiment name') 9 | parser.add_argument('-vis', type=bool, default=False, help='Generate visualisation during validation') 10 | parser.add_argument('-train_vis', type=bool, default=False, help='Generate visualisation during training') 11 | parser.add_argument('-prompt', type=str, default='bbox', help='type of prompt, bbox or click') 12 | parser.add_argument('-prompt_freq', type=int, default=2, help='frequency of giving prompt in 3D images') 13 | parser.add_argument('-pretrain', type=str, default=None, help='path of pretrain weights') 14 | parser.add_argument('-val_freq',type=int,default=5,help='interval between each validation') 15 | parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not') 16 | parser.add_argument('-gpu_device', type=int, default=0, help='use which gpu') 17 | parser.add_argument('-image_size', type=int, default=1024, help='image_size') 18 | parser.add_argument('-out_size', type=int, default=1024, help='output_size') 19 | parser.add_argument('-distributed', default='none' ,type=str,help='multi GPU ids to use') 20 | parser.add_argument('-dataset', default='btcv' ,type=str,help='dataset name') 21 | parser.add_argument('-sam_ckpt', type=str, default=None , help='sam checkpoint address') 22 | parser.add_argument('-sam_config', type=str, default=None , help='sam checkpoint address') 23 | parser.add_argument('-video_length', type=int, default=2, help='sam checkpoint address') 24 | parser.add_argument('-b', type=int, default=1, help='batch size for dataloader') 25 | parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate') 26 | parser.add_argument('-weights', type=str, default = 0, help='the weights file you want to test') 27 | parser.add_argument('-multimask_output', type=int, default=1 , help='the number of masks output for multi-class segmentation') 28 | parser.add_argument('-memory_bank_size', type=int, default=16, help='sam 2d memory bank size') 29 | parser.add_argument( 30 | '-data_path', 31 | type=str, 32 | default='./data/btcv', 33 | help='The path of segmentation data') 34 | opt = parser.parse_args() 35 | 36 | return opt 37 | -------------------------------------------------------------------------------- /configs/BCSS.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | 3 | config = { 4 | "batch_size": 8, 5 | "accumulate_grad_batches": 2, 6 | "num_workers": 8, 7 | "out_dir": "/root/workspace/code/sam-path/SAMPath/", 8 | "opt": { 9 | "num_epochs": 60, 10 | "learning_rate": 5e-4, 11 | "weight_decay": 1e-2, #1e-2, 12 | "precision": 32, # "16-mixed" 13 | # "precision": "16-mixed", 14 | "steps": [72 * 25, 72 * 29], 15 | "warmup_steps": 72, 16 | }, 17 | "model": { 18 | "type": 'vit_b', 19 | "checkpoint": "/root/workspace/code/sam-path/pretrained/sam_vit_b_01ec64.pth", 20 | "freeze": { 21 | "image_encoder": True, 22 | "prompt_encoder": True, 23 | "mask_decoder": False, 24 | }, 25 | "prompt_dim": 256, 26 | "prompt_decoder": False, 27 | "dense_prompt_decoder": False, 28 | 29 | "extra_encoder": 'uni_v1', 30 | "extra_type": "fusion", 31 | "extra_checkpoint": "/root/workspace/code/sam-path/pretrained/uni/pytorch_model.bin", 32 | }, 33 | "loss": { 34 | "focal_cof": 0.25, 35 | "dice_cof": 0.75, 36 | "ce_cof": 0.0, 37 | "iou_cof": 0.0625, 38 | }, 39 | "dataset": { 40 | "dataset_root": "/root/workspace/code/sam-path/path_data/BCSS/merged_dataset", 41 | "dataset_csv_path": "/root/workspace/code/sam-path/SAMPath/dataset_cfg/BCSS_cv.csv", 42 | "data_ext": ".png", 43 | "val_fold_id": 0, 44 | "num_classes": 6, 45 | 46 | "ignored_classes": (0), 47 | "ignored_classes_metric": None, # if we do not count background, set to 1 (bg class) 48 | "image_hw": (1024, 1024), # default is 1024, 1024 49 | 50 | "feature_input": False, # or "True" for *.pt features 51 | "dataset_mean": (0.485, 0.456, 0.406), 52 | "dataset_std": (0.229, 0.224, 0.225), 53 | } 54 | } 55 | 56 | cfg = Box(config) -------------------------------------------------------------------------------- /configs/CRAG.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | 3 | config = { 4 | # "num_devices": 2, 5 | "batch_size": 4, 6 | "num_workers": 4, 7 | "out_dir": "/root/workspace/code/sam-path/SAMPath/", 8 | "opt": { 9 | "num_epochs": 120, 10 | "learning_rate": 1e-4, 11 | "weight_decay": 1e-2, #1e-2, 12 | "precision": 32, # "16-mixed" 13 | "steps": [23 * 50, 23 * 55], 14 | "warmup_steps": 46, 15 | }, 16 | "model": { 17 | "type": 'vit_b', 18 | "checkpoint": "/root/workspace/code/sam-path/pretrained/sam_vit_b_01ec64.pth", 19 | "freeze": { 20 | "image_encoder": True, 21 | "prompt_encoder": True, 22 | "mask_decoder": False, 23 | }, 24 | "prompt_dim": 256, 25 | "prompt_decoder": False, 26 | "dense_prompt_decoder": False, 27 | 28 | "extra_encoder": 'uni_v1', 29 | "extra_type": "fusion", 30 | "extra_checkpoint": "/root/workspace/code/sam-path/pretrained/uni/pytorch_model.bin", 31 | }, 32 | "loss": { 33 | "focal_cof": 0.125, 34 | "dice_cof": 0.875, 35 | "ce_cof": 0., 36 | "iou_cof": 0.0, 37 | }, 38 | "dataset": { 39 | "dataset_root": "/root/workspace/code/sam-path/path_data/CRAG/merged", 40 | "dataset_csv_path": "/root/workspace/code/sam-path/SAMPath/dataset_cfg/CRAG_cv.csv", 41 | "data_ext": ".png", 42 | "val_fold_id": 0, 43 | "num_classes": 3, 44 | 45 | "ignored_classes": None, 46 | "ignored_classes_metric": 1, # if we do not count background, set to 1 (bg class) 47 | "image_hw": (1536, 1536), # default is 1024, 1024 48 | 49 | "feature_input": False, # or "True" for *.pt features 50 | "dataset_mean": (0.485, 0.456, 0.406), 51 | "dataset_std": (0.229, 0.224, 0.225), 52 | } 53 | } 54 | 55 | cfg = Box(config) -------------------------------------------------------------------------------- /configs/GlaS.py: -------------------------------------------------------------------------------- 1 | from box import Box 2 | 3 | config = { 4 | # "num_devices": 2, 5 | "batch_size": 8, 6 | "num_workers": 8, 7 | "out_dir": "/root/workspace/code/sam-path/SAMPath/", 8 | "opt": { 9 | "num_epochs": 120, 10 | "learning_rate": 1e-5, 11 | "weight_decay": 1e-2, #1e-2, 12 | "precision": 32, # "16-mixed" 13 | "steps": [23 * 50, 23 * 55], 14 | "warmup_steps": 46, 15 | }, 16 | "model": { 17 | "type": 'vit_b', 18 | "checkpoint": "/root/workspace/code/sam-path/pretrained/sam_vit_b_01ec64.pth", 19 | # "checkpoint": "/root/workspace/code/sam-path/segment-anything-2/checkpoints/sam2_hiera_small.pt", 20 | "freeze": { 21 | "image_encoder": True, 22 | "prompt_encoder": True, 23 | "mask_decoder": False, #不对mask_decoder 进行frozen 操作,因为这里要训练 KAN 24 | }, 25 | "prompt_dim": 256, 26 | "prompt_decoder": False, 27 | "dense_prompt_decoder": False, 28 | 29 | "extra_encoder": 'uni_v1', 30 | "extra_type": "fusion", 31 | "extra_checkpoint": "/root/workspace/code/sam-path/pretrained/uni/pytorch_model.bin", 32 | }, 33 | "loss": { 34 | "focal_cof": 0.125, 35 | "dice_cof": 0.875, 36 | "ce_cof": 0., 37 | "iou_cof": 0.0, 38 | }, 39 | "dataset": { 40 | "dataset_root": "/root/workspace/code/sam-path/path_data/Glas_sam", 41 | "dataset_csv_path": "/root/workspace/code/sam-path/SAMPath/dataset_cfg/GlaS_cv.csv", 42 | "data_ext": ".png", 43 | "val_fold_id": 0, 44 | "num_classes": 3, 45 | 46 | "ignored_classes": None, 47 | "ignored_classes_metric": 1, # if we do not count background, set to 1 (bg class) 48 | "image_hw": (775, 775), # default is 1024, 1024 49 | 50 | "feature_input": False, # or "True" for *.pt features 51 | "dataset_mean": (0.485, 0.456, 0.406), 52 | "dataset_std": (0.229, 0.224, 0.225), 53 | } 54 | } 55 | 56 | cfg = Box(config) -------------------------------------------------------------------------------- /dataset_cfg/CRAG_cv.csv: -------------------------------------------------------------------------------- 1 | img_id,fold 2 | train_1,3 3 | train_10,3 4 | train_100,4 5 | train_101,1 6 | train_102,2 7 | train_103,2 8 | train_104,4 9 | train_105,1 10 | train_106,1 11 | train_107,3 12 | train_108,3 13 | train_109,1 14 | train_11,0 15 | train_110,0 16 | train_111,1 17 | train_112,2 18 | train_113,1 19 | train_114,2 20 | train_115,2 21 | train_116,3 22 | train_117,0 23 | train_118,1 24 | train_119,4 25 | train_12,0 26 | train_120,0 27 | train_121,0 28 | train_122,4 29 | train_123,2 30 | train_124,2 31 | train_125,3 32 | train_126,2 33 | train_127,4 34 | train_128,3 35 | train_129,4 36 | train_13,0 37 | train_130,2 38 | train_131,3 39 | train_132,2 40 | train_133,0 41 | train_134,0 42 | train_135,0 43 | train_136,1 44 | train_137,2 45 | train_138,4 46 | train_139,2 47 | train_14,3 48 | train_140,2 49 | train_141,3 50 | train_142,1 51 | train_143,4 52 | train_144,2 53 | train_145,3 54 | train_146,3 55 | train_147,2 56 | train_148,0 57 | train_149,2 58 | train_15,2 59 | train_150,0 60 | train_151,3 61 | train_152,1 62 | train_153,4 63 | train_154,3 64 | train_155,3 65 | train_156,0 66 | train_157,4 67 | train_158,2 68 | train_159,4 69 | train_16,4 70 | train_160,2 71 | train_161,3 72 | train_162,4 73 | train_163,4 74 | train_164,1 75 | train_165,0 76 | train_166,0 77 | train_167,1 78 | train_168,0 79 | train_169,0 80 | train_17,2 81 | train_170,4 82 | train_171,0 83 | train_172,1 84 | train_173,1 85 | train_18,3 86 | train_19,2 87 | train_2,4 88 | train_20,2 89 | train_21,3 90 | train_22,0 91 | train_23,0 92 | train_24,1 93 | train_25,4 94 | train_26,4 95 | train_27,3 96 | train_28,3 97 | train_29,4 98 | train_3,2 99 | train_30,1 100 | train_31,1 101 | train_32,2 102 | train_33,3 103 | train_34,3 104 | train_35,2 105 | train_36,0 106 | train_37,1 107 | train_38,4 108 | train_39,4 109 | train_4,2 110 | train_40,2 111 | train_41,4 112 | train_42,3 113 | train_43,4 114 | train_44,2 115 | train_45,2 116 | train_46,0 117 | train_47,4 118 | train_48,0 119 | train_49,1 120 | train_5,0 121 | train_50,1 122 | train_51,2 123 | train_52,3 124 | train_53,0 125 | train_54,0 126 | train_55,3 127 | train_56,3 128 | train_57,1 129 | train_58,1 130 | train_59,2 131 | train_6,2 132 | train_60,4 133 | train_61,1 134 | train_62,4 135 | train_63,4 136 | train_64,4 137 | train_65,0 138 | train_66,3 139 | train_67,1 140 | train_68,0 141 | train_69,4 142 | train_7,0 143 | train_70,1 144 | train_71,1 145 | train_72,0 146 | train_73,1 147 | train_74,1 148 | train_75,1 149 | train_76,4 150 | train_77,4 151 | train_78,2 152 | train_79,3 153 | train_8,1 154 | train_80,1 155 | train_81,0 156 | train_82,4 157 | train_83,3 158 | train_84,4 159 | train_85,2 160 | train_86,2 161 | train_87,3 162 | train_88,3 163 | train_89,0 164 | train_9,3 165 | train_90,1 166 | train_91,0 167 | train_92,3 168 | train_93,1 169 | train_94,0 170 | train_95,4 171 | train_96,0 172 | train_97,1 173 | train_98,3 174 | train_99,1 175 | test_1,-1 176 | test_10,-1 177 | test_11,-1 178 | test_12,-1 179 | test_13,-1 180 | test_14,-1 181 | test_15,-1 182 | test_16,-1 183 | test_17,-1 184 | test_18,-1 185 | test_19,-1 186 | test_2,-1 187 | test_20,-1 188 | test_21,-1 189 | test_22,-1 190 | test_23,-1 191 | test_24,-1 192 | test_25,-1 193 | test_26,-1 194 | test_27,-1 195 | test_28,-1 196 | test_29,-1 197 | test_3,-1 198 | test_30,-1 199 | test_31,-1 200 | test_32,-1 201 | test_33,-1 202 | test_34,-1 203 | test_35,-1 204 | test_36,-1 205 | test_37,-1 206 | test_38,-1 207 | test_39,-1 208 | test_4,-1 209 | test_40,-1 210 | test_5,-1 211 | test_6,-1 212 | test_7,-1 213 | test_8,-1 214 | test_9,-1 215 | -------------------------------------------------------------------------------- /dataset_cfg/GlaS_cv.csv: -------------------------------------------------------------------------------- 1 | img_id,fold 2 | test-0016,-1 3 | val-0012,2 4 | train-0007,1 5 | train-0039,1 6 | test-0069,-1 7 | train-0071,1 8 | val-0008,2 9 | test-0009,-1 10 | test-0041,3 11 | train-0045,1 12 | test-0075,3 13 | test-0038,3 14 | train-0013,1 15 | test-0007,3 16 | train-0049,1 17 | train-0034,1 18 | train-0033,1 19 | train-0070,1 20 | train-0044,1 21 | test-0043,3 22 | test-0047,3 23 | train-0041,1 24 | train-0012,1 25 | train-0026,1 26 | val-0007,2 27 | train-0031,1 28 | train-0059,1 29 | test-0022,3 30 | train-0019,1 31 | test-0080,3 32 | test-0039,3 33 | train-0021,1 34 | train-0047,1 35 | test-0055,3 36 | train-0055,1 37 | test-0046,3 38 | test-0021,3 39 | train-0072,1 40 | test-0053,3 41 | test-0067,3 42 | train-0052,1 43 | train-0051,1 44 | train-0042,1 45 | test-0036,3 46 | val-0006,2 47 | test-0076,3 48 | test-0011,3 49 | test-0029,3 50 | train-0002,1 51 | test-0015,3 52 | test-0014,3 53 | test-0042,3 54 | train-0024,1 55 | train-0057,1 56 | train-0023,1 57 | test-0056,3 58 | test-0058,3 59 | train-0060,1 60 | train-0038,1 61 | test-0020,3 62 | test-0035,3 63 | train-0053,1 64 | test-0010,3 65 | test-0005,3 66 | test-0059,3 67 | val-0010,2 68 | val-0004,2 69 | train-0003,1 70 | test-0072,3 71 | test-0050,3 72 | test-0049,3 73 | val-0011,2 74 | train-0054,1 75 | test-0064,3 76 | train-0035,1 77 | test-0037,3 78 | test-0063,3 79 | test-0071,3 80 | train-0010,1 81 | test-0018,3 82 | train-0020,1 83 | train-0016,1 84 | test-0045,3 85 | test-0078,3 86 | test-0077,3 87 | test-0065,3 88 | test-0017,3 89 | test-0030,3 90 | test-0057,3 91 | test-0008,3 92 | train-0005,1 93 | test-0068,3 94 | test-0032,3 95 | train-0068,1 96 | train-0050,1 97 | train-0004,1 98 | test-0013,3 99 | train-0011,1 100 | train-0022,1 101 | test-0073,-1 102 | train-0058,1 103 | test-0061,-1 104 | train-0001,1 105 | train-0062,1 106 | test-0048,-1 107 | train-0066,1 108 | train-0030,1 109 | val-0003,2 110 | val-0013,2 111 | test-0070,-1 112 | test-0025,-1 113 | test-0040,-1 114 | val-0001,2 115 | train-0029,1 116 | train-0032,1 117 | test-0033,-1 118 | test-0079,-1 119 | test-0003,-1 120 | train-0064,1 121 | test-0044,-1 122 | test-0002,-1 123 | test-0051,-1 124 | test-0062,-1 125 | test-0027,-1 126 | test-0034,-1 127 | test-0006,-1 128 | val-0009,2 129 | train-0065,1 130 | train-0037,1 131 | train-0063,1 132 | test-0019,-1 133 | train-0043,1 134 | train-0067,1 135 | train-0008,1 136 | train-0061,1 137 | test-0066,-1 138 | train-0036,1 139 | train-0015,1 140 | test-0012,-1 141 | test-0031,-1 142 | train-0018,1 143 | test-0004,-1 144 | train-0056,1 145 | train-0025,1 146 | test-0052,-1 147 | val-0005,2 148 | train-0014,1 149 | train-0046,1 150 | train-0009,1 151 | train-0069,1 152 | test-0074,-1 153 | test-0024,-1 154 | train-0006,1 155 | test-0060,-1 156 | train-0027,1 157 | test-0026,-1 158 | test-0023,-1 159 | test-0028,-1 160 | test-0054,-1 161 | val-0002,2 162 | train-0028,1 163 | train-0017,1 164 | train-0040,1 165 | train-0048,1 166 | test-0001,3 167 | -------------------------------------------------------------------------------- /image_mask_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from albumentations.pytorch import ToTensorV2 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset, DataLoader 11 | import jpeg4py 12 | from torchvision.transforms import transforms 13 | 14 | 15 | SUB_FOLDER_IMAGE = "img" 16 | SUB_FOLDER_MASK = "mask" 17 | 18 | 19 | def read_rgb_img(p): 20 | if p.lower().endswith((".jpg", ".jpeg")): 21 | try: 22 | return jpeg4py.JPEG(p).decode() 23 | except: 24 | # cv2.setNumThreads(0) 25 | return cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB) 26 | else: 27 | # cv2.setNumThreads(0) 28 | return cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB) 29 | 30 | 31 | def read_mask(p): 32 | return cv2.imread(p, cv2.IMREAD_UNCHANGED) #.astype(np.int_) 33 | 34 | 35 | class ImageMaskDataset(Dataset): 36 | 37 | def __init__( 38 | self, 39 | dataset_root: str, 40 | dataset_csv_path: str, 41 | data_type: str, 42 | val_fold_id: int, 43 | augmentation=None, 44 | data_ext: str =".png", 45 | dataset_mean=(0.485, 0.456, 0.406), 46 | dataset_std=(0.229, 0.224, 0.225), 47 | ignored_classes=None, # only supports None, 0 or [0, ...] 48 | ): 49 | super().__init__() 50 | 51 | self.dataset_root = dataset_root 52 | self.dataset_csv_path = dataset_csv_path 53 | self.data_ext = data_ext 54 | self.augmentation = augmentation 55 | 56 | self.setup(data_type, val_fold_id) 57 | 58 | self.tensor_transforms = albumentations.Compose([ 59 | albumentations.Normalize(mean=dataset_mean, std=dataset_std), 60 | ToTensorV2(), 61 | 62 | ]) 63 | self.ignored_classes = ignored_classes 64 | 65 | def __len__(self): 66 | return len(self.img_list) 67 | 68 | def setup(self, data_type, val_fold_id): 69 | if data_type not in ['train', 'val', 'test']: 70 | raise Exception("Not supported dataset type. It should be train, val or test") 71 | self.data_type = data_type 72 | self.val_fold_id = val_fold_id 73 | if data_type == 'test': 74 | self.val_fold_id = -1 75 | 76 | if val_fold_id >= 0: 77 | self.img_list = self.read_cv_dataset_csv() 78 | else: 79 | if data_type == 'val': 80 | data_type = 'test' 81 | self.data_type = data_type 82 | self.img_list = self.read_dataset_csv() 83 | 84 | def read_dataset_csv(self): 85 | df = pd.read_csv(self.dataset_csv_path, header=0) 86 | if self.data_type in ['test']: 87 | df = df[df['is_test'] > 0] 88 | else: # train 89 | df = df[df['is_test'] == 0] 90 | return df 91 | 92 | def read_cv_dataset_csv(self): 93 | df = pd.read_csv(self.dataset_csv_path, header=0) 94 | if self.data_type in ['val']: 95 | df = df[df['fold'] == self.val_fold_id] 96 | elif self.data_type in ['test']: 97 | df = df[df['fold'] < 0] 98 | else: 99 | df = df[df['fold'] > 0] 100 | df = df[df['fold'] != self.val_fold_id] 101 | return df 102 | 103 | def process_ignored_classes(self, mask): 104 | if self.ignored_classes is not None: 105 | if not isinstance(self.ignored_classes, (list, tuple)): 106 | self.ignored_classes = [self.ignored_classes] 107 | for cls in self.ignored_classes: 108 | if cls != 0: 109 | mask[mask == cls] = 0 110 | else: 111 | mask += 1 112 | return mask 113 | 114 | def __getitem__(self, i): 115 | row = self.img_list.iloc[i] 116 | img_id = row['img_id'] 117 | 118 | image_path = os.path.join(self.dataset_root, SUB_FOLDER_IMAGE, img_id + self.data_ext) 119 | image = read_rgb_img(image_path) 120 | mask = read_mask(os.path.join(self.dataset_root, SUB_FOLDER_MASK, img_id + self.data_ext)) 121 | 122 | if self.augmentation is not None: 123 | ret = self.augmentation(image=image, mask=mask) 124 | image, mask = ret["image"], ret["mask"] 125 | 126 | mask = self.process_ignored_classes(mask) 127 | 128 | ret = self.tensor_transforms(image=image, mask=mask) 129 | image, mask = ret["image"], ret["mask"] 130 | 131 | return image, mask.long() 132 | 133 | class FtMaskDataset(ImageMaskDataset): 134 | def __init__( 135 | self, 136 | dataset_root: str, 137 | dataset_csv_path: str, 138 | data_type: str, 139 | val_fold_id: int, 140 | augmentation = None, 141 | data_ext: str = ".pt", # only changed this 142 | dataset_mean = (0.485, 0.456, 0.406), 143 | dataset_std = (0.229, 0.224, 0.225), 144 | ignored_classes = None, # only supports None, 0 or [0, ...] 145 | ): 146 | super().__init__( 147 | dataset_root, 148 | dataset_csv_path, 149 | data_type, 150 | val_fold_id, 151 | augmentation, 152 | data_ext, 153 | dataset_mean, 154 | dataset_std, 155 | ignored_classes, 156 | ) 157 | 158 | def __getitem__(self, i): 159 | row = self.img_list.iloc[i] 160 | img_id = row['img_id'] 161 | 162 | image = torch.load(os.path.join(self.dataset_root, SUB_FOLDER_IMAGE, img_id + self.data_ext), 163 | map_location='cpu') 164 | mask = read_mask(os.path.join(self.dataset_root, SUB_FOLDER_MASK, img_id + ".png")) 165 | 166 | mask = self.process_ignored_classes(mask) 167 | 168 | mask = torch.from_numpy(mask).long() 169 | 170 | return image, mask 171 | 172 | 173 | 174 | class GeneralDataModule(LightningDataModule): 175 | def __init__(self, common_cfg_dic, dataset_classs, cus_transforms, batch_size, num_workers): 176 | super().__init__() 177 | 178 | self.batch_size = batch_size 179 | self.num_workers = num_workers 180 | 181 | self.dataset_train, self.dataset_val, self.dataset_test = self.initialize_dataset(common_cfg_dic, 182 | dataset_classs, 183 | cus_transforms) 184 | 185 | def initialize_dataset(self, common_cfg, DatasetCLS, cus_transforms): 186 | if cus_transforms is None: 187 | transforms_train, transforms_eval = None, None 188 | elif isinstance(cus_transforms, (list, tuple)): 189 | transforms_train = cus_transforms[0] 190 | transforms_eval = cus_transforms[1] 191 | else: 192 | transforms_train, transforms_eval = cus_transforms, cus_transforms 193 | 194 | dataset_train = DatasetCLS(**common_cfg, data_type="train", augmentation=transforms_train) 195 | dataset_val = DatasetCLS(**common_cfg, data_type="val", augmentation=transforms_eval) 196 | dataset_test = DatasetCLS(**common_cfg, data_type="test", augmentation=transforms_eval) 197 | return dataset_train, dataset_val, dataset_test 198 | 199 | def train_dataloader(self): 200 | return DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) 201 | 202 | def val_dataloader(self): 203 | return DataLoader(self.dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers), \ 204 | DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) 205 | 206 | def test_dataloader(self): 207 | return DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) 208 | 209 | def predict_dataloader(self): 210 | return DataLoader(self.dataset_test, batch_size=1, shuffle=False, num_workers=self.num_workers), -------------------------------------------------------------------------------- /imgs/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simzhangbest/SAM2PATH/e31a4924bf3f619232dc639caaf332aec45c71d2/imgs/image2.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from monai.losses import DiceLoss, FocalLoss 5 | 6 | # To deal with unlabeled region for losses: make prediciton & target all zero 7 | # To deal with unlabeled region for metrics: torchmetrics support ignore_index 8 | 9 | class SAMLoss(nn.Module): 10 | 11 | def __init__(self, focal_cof: float = 20., dice_cof: float = 1., ce_cof: float = 0., iou_cof: float = 1.): 12 | super().__init__() 13 | self.focal_cof = focal_cof 14 | self.dice_cof = dice_cof 15 | self.ce_cof = ce_cof 16 | self.iou_cof = iou_cof 17 | 18 | self.dice_loss_fn = DiceLoss(include_background=False, to_onehot_y=False, sigmoid=False, softmax=False) 19 | self.focal_loss_fn = FocalLoss(include_background=False, to_onehot_y=False) 20 | self.ce_loss_fn = nn.CrossEntropyLoss(ignore_index=0) 21 | 22 | @torch.no_grad() 23 | def to_one_hot_label(self, targets, num_classes): 24 | targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=num_classes) 25 | targets_one_hot = torch.movedim(targets_one_hot, -1, 1) 26 | return targets_one_hot 27 | 28 | def forward(self, inputs, targets, iou_pred, ignored_masks=None): 29 | # masks for ignored regions 30 | if ignored_masks is not None: 31 | inputs = inputs * (1. - ignored_masks.expand_as(inputs)) 32 | targets = targets * (1 - ignored_masks.long().squeeze(1)) 33 | 34 | targets_one_hot = self.to_one_hot_label(targets, num_classes=inputs.shape[1]) 35 | 36 | inputs_softmax = F.softmax(inputs, dim=1) 37 | 38 | dice = self.dice_loss_fn(inputs_softmax, targets_one_hot) 39 | focal = self.focal_loss_fn(inputs, targets_one_hot) 40 | 41 | 42 | iou_true = calc_iou(inputs_softmax, targets_one_hot) 43 | iou = F.mse_loss(iou_pred[:, 1:], iou_true[:, 1:]) # ignore background 44 | # iou = 0. 45 | 46 | ce_loss = self.ce_loss_fn(inputs, targets) 47 | 48 | total_loss = self.focal_cof * focal + self.dice_cof * dice + self.ce_cof * ce_loss + self.iou_cof * iou 49 | 50 | all_loss = { 51 | "loss": total_loss, 52 | "focal": focal, 53 | "dice": dice, 54 | "ce": ce_loss, 55 | "iou": iou 56 | } 57 | # print(f"sim try get total loss: {all_loss}") 58 | 59 | return all_loss 60 | 61 | 62 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 63 | # both are B, N_cls, H, W 64 | # pred_mask = F.softmax(pred_mask, dim=1) 65 | pred_mask = (pred_mask >= 0.5).float() 66 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(2, 3)) 67 | union = torch.sum(pred_mask, dim=(2, 3)) + torch.sum(gt_mask, dim=(2, 3)) - intersection 68 | epsilon = 1e-7 69 | batch_iou = intersection / (union + epsilon) 70 | 71 | batch_iou = batch_iou.unsqueeze(2) 72 | return batch_iou 73 | 74 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import time 4 | # sim added 5 | import timm 6 | import torch 7 | # from lightning.pytorch import seed_everything 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.callbacks import LearningRateMonitor 10 | from pytorch_lightning.loggers import WandbLogger 11 | from torchmetrics import MetricCollection, JaccardIndex, F1Score, Dice 12 | from network.sam_network_2 import PromptSAM, PromptSAMLateFusion, PromptSAMLateFusion_NOKAN 13 | # from network.sam_network_backup import PromptSAM, PromptSAMLateFusion 14 | from pl_module_sam_seg import SamSeg 15 | import albumentations 16 | # from torch.utils.tensorboard import SummaryWriter 17 | from pytorch_lightning.loggers import TensorBoardLogger 18 | 19 | # project_path = "/mnt/zmy/code/sam-path/SAMPath" 20 | # name = "SAM-PATH-CRAG" 21 | # tb_writer_summary_path = os.path.join(project_path, "run", name, "Logs") 22 | # current_time = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 23 | # log_dir = os.path.join(tb_writer_summary_path, current_time) 24 | # logger = TensorBoardLogger(log_dir, name='my_model') 25 | 26 | # import os 27 | # os.environ["WANDB_MODE"]="offline" 28 | 29 | 30 | def get_augmentation(cfg): 31 | # W, H = cfg.dataset.image_hw if cfg.dataset.image_hw is not None else (1024, 1024) 32 | W, H = cfg.dataset.image_hw 33 | transform_train_fn = albumentations.Compose([ 34 | albumentations.RandomResizedCrop(H, W, scale=(0.08, 1.0), p=1.0), 35 | albumentations.Flip(p=0.75), 36 | albumentations.RandomRotate90(), 37 | albumentations.ColorJitter(0.1, 0.1, 0.1, 0.1), 38 | ]) 39 | # transform_test_fn = None #albumentations.Compose([]) 40 | transform_test_fn = albumentations.Compose([ 41 | albumentations.Resize(H, W), 42 | ]) 43 | # transform_train = lambda x: transform_train_fn(image=x[0], mask=x[1])["image"] 44 | # transform_test = lambda x: transform_test_fn(image=x)["image"] 45 | # return transform_train, transform_test 46 | return transform_train_fn, transform_test_fn 47 | 48 | 49 | def get_metrics(cfg): 50 | num_classes = cfg.dataset.num_classes + 1 # Note that we have an extra class 51 | # if cfg.dataset.ignored_classes_metric is not None: 52 | # ignore_index = [0, cfg.dataset.ignored_classes_metric] 53 | # else: 54 | ignore_index = 0 55 | metrics = MetricCollection({ 56 | "IOU_Jaccard_Bal": JaccardIndex(num_classes=num_classes, ignore_index=ignore_index, task='multiclass'), 57 | "IOU_Jaccard": JaccardIndex(num_classes=num_classes, ignore_index=ignore_index, task='multiclass', 58 | average="micro"), 59 | "F1": F1Score(num_classes=num_classes, ignore_index=ignore_index, task='multiclass', average="micro"), 60 | "Dice": Dice(num_classes=num_classes, ignore_index=ignore_index, average="micro"), 61 | "Dice_Bal": Dice(num_classes=num_classes, ignore_index=ignore_index, average="macro"), 62 | }) 63 | return metrics 64 | 65 | 66 | 67 | def get_model(cfg): 68 | if cfg.model.extra_encoder is not None: # cfg.model.extra_encoder hipt 69 | print("Using %s as an extra encoder" % cfg.model.extra_encoder) 70 | neck = True if cfg.model.extra_type == 'plus' else False 71 | if cfg.model.extra_encoder == 'hipt': 72 | from network.get_network import get_hipt 73 | extra_encoder = get_hipt(cfg.model.extra_checkpoint, neck=neck) 74 | elif cfg.model.extra_encoder == 'uni_v1': # sim added 75 | from network.get_network import get_uni 76 | extra_encoder = get_uni(cfg.model.extra_checkpoint, neck=neck) 77 | else: 78 | raise NotImplementedError 79 | else: 80 | extra_encoder = None 81 | if cfg.model.extra_type in ['plus']: 82 | MODEL = PromptSAM 83 | elif cfg.model.extra_type in ['fusion']: 84 | MODEL = PromptSAMLateFusion 85 | else: 86 | raise NotImplementedError 87 | 88 | model = MODEL( 89 | model_type = cfg.model.type, 90 | checkpoint = cfg.model.checkpoint, 91 | prompt_dim = cfg.model.prompt_dim, 92 | num_classes = cfg.dataset.num_classes, 93 | extra_encoder = extra_encoder, 94 | freeze_image_encoder = cfg.model.freeze.image_encoder, 95 | freeze_prompt_encoder = cfg.model.freeze.prompt_encoder, 96 | freeze_mask_decoder = cfg.model.freeze.mask_decoder, 97 | mask_HW = cfg.dataset.image_hw, 98 | feature_input = cfg.dataset.feature_input, 99 | prompt_decoder = cfg.model.prompt_decoder, 100 | dense_prompt_decoder=cfg.model.dense_prompt_decoder, 101 | no_sam=cfg.model.no_sam if "no_sam" in cfg.model else None 102 | ) 103 | return model 104 | 105 | def get_data_module(cfg): 106 | from image_mask_dataset import GeneralDataModule, ImageMaskDataset, FtMaskDataset 107 | augs = get_augmentation(cfg) 108 | common_cfg_dic = { 109 | "dataset_root": cfg.dataset.dataset_root, 110 | "dataset_csv_path": cfg.dataset.dataset_csv_path, 111 | "val_fold_id": cfg.dataset.val_fold_id, 112 | "data_ext": ".jpg" if "data_ext" not in cfg.dataset else cfg.dataset.data_ext, 113 | "dataset_mean": cfg.dataset.dataset_mean, 114 | "dataset_std": cfg.dataset.dataset_std, 115 | "ignored_classes": cfg.dataset.ignored_classes, # only supports None, 0 or [0, ...] 116 | } 117 | if cfg.dataset.feature_input is True: 118 | dataset_cls = FtMaskDataset 119 | else: 120 | dataset_cls = ImageMaskDataset 121 | 122 | data_module = GeneralDataModule(common_cfg_dic, dataset_cls, cus_transforms=augs, 123 | batch_size=cfg.batch_size, num_workers=cfg.num_workers) 124 | return data_module 125 | 126 | 127 | 128 | def get_pl_module(cfg, model, metrics): 129 | pl_module = SamSeg( 130 | cfg = cfg, 131 | sam_model = model, 132 | metrics = metrics, 133 | num_classes = cfg.dataset.num_classes, 134 | focal_cof = cfg.loss.focal_cof, 135 | dice_cof = cfg.loss.dice_cof, 136 | ce_cof=cfg.loss.ce_cof, 137 | iou_cof = cfg.loss.iou_cof, 138 | lr = cfg.opt.learning_rate, 139 | weight_decay = cfg.opt.weight_decay, 140 | lr_steps = cfg.opt.steps, 141 | warmup_steps=cfg.opt.warmup_steps, 142 | ignored_index=cfg.dataset.ignored_classes_metric, 143 | ) 144 | return pl_module 145 | 146 | def main(cfg): 147 | 148 | # sim added: logger module 149 | project_path = cfg.out_dir 150 | name = cfg.name 151 | tb_writer_summary_path = os.path.join(project_path, "run", name, "Logs") 152 | current_time = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 153 | log_dir = os.path.join(tb_writer_summary_path, current_time) 154 | logger = TensorBoardLogger(log_dir, name='v100_model') 155 | 156 | 157 | data_module = get_data_module(cfg) 158 | 159 | sam_model = get_model(cfg) 160 | 161 | metrics = get_metrics(cfg=cfg) 162 | 163 | pl_module = get_pl_module(cfg, model=sam_model, metrics=metrics) # 包装给 pl 164 | 165 | # logger = WandbLogger(project=cfg.project, name=cfg.name, save_dir=cfg.out_dir, log_model=True) 166 | 167 | 168 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 169 | 170 | accumulate_grad_batches = cfg.accumulate_grad_batches if "accumulate_grad_batches" in cfg else 1 171 | 172 | trainer = Trainer(default_root_dir=cfg.out_dir, logger=logger, 173 | devices=cfg.devices, 174 | max_epochs=cfg.opt.num_epochs, 175 | accelerator="gpu", #strategy="auto", 176 | #strategy='ddp_find_unused_parameters_true', 177 | log_every_n_steps=20, num_sanity_val_steps=0, 178 | precision=cfg.opt.precision, 179 | callbacks=[lr_monitor], 180 | accumulate_grad_batches=accumulate_grad_batches, 181 | fast_dev_run=False) 182 | 183 | trainer.fit(pl_module, data_module) 184 | 185 | 186 | if __name__ == '__main__': 187 | 188 | # python main.py --config configs.CRAG --devices 1 --project sampath --name crag_run0 189 | # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 190 | # torch.use_deterministic_algorithms(False) 191 | parser = ArgumentParser() 192 | parser.add_argument("--config", default="configs.CRAG") 193 | parser.add_argument('--devices', type=lambda s: [int(item) for item in s.split(',')], default=[1]) 194 | parser.add_argument('--project', type=str, default="sim-sampath-crag") 195 | parser.add_argument('--name', type=str, default="crag_run4_kan_sam") 196 | # parser.add_argument('--seed', type=int, default=42) 197 | args = parser.parse_args() 198 | 199 | module = __import__(args.config, globals(), locals(), ['cfg']) 200 | cfg = module.cfg 201 | 202 | cfg["project"] = args.project 203 | cfg["devices"] = args.devices 204 | cfg["name"] = args.name 205 | # cfg["seed"] = args.seed 206 | 207 | # seed_everything(cfg["seed"]) 208 | print(cfg) 209 | main(cfg) 210 | # print(cfg) 211 | -------------------------------------------------------------------------------- /model/0.0_cache_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simzhangbest/SAM2PATH/e31a4924bf3f619232dc639caaf332aec45c71d2/model/0.0_cache_data -------------------------------------------------------------------------------- /model/0.0_state: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simzhangbest/SAM2PATH/e31a4924bf3f619232dc639caaf332aec45c71d2/model/0.0_state -------------------------------------------------------------------------------- /model/history.txt: -------------------------------------------------------------------------------- 1 | ### Round 0 ### 2 | init => 0.0 3 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simzhangbest/SAM2PATH/e31a4924bf3f619232dc639caaf332aec45c71d2/network/__init__.py -------------------------------------------------------------------------------- /network/get_network.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | from segment_anything.modeling.common import LayerNorm2d 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class EncoderWrapper(nn.Module): 10 | def __init__(self, model, ft_dim, out_dim, neck=True, re_norm = False, mean=None, std=None): 11 | super().__init__() 12 | self.model = model 13 | if neck: 14 | self.neck = nn.Sequential( 15 | nn.Conv2d( 16 | ft_dim, 17 | out_dim, 18 | kernel_size=1, 19 | bias=False, 20 | ), 21 | LayerNorm2d(out_dim), 22 | nn.Conv2d( 23 | out_dim, 24 | out_dim, 25 | kernel_size=3, 26 | padding=1, 27 | bias=False, 28 | ), 29 | LayerNorm2d(out_dim), 30 | ) 31 | else: 32 | self.neck = None 33 | 34 | self.re_norm = re_norm 35 | 36 | self.register_buffer("in_mean", torch.Tensor((0.485, 0.456, 0.406)).view(-1, 1, 1), False) 37 | self.register_buffer("in_std", torch.Tensor((0.229, 0.224, 0.225)).view(-1, 1, 1), False) 38 | self.register_buffer("mean", torch.Tensor(mean).view(-1, 1, 1), False) 39 | self.register_buffer("std", torch.Tensor(std).view(-1, 1, 1), False) 40 | 41 | def forward(self, x, no_grad=True): # x [6, 3, 1024, 1024] 42 | if self.re_norm: 43 | x = x * self.in_std + self.in_mean 44 | x = (x - self.mean) / self.std 45 | if no_grad: 46 | with torch.no_grad(): 47 | x = self.model(x, dense=True) 48 | else: 49 | x = self.model(x, dense=True) 50 | # x should be B, 4096, dim 51 | if self.neck is not None: 52 | x = x.permute(0, 2, 1) 53 | x = x.reshape(x.shape[0], -1, 64, 64) 54 | x = self.neck(x) 55 | return x # [6, 4096, 384] 56 | 57 | 58 | def get_hipt(pretrained=None, neck=True): 59 | # from .hipt.vision_transformer import vit_small 60 | # from .hipt.hipt_prompt import load_ssl_weights 61 | 62 | # debug use 63 | from hipt.vision_transformer import vit_small 64 | from hipt.hipt_prompt import load_ssl_weights 65 | 66 | model = vit_small(patch_size=16) 67 | model = load_ssl_weights(model, pretrained) 68 | 69 | model = EncoderWrapper(model, 384, 256, neck=neck, 70 | re_norm=True, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 71 | 72 | return model 73 | 74 | 75 | import os 76 | import timm 77 | os.environ['UNI_CKPT_PATH'] = "/mnt/zm/code/CLAM/checkpoints/uni/pytorch_model.bin" 78 | 79 | def has_uni(): 80 | HAS_UNI = False 81 | UNI_CKPT_PATH = '' 82 | # check if UNI_CKPT_PATH is set, catch exception if not 83 | try: 84 | # check if UNI_CKPT_PATH is set 85 | if 'UNI_CKPT_PATH' not in os.environ: 86 | raise ValueError('UNI_CKPT_PATH not set') 87 | HAS_UNI = True 88 | UNI_CKPT_PATH = os.environ['UNI_CKPT_PATH'] 89 | except Exception as e: 90 | print(e) 91 | return HAS_UNI, UNI_CKPT_PATH 92 | 93 | def get_uni(pretrained=None, neck=True): 94 | from .hipt.vision_transformer import vit_large 95 | from .hipt.hipt_prompt import load_uni_weights 96 | model = vit_large() 97 | model = load_uni_weights(model, pretrained) 98 | 99 | model = EncoderWrapper(model, 1024, 256, neck=neck, 100 | re_norm=True, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 101 | return model 102 | 103 | # sim added for model encoder test 104 | from torchsummary import summary 105 | if __name__ == '__main__': 106 | # x should be B, 4096, dim 107 | # pretrained = "/mnt/zmy/code/sam-path/pretrained/vit256_small_dino.pth" 108 | # encoder = get_hipt(pretrained, neck=True) 109 | # # x [6, 3, 1024, 1024] 110 | # input_size = (6, 3, 1024, 1024) 111 | 112 | # input_feature = torch.randn(input_size) 113 | # out = encoder(input_feature) 114 | # print(out.shape) # [6, 4096, 384] neck=False # [6, 256, 64, 64] neck=True 115 | 116 | 117 | 118 | # test uni 119 | _, pretrained = has_uni() 120 | encoder = get_uni(pretrained, neck=True) 121 | input_size = (6, 3, 1024, 1024) 122 | input_feature = torch.randn(input_size) 123 | out = encoder(input_feature) 124 | print(out.shape) # [6, 4096, 1024] neck=False # -------------------------------------------------------------------------------- /network/hipt/extract_hipt_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import pandas as pd 5 | import torch 6 | from einops import rearrange, repeat 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | 11 | import vision_transformer as vits 12 | import vision_transformer4k as vits4k 13 | 14 | 15 | def eval_transforms(is_imagenet=False, patch_size=256): 16 | if is_imagenet: 17 | mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 18 | else: 19 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 20 | eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 21 | return eval_t 22 | 23 | def generate_mask(img_arr): 24 | sat = cv2.cvtColor(img_arr, cv2.COLOR_RGB2HSV)[:, :, 1] 25 | 26 | sat[sat <= 15] = 0 27 | sat[sat > 15] = 1 28 | return sat 29 | 30 | 31 | def get_vit256(pretrained_weights, arch='vit_small'): 32 | r""" 33 | Builds ViT-256 Model. 34 | 35 | Args: 36 | - pretrained_weights (str): Path to ViT-256 Model Checkpoint. 37 | - arch (str): Which model architecture. 38 | - device (torch): Torch device to save model. 39 | 40 | Returns: 41 | - model256 (torch.nn): Initialized model. 42 | """ 43 | 44 | checkpoint_key = 'teacher' 45 | 46 | model256 = vits.__dict__[arch](patch_size=16, num_classes=0) 47 | # for p in model256.parameters(): 48 | # p.requires_grad = False 49 | # model256.eval() 50 | 51 | if os.path.isfile(pretrained_weights): 52 | state_dict = torch.load(pretrained_weights, map_location="cpu") 53 | if checkpoint_key is not None and checkpoint_key in state_dict: 54 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 55 | state_dict = state_dict[checkpoint_key] 56 | # remove `module.` prefix 57 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 58 | # remove `backbone.` prefix induced by multicrop wrapper 59 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 60 | msg = model256.load_state_dict(state_dict, strict=False) 61 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 62 | 63 | return model256 64 | 65 | 66 | def get_vit4k(pretrained_weights, arch='vit4k_xs'): 67 | r""" 68 | Builds ViT-4K Model. 69 | 70 | Args: 71 | - pretrained_weights (str): Path to ViT-4K Model Checkpoint. 72 | - arch (str): Which model architecture. 73 | - device (torch): Torch device to save model. 74 | 75 | Returns: 76 | - model256 (torch.nn): Initialized model. 77 | """ 78 | 79 | checkpoint_key = 'teacher' 80 | model4k = vits4k.__dict__[arch](num_classes=0) 81 | # for p in model4k.parameters(): 82 | # p.requires_grad = False 83 | # model4k.eval() 84 | 85 | if os.path.isfile(pretrained_weights): 86 | state_dict = torch.load(pretrained_weights, map_location="cpu") 87 | if checkpoint_key is not None and checkpoint_key in state_dict: 88 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 89 | state_dict = state_dict[checkpoint_key] 90 | # remove `module.` prefix 91 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 92 | # remove `backbone.` prefix induced by multicrop wrapper 93 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 94 | msg = model4k.load_state_dict(state_dict, strict=False) 95 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 96 | 97 | return model4k 98 | 99 | 100 | class HIPT_4K(torch.nn.Module): 101 | """ 102 | HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with 103 | [256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens. 104 | """ 105 | 106 | def __init__(self, ck_dir, feature_4k=True): 107 | super().__init__() 108 | model256_path = os.path.join(ck_dir, 'vit256_small_dino.pth') 109 | model4k_path = os.path.join(ck_dir, 'vit4k_xs_dino.pth') 110 | self.model256 = get_vit256(pretrained_weights=model256_path) 111 | self.model4k = get_vit4k(pretrained_weights=model4k_path) 112 | # self.patch_filter_params = patch_filter_params 113 | 114 | self.feature_4k = feature_4k 115 | 116 | 117 | def forward(self, x, mask=None): 118 | """ 119 | Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K. 120 | 1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256). 121 | 2. x then gets unfolded into a "batch" of [256 x 256] images. 122 | 3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch. 123 | 4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".) 124 | 5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K. 125 | 126 | Args: 127 | - x (torch.Tensor): [1 x C x W' x H'] image tensor. 128 | 129 | Return: 130 | - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default). 131 | """ 132 | # batch_256, w_256, h_256 = self.prepare_img_tensor(x) # 1. [1 x 3 x W x H]. 133 | batch_256 = x 134 | w_256, h_256 = 16, 16 135 | 136 | batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 137 | batch_256 = rearrange(batch_256, 138 | 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256) 139 | 140 | if mask is not None: 141 | if len(mask.shape) < 4: 142 | mask = mask.unsqueeze(dim=1) 143 | mask_256 = mask.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 144 | mask_256 = rearrange(mask_256, 145 | 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256) 146 | mask_sum = torch.sum(mask_256, dim=(1, 2, 3)) 147 | 148 | batch_256 = batch_256[mask_sum > 0.1 * 256 * 256, ...] 149 | 150 | features_cls256 = [] 151 | for mini_bs in range(0, batch_256.shape[0], 152 | 256): # 3. B may be too large for ViT_256. We further take minibatches of 256. 153 | minibatch_256 = batch_256[mini_bs:mini_bs + 256] # .to(self.device256, non_blocking=True) 154 | features_cls256.append(self.model256( 155 | minibatch_256).detach()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches. 156 | 157 | features_cls256 = torch.vstack(features_cls256) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token. 158 | 159 | if self.feature_4k: 160 | features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0, 1).transpose(0, 2).unsqueeze(dim=0) 161 | # features_cls256 = features_cls256.to(self.device4k, non_blocking=True) # 4. [1 x 384 x w_256 x h_256] 162 | features_cls4k = self.model4k.forward(features_cls256) # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token. 163 | else: 164 | features_cls4k = features_cls256 165 | return features_cls4k -------------------------------------------------------------------------------- /network/hipt/hipt_mil.py: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # HIPT Implementation (With Local-Global Pretraining) # 3 | ###################################### 4 | import torch 5 | from torch import nn 6 | 7 | import torch.nn.functional as F 8 | 9 | """ 10 | Attention Network with Sigmoid Gating (3 fc layers) 11 | args: 12 | L: input feature dimension 13 | D: hidden layer dimension 14 | dropout: whether to use dropout (p = 0.25) 15 | n_classes: number of classes 16 | """ 17 | class Attn_Net_Gated(nn.Module): 18 | def __init__(self, L=1024, D=256, dropout=False, n_classes=1): 19 | r""" 20 | Attention Network with Sigmoid Gating (3 fc layers) 21 | args: 22 | L (int): input feature dimension 23 | D (int): hidden layer dimension 24 | dropout (bool): whether to apply dropout (p = 0.25) 25 | n_classes (int): number of classes 26 | """ 27 | super(Attn_Net_Gated, self).__init__() 28 | self.attention_a = [ 29 | nn.Linear(L, D), 30 | nn.Tanh()] 31 | 32 | self.attention_b = [nn.Linear(L, D), nn.Sigmoid()] 33 | if dropout: 34 | self.attention_a.append(nn.Dropout(0.25)) 35 | self.attention_b.append(nn.Dropout(0.25)) 36 | 37 | self.attention_a = nn.Sequential(*self.attention_a) 38 | self.attention_b = nn.Sequential(*self.attention_b) 39 | self.attention_c = nn.Linear(D, n_classes) 40 | 41 | def forward(self, x): 42 | a = self.attention_a(x) 43 | b = self.attention_b(x) 44 | A = a.mul(b) 45 | A = self.attention_c(A) # N x n_classes 46 | return A, x 47 | 48 | 49 | class HIPT_GP_FC(nn.Module): 50 | def __init__(self, size_arg="small", dropout=0.25, n_classes=4, cus_size=None): 51 | super(HIPT_GP_FC, self).__init__() 52 | self.size_dict_path = {"small": [384, 192, 192], "big": [1024, 512, 384]} 53 | if cus_size is None: 54 | size = self.size_dict_path[size_arg] 55 | else: 56 | size = cus_size 57 | 58 | self.global_phi = nn.Sequential(nn.Linear(192, 192), nn.ReLU(), nn.Dropout(dropout)) 59 | self.global_transformer = nn.TransformerEncoder( 60 | nn.TransformerEncoderLayer( 61 | d_model=192, nhead=3, dim_feedforward=192, dropout=0.25, activation='relu' 62 | ), 63 | num_layers=2 64 | ) 65 | self.global_attn_pool = Attn_Net_Gated(L=size[1], D=size[1], dropout=True, n_classes=1) 66 | self.global_rho = nn.Sequential(*[nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]) 67 | 68 | self.classifier = nn.Linear(size[1], n_classes) 69 | 70 | def forward(self, h_4096, **kwargs): 71 | ### Global 72 | h_4096 = self.global_phi(h_4096) 73 | h_4096 = self.global_transformer(h_4096.unsqueeze(1)).squeeze(1) 74 | A_4096, h_4096 = self.global_attn_pool(h_4096) 75 | A_4096 = torch.transpose(A_4096, 1, 0) 76 | A_4096 = F.softmax(A_4096, dim=1) 77 | h_path = torch.mm(A_4096, h_4096) 78 | h_WSI = self.global_rho(h_path) 79 | 80 | logits = self.classifier(h_WSI) 81 | # Y_hat = torch.topk(logits, 1, dim=1)[1] 82 | # return logits, F.softmax(logits, dim=1), Y_hat, None, None 83 | return logits -------------------------------------------------------------------------------- /network/kan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | class KANLinear(torch.nn.Module): 7 | def __init__( 8 | self, 9 | in_features, 10 | out_features, 11 | grid_size=5, 12 | spline_order=3, 13 | scale_noise=0.1, 14 | scale_base=1.0, 15 | scale_spline=1.0, 16 | enable_standalone_scale_spline=True, 17 | base_activation=torch.nn.SiLU, 18 | grid_eps=0.02, 19 | grid_range=[-1, 1], 20 | ): 21 | super(KANLinear, self).__init__() 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.grid_size = grid_size 25 | self.spline_order = spline_order 26 | 27 | h = (grid_range[1] - grid_range[0]) / grid_size 28 | grid = ( 29 | ( 30 | torch.arange(-spline_order, grid_size + spline_order + 1) * h 31 | + grid_range[0] 32 | ) 33 | .expand(in_features, -1) 34 | .contiguous() 35 | ) 36 | self.register_buffer("grid", grid) 37 | 38 | self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) 39 | self.spline_weight = torch.nn.Parameter( 40 | torch.Tensor(out_features, in_features, grid_size + spline_order) 41 | ) 42 | if enable_standalone_scale_spline: 43 | self.spline_scaler = torch.nn.Parameter( 44 | torch.Tensor(out_features, in_features) 45 | ) 46 | 47 | self.scale_noise = scale_noise 48 | self.scale_base = scale_base 49 | self.scale_spline = scale_spline 50 | self.enable_standalone_scale_spline = enable_standalone_scale_spline 51 | self.base_activation = base_activation() 52 | self.grid_eps = grid_eps 53 | 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) 58 | with torch.no_grad(): 59 | noise = ( 60 | ( 61 | torch.rand(self.grid_size + 1, self.in_features, self.out_features) 62 | - 1 / 2 63 | ) 64 | * self.scale_noise 65 | / self.grid_size 66 | ) 67 | self.spline_weight.data.copy_( 68 | (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) 69 | * self.curve2coeff( 70 | self.grid.T[self.spline_order : -self.spline_order], 71 | noise, 72 | ) 73 | ) 74 | if self.enable_standalone_scale_spline: 75 | # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) 76 | torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) 77 | 78 | def b_splines(self, x: torch.Tensor): 79 | """ 80 | Compute the B-spline bases for the given input tensor. 81 | 82 | Args: 83 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 84 | 85 | Returns: 86 | torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). 87 | """ 88 | assert x.dim() == 2 and x.size(1) == self.in_features 89 | 90 | grid: torch.Tensor = ( 91 | self.grid 92 | ) # (in_features, grid_size + 2 * spline_order + 1) 93 | x = x.unsqueeze(-1) 94 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) 95 | for k in range(1, self.spline_order + 1): 96 | bases = ( 97 | (x - grid[:, : -(k + 1)]) 98 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 99 | * bases[:, :, :-1] 100 | ) + ( 101 | (grid[:, k + 1 :] - x) 102 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 103 | * bases[:, :, 1:] 104 | ) 105 | 106 | assert bases.size() == ( 107 | x.size(0), 108 | self.in_features, 109 | self.grid_size + self.spline_order, 110 | ) 111 | return bases.contiguous() 112 | 113 | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): 114 | """ 115 | Compute the coefficients of the curve that interpolates the given points. 116 | 117 | Args: 118 | x (torch.Tensor): Input tensor of shape (batch_size, in_features). 119 | y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). 120 | 121 | Returns: 122 | torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). 123 | """ 124 | assert x.dim() == 2 and x.size(1) == self.in_features 125 | assert y.size() == (x.size(0), self.in_features, self.out_features) 126 | 127 | A = self.b_splines(x).transpose( 128 | 0, 1 129 | ) # (in_features, batch_size, grid_size + spline_order) 130 | B = y.transpose(0, 1) # (in_features, batch_size, out_features) 131 | solution = torch.linalg.lstsq( 132 | A, B 133 | ).solution # (in_features, grid_size + spline_order, out_features) 134 | result = solution.permute( 135 | 2, 0, 1 136 | ) # (out_features, in_features, grid_size + spline_order) 137 | 138 | assert result.size() == ( 139 | self.out_features, 140 | self.in_features, 141 | self.grid_size + self.spline_order, 142 | ) 143 | return result.contiguous() 144 | 145 | @property 146 | def scaled_spline_weight(self): 147 | return self.spline_weight * ( 148 | self.spline_scaler.unsqueeze(-1) 149 | if self.enable_standalone_scale_spline 150 | else 1.0 151 | ) 152 | 153 | def forward(self, x: torch.Tensor): 154 | assert x.dim() == 2 and x.size(1) == self.in_features 155 | 156 | base_output = F.linear(self.base_activation(x), self.base_weight) 157 | spline_output = F.linear( 158 | self.b_splines(x).view(x.size(0), -1), 159 | self.scaled_spline_weight.view(self.out_features, -1), 160 | ) 161 | return base_output + spline_output 162 | 163 | @torch.no_grad() 164 | def update_grid(self, x: torch.Tensor, margin=0.01): 165 | assert x.dim() == 2 and x.size(1) == self.in_features 166 | batch = x.size(0) 167 | 168 | splines = self.b_splines(x) # (batch, in, coeff) 169 | splines = splines.permute(1, 0, 2) # (in, batch, coeff) 170 | orig_coeff = self.scaled_spline_weight # (out, in, coeff) 171 | orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) 172 | unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) 173 | unreduced_spline_output = unreduced_spline_output.permute( 174 | 1, 0, 2 175 | ) # (batch, in, out) 176 | 177 | # sort each channel individually to collect data distribution 178 | x_sorted = torch.sort(x, dim=0)[0] 179 | grid_adaptive = x_sorted[ 180 | torch.linspace( 181 | 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device 182 | ) 183 | ] 184 | 185 | uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size 186 | grid_uniform = ( 187 | torch.arange( 188 | self.grid_size + 1, dtype=torch.float32, device=x.device 189 | ).unsqueeze(1) 190 | * uniform_step 191 | + x_sorted[0] 192 | - margin 193 | ) 194 | 195 | grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive 196 | grid = torch.concatenate( 197 | [ 198 | grid[:1] 199 | - uniform_step 200 | * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), 201 | grid, 202 | grid[-1:] 203 | + uniform_step 204 | * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), 205 | ], 206 | dim=0, 207 | ) 208 | 209 | self.grid.copy_(grid.T) 210 | self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) 211 | 212 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 213 | """ 214 | Compute the regularization loss. 215 | 216 | This is a dumb simulation of the original L1 regularization as stated in the 217 | paper, since the original one requires computing absolutes and entropy from the 218 | expanded (batch, in_features, out_features) intermediate tensor, which is hidden 219 | behind the F.linear function if we want an memory efficient implementation. 220 | 221 | The L1 regularization is now computed as mean absolute value of the spline 222 | weights. The authors implementation also includes this term in addition to the 223 | sample-based regularization. 224 | """ 225 | l1_fake = self.spline_weight.abs().mean(-1) 226 | regularization_loss_activation = l1_fake.sum() 227 | p = l1_fake / regularization_loss_activation 228 | regularization_loss_entropy = -torch.sum(p * p.log()) 229 | return ( 230 | regularize_activation * regularization_loss_activation 231 | + regularize_entropy * regularization_loss_entropy 232 | ) 233 | 234 | 235 | class KAN(torch.nn.Module): 236 | def __init__( 237 | self, 238 | layers_hidden, 239 | grid_size=5, 240 | spline_order=3, 241 | scale_noise=0.1, 242 | scale_base=1.0, 243 | scale_spline=1.0, 244 | base_activation=torch.nn.SiLU, 245 | grid_eps=0.02, 246 | grid_range=[-1, 1], 247 | ): 248 | super(KAN, self).__init__() 249 | self.grid_size = grid_size 250 | self.spline_order = spline_order 251 | 252 | self.layers = torch.nn.ModuleList() 253 | for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): 254 | self.layers.append( 255 | KANLinear( 256 | in_features, 257 | out_features, 258 | grid_size=grid_size, 259 | spline_order=spline_order, 260 | scale_noise=scale_noise, 261 | scale_base=scale_base, 262 | scale_spline=scale_spline, 263 | base_activation=base_activation, 264 | grid_eps=grid_eps, 265 | grid_range=grid_range, 266 | ) 267 | ) 268 | 269 | def forward(self, x: torch.Tensor, update_grid=False): 270 | for layer in self.layers: 271 | if update_grid: 272 | layer.update_grid(x) 273 | x = layer(x) 274 | return x 275 | 276 | def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): 277 | return sum( 278 | layer.regularization_loss(regularize_activation, regularize_entropy) 279 | for layer in self.layers 280 | ) 281 | -------------------------------------------------------------------------------- /pl_module_sam_seg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torchmetrics import MetricCollection 8 | import time 9 | 10 | from losses import SAMLoss 11 | 12 | class SamSeg(LightningModule): 13 | 14 | def __init__( 15 | self, 16 | cfg, 17 | sam_model: nn.Module, 18 | metrics: MetricCollection, 19 | num_classes: int, 20 | focal_cof: float = 20., 21 | dice_cof: float = 1., 22 | iou_cof: float = 1., 23 | ce_cof: float = 0., 24 | lr: float = 0.0001, 25 | weight_decay: float = 0.01, 26 | lr_steps: list = (10, 20), 27 | warmup_steps: int = 0, 28 | ignored_index=None, 29 | ): 30 | super().__init__() 31 | self.save_hyperparameters(ignore=["sam_model", "metrics"]) # 这将自动记录所有通过 __init__ 传入的参数 32 | self.model = sam_model 33 | self.num_classes = num_classes 34 | 35 | self.loss = SAMLoss(focal_cof, dice_cof, ce_cof, iou_cof) 36 | 37 | self.train_metrics = metrics.clone(postfix='/train') 38 | self.valid_metrics = nn.ModuleList([metrics.clone(postfix='/val'), metrics.clone(postfix='/test')]) 39 | self.test_metrics = metrics.clone(prefix='final_test/') 40 | 41 | self.lr = lr 42 | 43 | self.ignored_index = ignored_index 44 | 45 | self.time_and_cnt = [0., 0] 46 | 47 | def forward(self, images): 48 | # use forward for inference/predictions 49 | pred_masks, iou_predictions = self.model(images) 50 | 51 | # pred_masks and iou_predictions are lists 将list 变成 torch 张量 52 | pred_masks = torch.stack(pred_masks, dim=0) 53 | iou_predictions = torch.stack(iou_predictions, dim=0) 54 | 55 | return pred_masks, iou_predictions 56 | 57 | def calc_loss(self, pred_masks, gt_masks, iou_predictions, ignored_masks): 58 | loss_dict = self.loss(pred_masks, gt_masks, iou_predictions, ignored_masks=ignored_masks) 59 | assert "loss" in loss_dict 60 | return loss_dict 61 | 62 | @torch.no_grad() 63 | def process_masks(self, gt_masks): 64 | # gt_cls_masks = [gt_masks == i for i in range(0, self.num_classes + 1)] 65 | 66 | ignored_masks = gt_masks == 0 67 | # gt_cls_masks = torch.stack(gt_cls_masks[1:], dim=1).float() 68 | ignored_masks = ignored_masks.unsqueeze(1).long() 69 | return gt_masks, ignored_masks 70 | 71 | def predict_mask(self, pred_masks, gt_masks, ignored_masks): 72 | # pred_masks = [batch_size, #classes, h, w] 73 | # note class 0 is always for ignored classes 74 | pred_masks = torch.argmax(pred_masks[:, 1:, ...], dim=1) + 1 75 | pred_masks = pred_masks * (1 - ignored_masks.squeeze(1)) 76 | 77 | if self.ignored_index is not None: 78 | pred_masks[pred_masks == self.ignored_index] = 0 79 | gt_masks[gt_masks == self.ignored_index] = 0 80 | 81 | return pred_masks, gt_masks 82 | 83 | def training_step(self, batch, batch_idx): 84 | images, gt_masks = batch 85 | gt_masks, ignored_masks = self.process_masks(gt_masks) 86 | 87 | pred_masks, iou_predictions = self(images) 88 | losses = self.calc_loss(pred_masks, gt_masks, iou_predictions, ignored_masks=ignored_masks) 89 | 90 | self.log_losses(losses, "train") 91 | 92 | mask_cls_pred, gt_masks = self.predict_mask(pred_masks, gt_masks, ignored_masks=ignored_masks) 93 | self.train_metrics.update(mask_cls_pred, gt_masks) 94 | # self.train_metrics(mask_cls_pred, gt_masks) 95 | 96 | self.log_dict(self.train_metrics.compute(), on_step=False, on_epoch=True) 97 | 98 | return losses["loss"] 99 | 100 | def on_train_epoch_end(self): 101 | self.log_dict(self.train_metrics.compute()) 102 | self.train_metrics.reset() 103 | 104 | def validation_step(self, batch, batch_idx, dataloader_idx=None): 105 | images, gt_masks = batch 106 | gt_masks, ignored_masks = self.process_masks(gt_masks) 107 | 108 | prefix = get_prefix_from_val_id(dataloader_idx) 109 | metrics_idx = dataloader_idx if dataloader_idx is not None else 0 110 | 111 | pred_masks, iou_predictions = self(images) 112 | losses = self.calc_loss(pred_masks, gt_masks, iou_predictions, ignored_masks=ignored_masks) 113 | 114 | mask_cls_pred, gt_masks = self.predict_mask(pred_masks, gt_masks, ignored_masks=ignored_masks) 115 | 116 | if not self.trainer.sanity_checking: 117 | self.log_losses(losses, prefix) 118 | self.valid_metrics[metrics_idx].update(mask_cls_pred, gt_masks) 119 | # self.valid_metrics[metrics_idx](mask_cls_pred, gt_masks) 120 | # self.log_dict(self.valid_metrics[metrics_idx], on_step=False, on_epoch=True, 121 | # add_dataloader_idx=False) 122 | 123 | def on_validation_epoch_end(self): 124 | if not self.trainer.sanity_checking: 125 | for valM in self.valid_metrics: 126 | self.log_dict(valM.compute(), add_dataloader_idx=False) 127 | valM.reset() 128 | 129 | def predict_step(self, batch, batch_idx, dataloader_idx: int = 0): 130 | images, gt_masks = batch 131 | gt_masks, ignored_masks = self.process_masks(gt_masks) 132 | 133 | 134 | # pred_masks, iou_predictions = self(images) 135 | with torch.no_grad(): 136 | time_start = time.perf_counter() 137 | pred_masks, iou_predictions = self.model(images) 138 | time_predict = time.perf_counter() - time_start 139 | 140 | pred_masks = torch.stack(pred_masks, dim=0) 141 | iou_predictions = torch.stack(iou_predictions, dim=0) 142 | 143 | self.time_and_cnt[0] += time_predict 144 | self.time_and_cnt[1] += 1 145 | print("Average prediction time: %f" % (self.time_and_cnt[0] / self.time_and_cnt[1])) 146 | 147 | mask_cls_pred, gt_masks = self.predict_mask(pred_masks, gt_masks, ignored_masks=ignored_masks) 148 | return mask_cls_pred 149 | 150 | def log_losses(self, losses, prefiex): 151 | if prefiex == "train": 152 | for t in losses: 153 | self.log("Loss/%s_%s" % (prefiex, t), losses[t], on_epoch=True, on_step=True, sync_dist=True) 154 | else: 155 | for t in losses: 156 | self.log("Loss/%s_%s" % (prefiex, t), losses[t], on_epoch=True, on_step=False, sync_dist=True, 157 | add_dataloader_idx=False) 158 | 159 | def configure_optimizers(self): 160 | # self.hparams available because we called self.save_hyperparameters() 161 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.hparams.weight_decay) 162 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, self.hparams.lr_steps, verbose=False) 163 | def lr_lambda(step): 164 | if step < self.hparams.warmup_steps: 165 | return step / self.hparams.warmup_steps 166 | elif step < self.hparams.lr_steps[0]: 167 | return 1.0 168 | elif step < self.hparams.lr_steps[1]: 169 | return 0.1 170 | else: 171 | return 0.01 172 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda, verbose=False) 173 | return { 174 | 'optimizer': optimizer, 175 | 'lr_scheduler': { 176 | 'scheduler': scheduler, 177 | 'interval': 'step' 178 | } 179 | }#[optimizer], [scheduler] 180 | 181 | def get_prefix_from_val_id(dataloader_idx): 182 | if dataloader_idx is None or dataloader_idx == 0: 183 | return "val" 184 | elif dataloader_idx == 1: 185 | return "test" 186 | else: 187 | raise NotImplementedError 188 | 189 | -------------------------------------------------------------------------------- /sam2_train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | 9 | initialize_config_module("sam2_train", version_base="1.2") 10 | -------------------------------------------------------------------------------- /sam2_train/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def build_sam2( 16 | config_file="sam2_hiera_b+", 17 | ckpt_path=None, 18 | device="cuda", 19 | mode="eval", 20 | hydra_overrides_extra=[], 21 | apply_postprocessing=True, 22 | ): 23 | 24 | if apply_postprocessing: 25 | hydra_overrides_extra = hydra_overrides_extra.copy() 26 | hydra_overrides_extra += [ 27 | # dynamically fall back to multi-mask if the single mask is not stable 28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 30 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 31 | ] 32 | 33 | # Read config and init model 34 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 35 | OmegaConf.resolve(cfg) 36 | model = instantiate(cfg.model, _recursive_=True) 37 | _load_checkpoint(model, ckpt_path) 38 | model = model.to(device) 39 | if mode == "eval": 40 | model.eval() 41 | return model 42 | 43 | 44 | def build_sam2_video_predictor( 45 | config_file, 46 | ckpt_path=None, 47 | device="cuda", 48 | mode="eval", 49 | hydra_overrides_extra=[], 50 | apply_postprocessing=True, 51 | ): 52 | hydra_overrides = [ 53 | "++model._target_=sam2_train.sam2_video_predictor.SAM2VideoPredictor", 54 | ] 55 | if apply_postprocessing: 56 | hydra_overrides_extra = hydra_overrides_extra.copy() 57 | hydra_overrides_extra += [ 58 | # dynamically fall back to multi-mask if the single mask is not stable 59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 61 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 62 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 63 | "++model.binarize_mask_from_pts_for_mem_enc=true", 64 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 65 | "++model.fill_hole_area=8", 66 | ] 67 | hydra_overrides.extend(hydra_overrides_extra) 68 | 69 | # Read config and init model 70 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 71 | OmegaConf.resolve(cfg) 72 | model = instantiate(cfg.model, _recursive_=True) 73 | _load_checkpoint(model, ckpt_path) 74 | model = model.to(device) 75 | if mode == "eval": 76 | model.eval() 77 | return model 78 | 79 | 80 | def _load_checkpoint(model, ckpt_path): 81 | if ckpt_path is not None: 82 | sd = torch.load(ckpt_path, map_location="cpu")["model"] 83 | missing_keys, unexpected_keys = model.load_state_dict(sd) 84 | if missing_keys: 85 | logging.error(missing_keys) 86 | raise RuntimeError() 87 | if unexpected_keys: 88 | logging.error(unexpected_keys) 89 | raise RuntimeError() 90 | logging.info("Loaded checkpoint sucessfully") 91 | -------------------------------------------------------------------------------- /sam2_train/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be a even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be a even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2_train/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_train/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_train/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | from typing import List, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2_train.modeling.backbones.utils import ( 15 | PatchEmbed, 16 | window_partition, 17 | window_unpartition, 18 | ) 19 | 20 | from sam2_train.modeling.sam2_utils import DropPath, MLP 21 | 22 | 23 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 24 | if pool is None: 25 | return x 26 | # (B, H, W, C) -> (B, C, H, W) 27 | x = x.permute(0, 3, 1, 2) 28 | x = pool(x) 29 | # (B, C, H', W') -> (B, H', W', C) 30 | x = x.permute(0, 2, 3, 1) 31 | if norm: 32 | x = norm(x) 33 | 34 | return x 35 | 36 | 37 | class MultiScaleAttention(nn.Module): 38 | def __init__( 39 | self, 40 | dim: int, 41 | dim_out: int, 42 | num_heads: int, 43 | q_pool: nn.Module = None, 44 | ): 45 | super().__init__() 46 | 47 | self.dim = dim 48 | self.dim_out = dim_out 49 | 50 | self.num_heads = num_heads 51 | head_dim = dim_out // num_heads 52 | self.scale = head_dim**-0.5 53 | 54 | self.q_pool = q_pool 55 | self.qkv = nn.Linear(dim, dim_out * 3) 56 | self.proj = nn.Linear(dim_out, dim_out) 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | B, H, W, _ = x.shape 60 | # qkv with shape (B, H * W, 3, nHead, C) 61 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 62 | # q, k, v with shape (B, H * W, nheads, C) 63 | q, k, v = torch.unbind(qkv, 2) 64 | 65 | # Q pooling (for downsample at stage changes) 66 | if self.q_pool: 67 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 68 | H, W = q.shape[1:3] # downsampled shape 69 | q = q.reshape(B, H * W, self.num_heads, -1) 70 | 71 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 72 | x = F.scaled_dot_product_attention( 73 | q.transpose(1, 2), 74 | k.transpose(1, 2), 75 | v.transpose(1, 2), 76 | ) 77 | # Transpose back 78 | x = x.transpose(1, 2) 79 | x = x.reshape(B, H, W, -1) 80 | 81 | x = self.proj(x) 82 | 83 | return x 84 | 85 | 86 | class MultiScaleBlock(nn.Module): 87 | def __init__( 88 | self, 89 | dim: int, 90 | dim_out: int, 91 | num_heads: int, 92 | mlp_ratio: float = 4.0, 93 | drop_path: float = 0.0, 94 | norm_layer: Union[nn.Module, str] = "LayerNorm", 95 | q_stride: Tuple[int, int] = None, 96 | act_layer: nn.Module = nn.GELU, 97 | window_size: int = 0, 98 | ): 99 | super().__init__() 100 | 101 | if isinstance(norm_layer, str): 102 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 103 | 104 | self.dim = dim 105 | self.dim_out = dim_out 106 | self.norm1 = norm_layer(dim) 107 | 108 | self.window_size = window_size 109 | 110 | self.pool, self.q_stride = None, q_stride 111 | if self.q_stride: 112 | self.pool = nn.MaxPool2d( 113 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 114 | ) 115 | 116 | self.attn = MultiScaleAttention( 117 | dim, 118 | dim_out, 119 | num_heads=num_heads, 120 | q_pool=self.pool, 121 | ) 122 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 123 | 124 | self.norm2 = norm_layer(dim_out) 125 | self.mlp = MLP( 126 | dim_out, 127 | int(dim_out * mlp_ratio), 128 | dim_out, 129 | num_layers=2, 130 | activation=act_layer, 131 | ) 132 | 133 | if dim != dim_out: 134 | self.proj = nn.Linear(dim, dim_out) 135 | 136 | def forward(self, x: torch.Tensor) -> torch.Tensor: 137 | shortcut = x # B, H, W, C 138 | x = self.norm1(x) 139 | 140 | # Skip connection 141 | if self.dim != self.dim_out: 142 | shortcut = do_pool(self.proj(x), self.pool) 143 | 144 | # Window partition 145 | window_size = self.window_size 146 | if window_size > 0: 147 | H, W = x.shape[1], x.shape[2] 148 | x, pad_hw = window_partition(x, window_size) 149 | 150 | # Window Attention + Q Pooling (if stage change) 151 | x = self.attn(x) 152 | if self.q_stride: 153 | # Shapes have changed due to Q pooling 154 | window_size = self.window_size // self.q_stride[0] 155 | H, W = shortcut.shape[1:3] 156 | 157 | pad_h = (window_size - H % window_size) % window_size 158 | pad_w = (window_size - W % window_size) % window_size 159 | pad_hw = (H + pad_h, W + pad_w) 160 | 161 | # Reverse window partition 162 | if self.window_size > 0: 163 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 164 | 165 | x = shortcut + self.drop_path(x) 166 | # MLP 167 | x = x + self.drop_path(self.mlp(self.norm2(x))) 168 | return x 169 | 170 | 171 | class Hiera(nn.Module): 172 | """ 173 | Reference: https://arxiv.org/abs/2306.00989 174 | """ 175 | 176 | def __init__( 177 | self, 178 | embed_dim: int = 96, # initial embed dim 179 | num_heads: int = 1, # initial number of heads 180 | drop_path_rate: float = 0.0, # stochastic depth 181 | q_pool: int = 3, # number of q_pool stages 182 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 183 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 184 | dim_mul: float = 2.0, # dim_mul factor at stage shift 185 | head_mul: float = 2.0, # head_mul factor at stage shift 186 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 187 | # window size per stage, when not using global att. 188 | window_spec: Tuple[int, ...] = ( 189 | 8, 190 | 4, 191 | 14, 192 | 7, 193 | ), 194 | # global attn in these blocks 195 | global_att_blocks: Tuple[int, ...] = ( 196 | 12, 197 | 16, 198 | 20, 199 | ), 200 | return_interm_layers=True, # return feats from every stage 201 | ): 202 | super().__init__() 203 | 204 | assert len(stages) == len(window_spec) 205 | self.window_spec = window_spec 206 | 207 | depth = sum(stages) 208 | self.q_stride = q_stride 209 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 210 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 211 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 212 | self.return_interm_layers = return_interm_layers 213 | 214 | self.patch_embed = PatchEmbed( 215 | embed_dim=embed_dim, 216 | ) 217 | # Which blocks have global att? 218 | self.global_att_blocks = global_att_blocks 219 | 220 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 221 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 222 | self.pos_embed = nn.Parameter( 223 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 224 | ) 225 | self.pos_embed_window = nn.Parameter( 226 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 227 | ) 228 | 229 | dpr = [ 230 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 231 | ] # stochastic depth decay rule 232 | 233 | cur_stage = 1 234 | self.blocks = nn.ModuleList() 235 | 236 | for i in range(depth): 237 | dim_out = embed_dim 238 | # lags by a block, so first block of 239 | # next stage uses an initial window size 240 | # of previous stage and final window size of current stage 241 | window_size = self.window_spec[cur_stage - 1] 242 | 243 | if self.global_att_blocks is not None: 244 | window_size = 0 if i in self.global_att_blocks else window_size 245 | 246 | if i - 1 in self.stage_ends: 247 | dim_out = int(embed_dim * dim_mul) 248 | num_heads = int(num_heads * head_mul) 249 | cur_stage += 1 250 | 251 | block = MultiScaleBlock( 252 | dim=embed_dim, 253 | dim_out=dim_out, 254 | num_heads=num_heads, 255 | drop_path=dpr[i], 256 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 257 | window_size=window_size, 258 | ) 259 | 260 | embed_dim = dim_out 261 | self.blocks.append(block) 262 | 263 | self.channel_list = ( 264 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 265 | if return_interm_layers 266 | else [self.blocks[-1].dim_out] 267 | ) 268 | 269 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 270 | h, w = hw 271 | window_embed = self.pos_embed_window 272 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 273 | pos_embed = pos_embed + window_embed.tile( 274 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 275 | ) 276 | pos_embed = pos_embed.permute(0, 2, 3, 1) 277 | return pos_embed 278 | 279 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 280 | x = self.patch_embed(x) 281 | # x: (B, H, W, C) 282 | 283 | # Add pos embed 284 | x = x + self._get_pos_embed(x.shape[1:3]) 285 | 286 | outputs = [] 287 | for i, blk in enumerate(self.blocks): 288 | x = blk(x) 289 | if (i == self.stage_ends[-1]) or ( 290 | i in self.stage_ends and self.return_interm_layers 291 | ): 292 | feats = x.permute(0, 3, 1, 2) 293 | outputs.append(feats) 294 | 295 | return outputs 296 | -------------------------------------------------------------------------------- /sam2_train/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /sam2_train/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2_train/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2_train.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2_train.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2_train/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2_train.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sam2_train/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention is all you need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | ): 29 | super().__init__() 30 | assert num_pos_feats % 2 == 0, "Expecting even model width" 31 | self.num_pos_feats = num_pos_feats // 2 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | self.cache = {} 41 | 42 | def _encode_xy(self, x, y): 43 | # The positions are expected to be normalized 44 | assert len(x) == len(y) and x.ndim == y.ndim == 1 45 | x_embed = x * self.scale 46 | y_embed = y * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, None] / dim_t 52 | pos_y = y_embed[:, None] / dim_t 53 | pos_x = torch.stack( 54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 55 | ).flatten(1) 56 | pos_y = torch.stack( 57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 58 | ).flatten(1) 59 | return pos_x, pos_y 60 | 61 | @torch.no_grad() 62 | def encode_boxes(self, x, y, w, h): 63 | pos_x, pos_y = self._encode_xy(x, y) 64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 65 | return pos 66 | 67 | encode = encode_boxes # Backwards compatibility 68 | 69 | @torch.no_grad() 70 | def encode_points(self, x, y, labels): 71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 72 | assert bx == by and nx == ny and bx == bl and nx == nl 73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 76 | return pos 77 | 78 | @torch.no_grad() 79 | def forward(self, x: torch.Tensor): 80 | cache_key = (x.shape[-2], x.shape[-1]) 81 | if cache_key in self.cache: 82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 83 | y_embed = ( 84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 85 | .view(1, -1, 1) 86 | .repeat(x.shape[0], 1, x.shape[-1]) 87 | ) 88 | x_embed = ( 89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 90 | .view(1, 1, -1) 91 | .repeat(x.shape[0], x.shape[-2], 1) 92 | ) 93 | 94 | if self.normalize: 95 | eps = 1e-6 96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 98 | 99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 101 | 102 | pos_x = x_embed[:, :, :, None] / dim_t 103 | pos_y = y_embed[:, :, :, None] / dim_t 104 | pos_x = torch.stack( 105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 106 | ).flatten(3) 107 | pos_y = torch.stack( 108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 109 | ).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | self.cache[cache_key] = pos[0] 112 | return pos 113 | 114 | 115 | class PositionEmbeddingRandom(nn.Module): 116 | """ 117 | Positional encoding using random spatial frequencies. 118 | """ 119 | 120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 121 | super().__init__() 122 | if scale is None or scale <= 0.0: 123 | scale = 1.0 124 | self.register_buffer( 125 | "positional_encoding_gaussian_matrix", 126 | scale * torch.randn((2, num_pos_feats)), 127 | ) 128 | 129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 130 | """Positionally encode points that are normalized to [0,1].""" 131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 132 | coords = 2 * coords - 1 133 | coords = coords @ self.positional_encoding_gaussian_matrix 134 | coords = 2 * np.pi * coords 135 | # outputs d_1 x ... x d_n x C shape 136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 137 | 138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 139 | """Generate positional encoding for a grid of the specified size.""" 140 | h, w = size 141 | device: Any = self.positional_encoding_gaussian_matrix.device 142 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 143 | y_embed = grid.cumsum(dim=0) - 0.5 144 | x_embed = grid.cumsum(dim=1) - 0.5 145 | y_embed = y_embed / h 146 | x_embed = x_embed / w 147 | 148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 149 | return pe.permute(2, 0, 1) # C x H x W 150 | 151 | def forward_with_coords( 152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 153 | ) -> torch.Tensor: 154 | """Positionally encode points that are not normalized to [0,1].""" 155 | coords = coords_input.clone() 156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 159 | 160 | 161 | # Rotary Positional Encoding, adapted from: 162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 163 | # 2. https://github.com/naver-ai/rope-vit 164 | # 3. https://github.com/lucidrains/rotary-embedding-torch 165 | 166 | 167 | def init_t_xy(end_x: int, end_y: int): 168 | t = torch.arange(end_x * end_y, dtype=torch.float32) 169 | t_x = (t % end_x).float() 170 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 171 | return t_x, t_y 172 | 173 | 174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 177 | 178 | t_x, t_y = init_t_xy(end_x, end_y) 179 | freqs_x = torch.outer(t_x, freqs_x) 180 | freqs_y = torch.outer(t_y, freqs_y) 181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 184 | 185 | 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | ndim = x.ndim 188 | assert 0 <= 1 < ndim 189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 191 | return freqs_cis.view(*shape) 192 | 193 | 194 | def apply_rotary_enc( 195 | xq: torch.Tensor, 196 | xk: torch.Tensor, 197 | freqs_cis: torch.Tensor, 198 | repeat_freqs_k: bool = False, 199 | ): 200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 201 | xk_ = ( 202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 203 | if xk.shape[-2] != 0 204 | else None 205 | ) 206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 208 | if xk_ is None: 209 | # no keys to rotate, due to dropout 210 | return xq_out.type_as(xq).to(xq.device), xk 211 | # repeat freqs along seq_len dim to match k seq_len 212 | if repeat_freqs_k: 213 | r = xk_.shape[-2] // xq_.shape[-2] 214 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 215 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 216 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 217 | -------------------------------------------------------------------------------- /sam2_train/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_train/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2_train.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2_train.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | batch_size: Optional[int], 146 | ) -> Tuple[torch.Tensor, torch.Tensor]: 147 | """ 148 | Embeds different types of prompts, returning both sparse and dense 149 | embeddings. 150 | 151 | Arguments: 152 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 153 | and labels to embed. 154 | boxes (torch.Tensor or none): boxes to embed 155 | masks (torch.Tensor or none): masks to embed 156 | 157 | Returns: 158 | torch.Tensor: sparse embeddings for the points and boxes, with shape 159 | BxNx(embed_dim), where N is determined by the number of input points 160 | and boxes. 161 | torch.Tensor: dense embeddings for the masks, in the shape 162 | Bx(embed_dim)x(embed_H)x(embed_W) 163 | """ 164 | bs = self._get_batch_size(points, boxes, masks) 165 | if bs is not None: 166 | bs = batch_size 167 | sparse_embeddings = torch.empty( 168 | (bs, 0, self.embed_dim), device=self._get_device() 169 | ) 170 | if points is not None: 171 | coords, labels = points 172 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 173 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 174 | if boxes is not None: 175 | box_embeddings = self._embed_boxes(boxes) 176 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 177 | 178 | if masks is not None: 179 | dense_embeddings = self._embed_masks(masks) 180 | else: 181 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 182 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 183 | ) 184 | 185 | return sparse_embeddings, dense_embeddings 186 | -------------------------------------------------------------------------------- /sam2_train/modeling/sam2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 16 | """ 17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 18 | that are temporally closest to the current frame at `frame_idx`. Here, we take 19 | - a) the closest conditioning frame before `frame_idx` (if any); 20 | - b) the closest conditioning frame after `frame_idx` (if any); 21 | - c) any other temporally closest conditioning frames until reaching a total 22 | of `max_cond_frame_num` conditioning frames. 23 | 24 | Outputs: 25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 27 | """ 28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 29 | selected_outputs = cond_frame_outputs 30 | unselected_outputs = {} 31 | else: 32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 33 | selected_outputs = {} 34 | 35 | # the closest conditioning frame before `frame_idx` (if any) 36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) 37 | if idx_before is not None: 38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before] 39 | 40 | # the closest conditioning frame after `frame_idx` (if any) 41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) 42 | if idx_after is not None: 43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after] 44 | 45 | # add other temporally closest conditioning frames until reaching a total 46 | # of `max_cond_frame_num` conditioning frames. 47 | num_remain = max_cond_frame_num - len(selected_outputs) 48 | inds_remain = sorted( 49 | (t for t in cond_frame_outputs if t not in selected_outputs), 50 | key=lambda x: abs(x - frame_idx), 51 | )[:num_remain] 52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 53 | unselected_outputs = { 54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 55 | } 56 | 57 | return selected_outputs, unselected_outputs 58 | 59 | 60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000): 61 | """ 62 | Get 1D sine positional embedding as in the original Transformer paper. 63 | """ 64 | pe_dim = dim // 2 65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) 66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) 67 | 68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t 69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) 70 | return pos_embed 71 | 72 | 73 | def get_activation_fn(activation): 74 | """Return an activation function given a string""" 75 | if activation == "relu": 76 | return F.relu 77 | if activation == "gelu": 78 | return F.gelu 79 | if activation == "glu": 80 | return F.glu 81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 82 | 83 | 84 | def get_clones(module, N): 85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 86 | 87 | 88 | class DropPath(nn.Module): 89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py 90 | def __init__(self, drop_prob=0.0, scale_by_keep=True): 91 | super(DropPath, self).__init__() 92 | self.drop_prob = drop_prob 93 | self.scale_by_keep = scale_by_keep 94 | 95 | def forward(self, x): 96 | if self.drop_prob == 0.0 or not self.training: 97 | return x 98 | keep_prob = 1 - self.drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and self.scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | # Lightly adapted from 107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 108 | class MLP(nn.Module): 109 | def __init__( 110 | self, 111 | input_dim: int, 112 | hidden_dim: int, 113 | output_dim: int, 114 | num_layers: int, 115 | activation: nn.Module = nn.ReLU, 116 | sigmoid_output: bool = False, 117 | ) -> None: 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList( 122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 123 | ) 124 | self.sigmoid_output = sigmoid_output 125 | self.act = activation() 126 | 127 | def forward(self, x): 128 | for i, layer in enumerate(self.layers): 129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 130 | if self.sigmoid_output: 131 | x = F.sigmoid(x) 132 | return x 133 | 134 | 135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 137 | class LayerNorm2d(nn.Module): 138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 139 | super().__init__() 140 | self.weight = nn.Parameter(torch.ones(num_channels)) 141 | self.bias = nn.Parameter(torch.zeros(num_channels)) 142 | self.eps = eps 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | return x 150 | -------------------------------------------------------------------------------- /sam2_train/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2_train.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2_train.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2_train.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2_train.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | # backbone_channel_list: [768, 384, 192, 96] 24 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 25 | fpn_interp_model: nearest 26 | 27 | memory_attention: 28 | _target_: sam2_train.modeling.memory_attention.MemoryAttention 29 | d_model: 256 30 | pos_enc_at_input: true 31 | layer: 32 | _target_: sam2_train.modeling.memory_attention.MemoryAttentionLayer 33 | activation: relu 34 | dim_feedforward: 2048 35 | dropout: 0.1 36 | pos_enc_at_attn: false 37 | self_attention: 38 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 39 | rope_theta: 10000.0 40 | feat_sizes: [32, 32] 41 | embedding_dim: 256 42 | num_heads: 1 43 | downsample_rate: 1 44 | dropout: 0.1 45 | d_model: 256 46 | pos_enc_at_cross_attn_keys: true 47 | pos_enc_at_cross_attn_queries: false 48 | cross_attention: 49 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 50 | rope_theta: 10000.0 51 | feat_sizes: [32, 32] 52 | rope_k_repeat: True 53 | embedding_dim: 256 54 | num_heads: 1 55 | downsample_rate: 1 56 | dropout: 0.1 57 | kv_in_dim: 64 58 | num_layers: 4 59 | 60 | memory_encoder: 61 | _target_: sam2_train.modeling.memory_encoder.MemoryEncoder 62 | out_dim: 64 63 | position_encoding: 64 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 65 | num_pos_feats: 64 66 | normalize: true 67 | scale: null 68 | temperature: 10000 69 | mask_downsampler: 70 | _target_: sam2_train.modeling.memory_encoder.MaskDownSampler 71 | kernel_size: 3 72 | stride: 2 73 | padding: 1 74 | fuser: 75 | _target_: sam2_train.modeling.memory_encoder.Fuser 76 | layer: 77 | _target_: sam2_train.modeling.memory_encoder.CXBlock 78 | dim: 256 79 | kernel_size: 7 80 | padding: 3 81 | layer_scale_init_value: 1e-6 82 | use_dwconv: True # depth-wise convs 83 | num_layers: 2 84 | 85 | num_maskmem: 7 86 | image_size: 1024 87 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 88 | sigmoid_scale_for_mem_enc: 20.0 89 | sigmoid_bias_for_mem_enc: -10.0 90 | use_mask_input_as_output_without_sam: true 91 | # Memory 92 | directly_add_no_mem_embed: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: false 102 | only_obj_ptrs_in_the_past_for_eval: true 103 | # object occlusion prediction 104 | pred_obj_scores: true 105 | pred_obj_scores_mlp: true 106 | fixed_no_obj_ptr: true 107 | # multimask tracking settings 108 | multimask_output_for_tracking: true 109 | use_multimask_token_for_obj_ptr: true 110 | multimask_min_pt_num: 0 111 | multimask_max_pt_num: 1 112 | use_mlp_for_obj_ptr_proj: true 113 | # Compilation flag 114 | compile_image_encoder: False 115 | -------------------------------------------------------------------------------- /sam2_train/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2_train.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2_train.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2_train.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2_train.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2_train.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2_train.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2_train.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2_train.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2_train.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2_train.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2_train/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2_train.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2_train.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2_train.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2_train.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2_train.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2_train.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2_train.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2_train.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2_train.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2_train.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2_train.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2_train.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /sam2_train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_train/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from threading import Thread 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | 17 | def get_sdpa_settings(): 18 | if torch.cuda.is_available(): 19 | old_gpu = torch.cuda.get_device_properties(0).major < 7 20 | # only use Flash Attention on Ampere (8.0) or newer GPUs 21 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 22 | if not use_flash_attn: 23 | warnings.warn( 24 | "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", 25 | category=UserWarning, 26 | stacklevel=2, 27 | ) 28 | # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only 29 | # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) 30 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) 31 | if pytorch_version < (2, 2): 32 | warnings.warn( 33 | f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " 34 | "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", 35 | category=UserWarning, 36 | stacklevel=2, 37 | ) 38 | math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn 39 | else: 40 | old_gpu = True 41 | use_flash_attn = False 42 | math_kernel_on = True 43 | 44 | return old_gpu, use_flash_attn, math_kernel_on 45 | 46 | 47 | def get_connected_components(mask): 48 | """ 49 | Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). 50 | 51 | Inputs: 52 | - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is 53 | background. 54 | 55 | Outputs: 56 | - labels: A tensor of shape (N, 1, H, W) containing the connected component labels 57 | for foreground pixels and 0 for background pixels. 58 | - counts: A tensor of shape (N, 1, H, W) containing the area of the connected 59 | components for foreground pixels and 0 for background pixels. 60 | """ 61 | from sam2_train import _C 62 | 63 | return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) 64 | 65 | 66 | def mask_to_box(masks: torch.Tensor): 67 | """ 68 | compute bounding box given an input mask 69 | 70 | Inputs: 71 | - masks: [B, 1, H, W] boxes, dtype=torch.Tensor 72 | 73 | Returns: 74 | - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor 75 | """ 76 | B, _, h, w = masks.shape 77 | device = masks.device 78 | xs = torch.arange(w, device=device, dtype=torch.int32) 79 | ys = torch.arange(h, device=device, dtype=torch.int32) 80 | grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") 81 | grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) 82 | grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) 83 | min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) 84 | max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) 85 | min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) 86 | max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) 87 | bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) 88 | 89 | return bbox_coords 90 | 91 | 92 | def _load_img_as_tensor(img_path, image_size): 93 | img_pil = Image.open(img_path) 94 | img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) 95 | if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images 96 | img_np = img_np / 255.0 97 | else: 98 | raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") 99 | img = torch.from_numpy(img_np).permute(2, 0, 1) 100 | video_width, video_height = img_pil.size # the original video size 101 | return img, video_height, video_width 102 | 103 | 104 | class AsyncVideoFrameLoader: 105 | """ 106 | A list of video frames to be load asynchronously without blocking session start. 107 | """ 108 | 109 | def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): 110 | self.img_paths = img_paths 111 | self.image_size = image_size 112 | self.offload_video_to_cpu = offload_video_to_cpu 113 | self.img_mean = img_mean 114 | self.img_std = img_std 115 | # items in `self._images` will be loaded asynchronously 116 | self.images = [None] * len(img_paths) 117 | # catch and raise any exceptions in the async loading thread 118 | self.exception = None 119 | # video_height and video_width be filled when loading the first image 120 | self.video_height = None 121 | self.video_width = None 122 | 123 | # load the first frame to fill video_height and video_width and also 124 | # to cache it (since it's most likely where the user will click) 125 | self.__getitem__(0) 126 | 127 | # load the rest of frames asynchronously without blocking the session start 128 | def _load_frames(): 129 | try: 130 | for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): 131 | self.__getitem__(n) 132 | except Exception as e: 133 | self.exception = e 134 | 135 | self.thread = Thread(target=_load_frames, daemon=True) 136 | self.thread.start() 137 | 138 | def __getitem__(self, index): 139 | if self.exception is not None: 140 | raise RuntimeError("Failure in frame loading thread") from self.exception 141 | 142 | img = self.images[index] 143 | if img is not None: 144 | return img 145 | 146 | img, video_height, video_width = _load_img_as_tensor( 147 | self.img_paths[index], self.image_size 148 | ) 149 | self.video_height = video_height 150 | self.video_width = video_width 151 | # normalize by mean and std 152 | img -= self.img_mean 153 | img /= self.img_std 154 | if not self.offload_video_to_cpu: 155 | img = img.cuda(non_blocking=True) 156 | self.images[index] = img 157 | return img 158 | 159 | def __len__(self): 160 | return len(self.images) 161 | 162 | 163 | def load_video_frames( 164 | video_path, 165 | image_size, 166 | offload_video_to_cpu, 167 | img_mean=(0.485, 0.456, 0.406), 168 | img_std=(0.229, 0.224, 0.225), 169 | async_loading_frames=False, 170 | ): 171 | """ 172 | Load the video frames from a directory of JPEG files (".jpg" format). 173 | 174 | The frames are resized to image_size x image_size and are loaded to GPU if 175 | `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. 176 | 177 | You can load a frame asynchronously by setting `async_loading_frames` to `True`. 178 | """ 179 | if isinstance(video_path, str) and os.path.isdir(video_path): 180 | jpg_folder = video_path 181 | else: 182 | raise NotImplementedError("Only JPEG frames are supported at this moment") 183 | 184 | frame_names = [ 185 | p 186 | for p in os.listdir(jpg_folder) 187 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 188 | ] 189 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 190 | num_frames = len(frame_names) 191 | if num_frames == 0: 192 | raise RuntimeError(f"no images found in {jpg_folder}") 193 | img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] 194 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] 195 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] 196 | 197 | if async_loading_frames: 198 | lazy_images = AsyncVideoFrameLoader( 199 | img_paths, image_size, offload_video_to_cpu, img_mean, img_std 200 | ) 201 | return lazy_images, lazy_images.video_height, lazy_images.video_width 202 | 203 | images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) 204 | for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): 205 | images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) 206 | if not offload_video_to_cpu: 207 | images = images.cuda() 208 | img_mean = img_mean.cuda() 209 | img_std = img_std.cuda() 210 | # normalize by mean and std 211 | images -= img_mean 212 | images /= img_std 213 | return images, video_height, video_width 214 | 215 | def load_video_frames_from_data( 216 | imgs_tensor, 217 | offload_video_to_cpu, 218 | img_mean=(0.485, 0.456, 0.406), 219 | img_std=(0.229, 0.224, 0.225), 220 | async_loading_frames=False, 221 | ): 222 | """ 223 | Load the video frames from a directory of JPEG files (".jpg" format). 224 | 225 | The frames are resized to image_size x image_size and are loaded to GPU if 226 | `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. 227 | 228 | You can load a frame asynchronously by setting `async_loading_frames` to `True`. 229 | """ 230 | 231 | num_frames = imgs_tensor.shape[0] 232 | 233 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] 234 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] 235 | 236 | images = imgs_tensor / 255.0 237 | if not offload_video_to_cpu: 238 | images = images.cuda() 239 | img_mean = img_mean.cuda() 240 | img_std = img_std.cuda() 241 | # normalize by mean and std 242 | images -= img_mean 243 | images /= img_std 244 | return images 245 | 246 | 247 | def fill_holes_in_mask_scores(mask, max_area): 248 | """ 249 | A post processor to fill small holes in mask scores with area under `max_area`. 250 | """ 251 | # Holes are those connected components in background with area <= self.max_area 252 | # (background regions are those with mask scores <= 0) 253 | assert max_area > 0, "max_area must be positive" 254 | labels, areas = get_connected_components(mask <= 0) 255 | is_hole = (labels > 0) & (areas <= max_area) 256 | # We fill holes with a small positive mask score (0.1) to change them to foreground. 257 | mask = torch.where(is_hole, 0.1, mask) 258 | return mask 259 | 260 | 261 | def concat_points(old_point_inputs, new_points, new_labels): 262 | """Add new points and labels to previous point inputs (add at the end).""" 263 | if old_point_inputs is None: 264 | points, labels = new_points, new_labels 265 | else: 266 | points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) 267 | labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) 268 | 269 | return {"point_coords": points, "point_labels": labels} 270 | -------------------------------------------------------------------------------- /sam2_train/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Normalize, Resize, ToTensor 11 | 12 | 13 | class SAM2Transforms(nn.Module): 14 | def __init__( 15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 16 | ): 17 | """ 18 | Transforms for SAM2. 19 | """ 20 | super().__init__() 21 | self.resolution = resolution 22 | self.mask_threshold = mask_threshold 23 | self.max_hole_area = max_hole_area 24 | self.max_sprinkle_area = max_sprinkle_area 25 | self.mean = [0.485, 0.456, 0.406] 26 | self.std = [0.229, 0.224, 0.225] 27 | self.to_tensor = ToTensor() 28 | self.transforms = torch.jit.script( 29 | nn.Sequential( 30 | Resize((self.resolution, self.resolution)), 31 | Normalize(self.mean, self.std), 32 | ) 33 | ) 34 | 35 | def __call__(self, x): 36 | x = self.to_tensor(x) 37 | return self.transforms(x) 38 | 39 | def forward_batch(self, img_list): 40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 41 | img_batch = torch.stack(img_batch, dim=0) 42 | return img_batch 43 | 44 | def transform_coords( 45 | self, coords: torch.Tensor, normalize=False, orig_hw=None 46 | ) -> torch.Tensor: 47 | """ 48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 50 | 51 | Returns 52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 53 | """ 54 | if normalize: 55 | assert orig_hw is not None 56 | h, w = orig_hw 57 | coords = coords.clone() 58 | coords[..., 0] = coords[..., 0] / w 59 | coords[..., 1] = coords[..., 1] / h 60 | 61 | coords = coords * self.resolution # unnormalize coords 62 | return coords 63 | 64 | def transform_boxes( 65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 66 | ) -> torch.Tensor: 67 | """ 68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 70 | """ 71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 72 | return boxes 73 | 74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 75 | """ 76 | Perform PostProcessing on output masks. 77 | """ 78 | from sam2_train.utils.misc import get_connected_components 79 | 80 | masks = masks.float() 81 | if self.max_hole_area > 0: 82 | # Holes are those connected components in background with area <= self.fill_hole_area 83 | # (background regions are those with mask scores <= self.mask_threshold) 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold) 86 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 87 | is_hole = is_hole.reshape_as(masks) 88 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 90 | 91 | if self.max_sprinkle_area > 0: 92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold) 93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 94 | is_hole = is_hole.reshape_as(masks) 95 | # We fill holes with negative mask score (-10.0) to change them to background. 96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 97 | 98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 99 | return masks 100 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | --------------------------------------------------------------------------------