├── .gitignore ├── DataLoader.py ├── LICENSE ├── README.md ├── app.ipynb ├── assets ├── README.md ├── SAM-Med2D_wechat_group.jpeg ├── SAM-Med2D_wechat_group.jpg ├── SAM_Med2D_wechat_group.png ├── cover_SA-Med2D-20M.png ├── dataset.png ├── framwork.png ├── result.png └── visualization.png ├── data_demo ├── image2label_train.json ├── images │ ├── amos_0004_75.png │ ├── amos_0006_90.png │ ├── amos_0507_31.png │ ├── s0114_111.png │ └── s0619_32.png ├── label2image_test.json └── masks │ ├── amos_0004_75_aorta_000.png │ ├── amos_0004_75_inferior_vena_cava_000.png │ ├── amos_0004_75_liver_000.png │ ├── amos_0006_90_aorta_000.png │ ├── amos_0006_90_inferior_vena_cava_000.png │ ├── amos_0006_90_liver_000.png │ ├── amos_0006_90_spleen_000.png │ ├── s0114_111_aorta_000.png │ ├── s0114_111_autochthon_left_000.png │ ├── s0114_111_autochthon_right_000.png │ ├── s0114_111_heart_atrium_left_000.png │ ├── s0114_111_heart_atrium_right_000.png │ ├── s0114_111_heart_myocardium_000.png │ ├── s0114_111_heart_ventricle_left_000.png │ ├── s0114_111_heart_ventricle_right_000.png │ ├── s0114_111_lung_lower_lobe_left_000.png │ ├── s0114_111_lung_lower_lobe_right_000.png │ ├── s0114_111_lung_middle_lobe_right_000.png │ ├── s0114_111_lung_upper_lobe_left_000.png │ ├── s0114_111_rib_left_9_000.png │ ├── s0114_111_rib_right_9_000.png │ ├── s0114_111_vertebrae_T9_000.png │ ├── s0619_32_colon_000.png │ ├── s0619_32_femur_right_000.png │ ├── s0619_32_gluteus_maximus_left_000.png │ ├── s0619_32_gluteus_maximus_right_000.png │ ├── s0619_32_hip_left_000.png │ ├── s0619_32_hip_left_001.png │ ├── s0619_32_hip_right_000.png │ └── s0619_32_hip_right_001.png ├── examples └── SAM-Med2D-onnxruntime │ └── main.py ├── metrics.py ├── predictor_example.ipynb ├── scripts ├── amg.py ├── export_onnx_encoder_model.py └── export_onnx_model.py ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── sam_model.py │ └── transformer.py ├── predictor.py ├── predictor_sammed.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /DataLoader.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from torch.utils.data import Dataset 4 | import albumentations as A 5 | from albumentations.pytorch import ToTensorV2 6 | import cv2 7 | import torch 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from utils import train_transforms, get_boxes_from_mask, init_point_sampling 13 | import json 14 | import random 15 | 16 | 17 | class TestingDataset(Dataset): 18 | 19 | def __init__(self, data_path, image_size=256, mode='test', requires_name=True, point_num=1, return_ori_mask=True, prompt_path=None): 20 | """ 21 | Initializes a TestingDataset object. 22 | Args: 23 | data_path (str): The path to the data. 24 | image_size (int, optional): The size of the image. Defaults to 256. 25 | mode (str, optional): The mode of the dataset. Defaults to 'test'. 26 | requires_name (bool, optional): Indicates whether the dataset requires image names. Defaults to True. 27 | point_num (int, optional): The number of points to retrieve. Defaults to 1. 28 | return_ori_mask (bool, optional): Indicates whether to return the original mask. Defaults to True. 29 | prompt_path (str, optional): The path to the prompt file. Defaults to None. 30 | """ 31 | self.image_size = image_size 32 | self.return_ori_mask = return_ori_mask 33 | self.prompt_path = prompt_path 34 | self.prompt_list = {} if prompt_path is None else json.load(open(prompt_path, "r")) 35 | self.requires_name = requires_name 36 | self.point_num = point_num 37 | 38 | json_file = open(os.path.join(data_path, f'label2image_{mode}.json'), "r") 39 | dataset = json.load(json_file) 40 | 41 | self.image_paths = list(dataset.values()) 42 | self.label_paths = list(dataset.keys()) 43 | 44 | self.pixel_mean = [123.675, 116.28, 103.53] 45 | self.pixel_std = [58.395, 57.12, 57.375] 46 | 47 | def __getitem__(self, index): 48 | """ 49 | Retrieves and preprocesses an item from the dataset. 50 | Args: 51 | index (int): The index of the item to retrieve. 52 | Returns: 53 | dict: A dictionary containing the preprocessed image and associated information. 54 | """ 55 | image_input = {} 56 | try: 57 | image = cv2.imread(self.image_paths[index]) 58 | image = (image - self.pixel_mean) / self.pixel_std 59 | except: 60 | print(self.image_paths[index]) 61 | 62 | mask_path = self.label_paths[index] 63 | ori_np_mask = cv2.imread(mask_path, 0) 64 | 65 | if ori_np_mask.max() == 255: 66 | ori_np_mask = ori_np_mask / 255 67 | 68 | assert np.array_equal(ori_np_mask, ori_np_mask.astype(bool)), f"Mask should only contain binary values 0 and 1. {self.label_paths[index]}" 69 | 70 | h, w = ori_np_mask.shape 71 | ori_mask = torch.tensor(ori_np_mask).unsqueeze(0) 72 | 73 | transforms = train_transforms(self.image_size, h, w) 74 | augments = transforms(image=image, mask=ori_np_mask) 75 | image, mask = augments['image'], augments['mask'].to(torch.int64) 76 | 77 | if self.prompt_path is None: 78 | boxes = get_boxes_from_mask(mask, max_pixel = 0) 79 | point_coords, point_labels = init_point_sampling(mask, self.point_num) 80 | else: 81 | prompt_key = mask_path.split('/')[-1] 82 | boxes = torch.as_tensor(self.prompt_list[prompt_key]["boxes"], dtype=torch.float) 83 | point_coords = torch.as_tensor(self.prompt_list[prompt_key]["point_coords"], dtype=torch.float) 84 | point_labels = torch.as_tensor(self.prompt_list[prompt_key]["point_labels"], dtype=torch.int) 85 | 86 | image_input["image"] = image 87 | image_input["label"] = mask.unsqueeze(0) 88 | image_input["point_coords"] = point_coords 89 | image_input["point_labels"] = point_labels 90 | image_input["boxes"] = boxes 91 | image_input["original_size"] = (h, w) 92 | image_input["label_path"] = '/'.join(mask_path.split('/')[:-1]) 93 | 94 | if self.return_ori_mask: 95 | image_input["ori_label"] = ori_mask 96 | 97 | image_name = self.label_paths[index].split('/')[-1] 98 | if self.requires_name: 99 | image_input["name"] = image_name 100 | return image_input 101 | else: 102 | return image_input 103 | 104 | def __len__(self): 105 | return len(self.label_paths) 106 | 107 | 108 | class TrainingDataset(Dataset): 109 | def __init__(self, data_dir, image_size=256, mode='train', requires_name=True, point_num=1, mask_num=5): 110 | """ 111 | Initializes a training dataset. 112 | Args: 113 | data_dir (str): Directory containing the dataset. 114 | image_size (int, optional): Desired size for the input images. Defaults to 256. 115 | mode (str, optional): Mode of the dataset. Defaults to 'train'. 116 | requires_name (bool, optional): Indicates whether to include image names in the output. Defaults to True. 117 | num_points (int, optional): Number of points to sample. Defaults to 1. 118 | num_masks (int, optional): Number of masks to sample. Defaults to 5. 119 | """ 120 | self.image_size = image_size 121 | self.requires_name = requires_name 122 | self.point_num = point_num 123 | self.mask_num = mask_num 124 | self.pixel_mean = [123.675, 116.28, 103.53] 125 | self.pixel_std = [58.395, 57.12, 57.375] 126 | 127 | dataset = json.load(open(os.path.join(data_dir, f'image2label_{mode}.json'), "r")) 128 | self.image_paths = list(dataset.keys()) 129 | self.label_paths = list(dataset.values()) 130 | 131 | def __getitem__(self, index): 132 | """ 133 | Returns a sample from the dataset. 134 | Args: 135 | index (int): Index of the sample. 136 | Returns: 137 | dict: A dictionary containing the sample data. 138 | """ 139 | 140 | image_input = {} 141 | try: 142 | image = cv2.imread(self.image_paths[index]) 143 | image = (image - self.pixel_mean) / self.pixel_std 144 | except: 145 | print(self.image_paths[index]) 146 | 147 | h, w, _ = image.shape 148 | transforms = train_transforms(self.image_size, h, w) 149 | 150 | masks_list = [] 151 | boxes_list = [] 152 | point_coords_list, point_labels_list = [], [] 153 | mask_path = random.choices(self.label_paths[index], k=self.mask_num) 154 | for m in mask_path: 155 | pre_mask = cv2.imread(m, 0) 156 | if pre_mask.max() == 255: 157 | pre_mask = pre_mask / 255 158 | 159 | augments = transforms(image=image, mask=pre_mask) 160 | image_tensor, mask_tensor = augments['image'], augments['mask'].to(torch.int64) 161 | 162 | boxes = get_boxes_from_mask(mask_tensor) 163 | point_coords, point_label = init_point_sampling(mask_tensor, self.point_num) 164 | 165 | masks_list.append(mask_tensor) 166 | boxes_list.append(boxes) 167 | point_coords_list.append(point_coords) 168 | point_labels_list.append(point_label) 169 | 170 | mask = torch.stack(masks_list, dim=0) 171 | boxes = torch.stack(boxes_list, dim=0) 172 | point_coords = torch.stack(point_coords_list, dim=0) 173 | point_labels = torch.stack(point_labels_list, dim=0) 174 | 175 | image_input["image"] = image_tensor.unsqueeze(0) 176 | image_input["label"] = mask.unsqueeze(1) 177 | image_input["boxes"] = boxes 178 | image_input["point_coords"] = point_coords 179 | image_input["point_labels"] = point_labels 180 | 181 | image_name = self.image_paths[index].split('/')[-1] 182 | if self.requires_name: 183 | image_input["name"] = image_name 184 | return image_input 185 | else: 186 | return image_input 187 | def __len__(self): 188 | return len(self.image_paths) 189 | 190 | 191 | def stack_dict_batched(batched_input): 192 | out_dict = {} 193 | for k,v in batched_input.items(): 194 | if isinstance(v, list): 195 | out_dict[k] = v 196 | else: 197 | out_dict[k] = v.reshape(-1, *v.shape[2:]) 198 | return out_dict 199 | 200 | 201 | if __name__ == "__main__": 202 | train_dataset = TrainingDataset("data_demo", image_size=256, mode='train', requires_name=True, point_num=1, mask_num=5) 203 | print("Dataset:", len(train_dataset)) 204 | train_batch_sampler = DataLoader(dataset=train_dataset, batch_size=2, shuffle=True, num_workers=4) 205 | for i, batched_image in enumerate(tqdm(train_batch_sampler)): 206 | batched_image = stack_dict_batched(batched_image) 207 | print(batched_image["image"].shape, batched_image["label"].shape) 208 | 209 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gradio as gr\n", 10 | "import numpy as np\n", 11 | "from PIL import Image, ImageDraw, ImageFont\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import cv2\n", 14 | "from segment_anything import sam_model_registry\n", 15 | "from segment_anything.predictor_sammed import SammedPredictor\n", 16 | "from argparse import Namespace\n", 17 | "import torch\n", 18 | "import torchvision\n", 19 | "import os, sys\n", 20 | "import random\n", 21 | "import warnings\n", 22 | "from scipy import ndimage\n", 23 | "import functools\n", 24 | "\n", 25 | "\n", 26 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 27 | "args = Namespace()\n", 28 | "args.device = device\n", 29 | "args.image_size = 256\n", 30 | "args.encoder_adapter = True\n", 31 | "args.sam_checkpoint = \"pretrain_model/sam-med2d_b.pth\" #sam_vit_b.pth sam-med2d_b.pth" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def load_model(args):\n", 41 | " model = sam_model_registry[\"vit_b\"](args).to(args.device)\n", 42 | " model.eval()\n", 43 | " predictor = SammedPredictor(model)\n", 44 | " return predictor\n", 45 | "\n", 46 | "\n", 47 | "predictor_with_adapter = load_model(args)\n", 48 | "args.encoder_adapter = False\n", 49 | "predictor_without_adapter = load_model(args)\n", 50 | "\n", 51 | "def run_sammed(input_image, selected_points, last_mask, adapter_type):\n", 52 | " if adapter_type == \"SAM-Med2D-B\":\n", 53 | " predictor = predictor_with_adapter\n", 54 | " else:\n", 55 | " predictor = predictor_without_adapter\n", 56 | " \n", 57 | " image_pil = Image.fromarray(input_image) #.convert(\"RGB\")\n", 58 | " image = input_image\n", 59 | " H,W,_ = image.shape\n", 60 | " predictor.set_image(image)\n", 61 | " centers = np.array([a for a,b in selected_points ])\n", 62 | " point_coords = centers\n", 63 | " point_labels = np.array([b for a,b in selected_points ])\n", 64 | "\n", 65 | " masks, _, logits = predictor.predict(\n", 66 | " point_coords=point_coords,\n", 67 | " point_labels=point_labels,\n", 68 | " mask_input = last_mask,\n", 69 | " multimask_output=True \n", 70 | " ) \n", 71 | "\n", 72 | " mask_image = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))\n", 73 | " mask_draw = ImageDraw.Draw(mask_image)\n", 74 | " for mask in masks:\n", 75 | " draw_mask(mask, mask_draw, random_color=False)\n", 76 | " image_draw = ImageDraw.Draw(image_pil)\n", 77 | "\n", 78 | " draw_point(selected_points, image_draw)\n", 79 | "\n", 80 | " image_pil = image_pil.convert('RGBA')\n", 81 | " image_pil.alpha_composite(mask_image)\n", 82 | " last_mask = torch.sigmoid(torch.as_tensor(logits, dtype=torch.float, device=device))\n", 83 | " return [(image_pil, mask_image), last_mask]\n", 84 | "\n", 85 | "\n", 86 | "def draw_mask(mask, draw, random_color=False):\n", 87 | " if random_color:\n", 88 | " color = (random.randint(0, 255), random.randint(\n", 89 | " 0, 255), random.randint(0, 255), 153)\n", 90 | " else:\n", 91 | " color = (30, 144, 255, 153)\n", 92 | "\n", 93 | " nonzero_coords = np.transpose(np.nonzero(mask))\n", 94 | "\n", 95 | " for coord in nonzero_coords:\n", 96 | " draw.point(coord[::-1], fill=color)\n", 97 | "\n", 98 | "def draw_point(point, draw, r=5):\n", 99 | " show_point = []\n", 100 | " for point, label in point:\n", 101 | " x,y = point\n", 102 | " if label == 1:\n", 103 | " draw.ellipse((x-r, y-r, x+r, y+r), fill='green')\n", 104 | " elif label == 0:\n", 105 | " draw.ellipse((x-r, y-r, x+r, y+r), fill='red')\n", 106 | "\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 3, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "Keyboard interruption in main thread... closing server.\n" 119 | ] 120 | }, 121 | { 122 | "data": { 123 | "text/plain": [] 124 | }, 125 | "execution_count": 3, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "colors = [(255, 0, 0), (0, 255, 0)]\n", 132 | "markers = [1, 5]\n", 133 | "block = gr.Blocks()\n", 134 | "with block:\n", 135 | " with gr.Row():\n", 136 | " gr.Markdown(\n", 137 | " '''# SAM-Med2D!🚀\n", 138 | " SAM-Med2D is an interactive segmentation model based on the SAM model for medical scenarios, supporting multi-point interactive segmentation and box interaction. \n", 139 | " Currently, only multi-point interaction is supported in this application. More information can be found on [**GitHub**](https://github.com/uni-medical/SAM-Med2D/tree/main).\n", 140 | " '''\n", 141 | " )\n", 142 | " with gr.Row():\n", 143 | " # select model\n", 144 | " adapter_type = gr.Dropdown([\"SAM-Med2D-B\", \"SAM-Med2D-B_w/o_adapter\"], value='SAM-Med2D-B', label=\"Select Adapter\")\n", 145 | " # adapter_type.change(fn = update_model, inputs=[adapter_type])\n", 146 | " \n", 147 | " with gr.Tab(label='Image'):\n", 148 | " with gr.Row().style(equal_height=True):\n", 149 | " with gr.Column():\n", 150 | " # input image\n", 151 | " original_image = gr.State(value=None) # store original image without points, default None\n", 152 | " input_image = gr.Image(type=\"numpy\")\n", 153 | " # point prompt\n", 154 | " with gr.Column():\n", 155 | " selected_points = gr.State([]) # store points\n", 156 | " last_mask = gr.State(None) \n", 157 | " with gr.Row():\n", 158 | " gr.Markdown('You can click on the image to select points prompt. Default: foreground_point.')\n", 159 | " undo_button = gr.Button('Undo point')\n", 160 | " radio = gr.Radio(['foreground_point', 'background_point'], label='point labels')\n", 161 | " button = gr.Button(\"Run!\")\n", 162 | " \n", 163 | " gallery_sammed = gr.Gallery(\n", 164 | " label=\"Generated images\", show_label=False, elem_id=\"gallery\").style(preview=True, grid=2,object_fit=\"scale-down\")\n", 165 | " \n", 166 | " def process_example(img):\n", 167 | " return img, [], None \n", 168 | " \n", 169 | " def store_img(img):\n", 170 | " return img, [], None # when new image is uploaded, `selected_points` should be empty\n", 171 | " input_image.upload(\n", 172 | " store_img,\n", 173 | " [input_image],\n", 174 | " [original_image, selected_points, last_mask]\n", 175 | " )\n", 176 | " # user click the image to get points, and show the points on the image\n", 177 | " def get_point(img, sel_pix, point_type, evt: gr.SelectData):\n", 178 | " if point_type == 'foreground_point':\n", 179 | " sel_pix.append((evt.index, 1)) # append the foreground_point\n", 180 | " elif point_type == 'background_point':\n", 181 | " sel_pix.append((evt.index, 0)) # append the background_point\n", 182 | " else:\n", 183 | " sel_pix.append((evt.index, 1)) # default foreground_point\n", 184 | " # draw points\n", 185 | " for point, label in sel_pix:\n", 186 | " cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n", 187 | " # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB\n", 188 | " # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", 189 | " return img if isinstance(img, np.ndarray) else np.array(img)\n", 190 | " \n", 191 | " input_image.select(\n", 192 | " get_point,\n", 193 | " [input_image, selected_points, radio],\n", 194 | " [input_image],\n", 195 | " )\n", 196 | "\n", 197 | " # undo the selected point\n", 198 | " def undo_points(orig_img, sel_pix):\n", 199 | " if isinstance(orig_img, int): # if orig_img is int, the image if select from examples\n", 200 | " temp = cv2.imread(image_examples[orig_img][0])\n", 201 | " temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n", 202 | " else:\n", 203 | " temp = orig_img.copy()\n", 204 | " # draw points\n", 205 | " if len(sel_pix) != 0:\n", 206 | " sel_pix.pop()\n", 207 | " for point, label in sel_pix:\n", 208 | " cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)\n", 209 | " if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB\n", 210 | " temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)\n", 211 | " return temp, None if isinstance(temp, np.ndarray) else np.array(temp), None\n", 212 | " \n", 213 | " undo_button.click(\n", 214 | " undo_points,\n", 215 | " [original_image, selected_points],\n", 216 | " [input_image, last_mask]\n", 217 | " )\n", 218 | "\n", 219 | " with gr.Row():\n", 220 | " with gr.Column():\n", 221 | " gr.Examples([\"data_demo/images/amos_0507_31.png\", \"data_demo/images/s0114_111.png\" ], inputs=[input_image], outputs=[original_image, selected_points,last_mask], fn=process_example, run_on_click=True)\n", 222 | "\n", 223 | " button.click(fn=run_sammed, inputs=[original_image, selected_points, last_mask, adapter_type], outputs=[gallery_sammed, last_mask])\n", 224 | "\n", 225 | "block.launch(debug=True, share=True, show_error=True)\n" 226 | ] 227 | } 228 | ], 229 | "metadata": { 230 | "kernelspec": { 231 | "display_name": "MMseg", 232 | "language": "python", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.8.0" 246 | }, 247 | "orig_nbformat": 4 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /assets/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/SAM-Med2D_wechat_group.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/SAM-Med2D_wechat_group.jpeg -------------------------------------------------------------------------------- /assets/SAM-Med2D_wechat_group.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/SAM-Med2D_wechat_group.jpg -------------------------------------------------------------------------------- /assets/SAM_Med2D_wechat_group.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/SAM_Med2D_wechat_group.png -------------------------------------------------------------------------------- /assets/cover_SA-Med2D-20M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/cover_SA-Med2D-20M.png -------------------------------------------------------------------------------- /assets/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/dataset.png -------------------------------------------------------------------------------- /assets/framwork.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/framwork.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/result.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/assets/visualization.png -------------------------------------------------------------------------------- /data_demo/image2label_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_demo/images/amos_0006_90.png": [ 3 | "data_demo/masks/amos_0006_90_liver_000.png", 4 | "data_demo/masks/amos_0006_90_spleen_000.png", 5 | "data_demo/masks/amos_0006_90_inferior_vena_cava_000.png", 6 | "data_demo/masks/amos_0006_90_aorta_000.png" 7 | ], 8 | "data_demo/images/s0114_111.png": [ 9 | "data_demo/masks/s0114_111_lung_middle_lobe_right_000.png", 10 | "data_demo/masks/s0114_111_heart_ventricle_right_000.png", 11 | "data_demo/masks/s0114_111_lung_lower_lobe_left_000.png", 12 | "data_demo/masks/s0114_111_rib_left_9_000.png", 13 | "data_demo/masks/s0114_111_vertebrae_T9_000.png", 14 | "data_demo/masks/s0114_111_heart_ventricle_left_000.png", 15 | "data_demo/masks/s0114_111_lung_lower_lobe_right_000.png", 16 | "data_demo/masks/s0114_111_lung_upper_lobe_left_000.png", 17 | "data_demo/masks/s0114_111_aorta_000.png", 18 | "data_demo/masks/s0114_111_autochthon_right_000.png", 19 | "data_demo/masks/s0114_111_autochthon_left_000.png", 20 | "data_demo/masks/s0114_111_heart_myocardium_000.png", 21 | "data_demo/masks/s0114_111_heart_atrium_right_000.png", 22 | "data_demo/masks/s0114_111_heart_atrium_left_000.png", 23 | "data_demo/masks/s0114_111_rib_right_9_000.png" 24 | ] 25 | } -------------------------------------------------------------------------------- /data_demo/images/amos_0004_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/images/amos_0004_75.png -------------------------------------------------------------------------------- /data_demo/images/amos_0006_90.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/images/amos_0006_90.png -------------------------------------------------------------------------------- /data_demo/images/amos_0507_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/images/amos_0507_31.png -------------------------------------------------------------------------------- /data_demo/images/s0114_111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/images/s0114_111.png -------------------------------------------------------------------------------- /data_demo/images/s0619_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/images/s0619_32.png -------------------------------------------------------------------------------- /data_demo/label2image_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_demo/masks/amos_0004_75_inferior_vena_cava_000.png": "data_demo/images/amos_0004_75.png", 3 | "data_demo/masks/amos_0004_75_liver_000.png": "data_demo/images/amos_0004_75.png", 4 | "data_demo/masks/amos_0004_75_aorta_000.png": "data_demo/images/amos_0004_75.png", 5 | "data_demo/masks/s0619_32_femur_right_000.png": "data_demo/images/s0619_32.png", 6 | "data_demo/masks/s0619_32_gluteus_maximus_left_000.png": "data_demo/images/s0619_32.png", 7 | "data_demo/masks/s0619_32_hip_left_000.png": "data_demo/images/s0619_32.png", 8 | "data_demo/masks/s0619_32_colon_000.png": "data_demo/images/s0619_32.png", 9 | "data_demo/masks/s0619_32_hip_right_000.png": "data_demo/images/s0619_32.png", 10 | "data_demo/masks/s0619_32_hip_right_001.png": "data_demo/images/s0619_32.png", 11 | "data_demo/masks/s0619_32_hip_left_001.png": "data_demo/images/s0619_32.png", 12 | "data_demo/masks/s0619_32_gluteus_maximus_right_000.png": "data_demo/images/s0619_32.png" 13 | } -------------------------------------------------------------------------------- /data_demo/masks/amos_0004_75_aorta_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0004_75_aorta_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0004_75_inferior_vena_cava_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0004_75_inferior_vena_cava_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0004_75_liver_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0004_75_liver_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0006_90_aorta_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0006_90_aorta_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0006_90_inferior_vena_cava_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0006_90_inferior_vena_cava_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0006_90_liver_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0006_90_liver_000.png -------------------------------------------------------------------------------- /data_demo/masks/amos_0006_90_spleen_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/amos_0006_90_spleen_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_aorta_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_aorta_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_autochthon_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_autochthon_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_autochthon_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_autochthon_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_heart_atrium_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_heart_atrium_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_heart_atrium_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_heart_atrium_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_heart_myocardium_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_heart_myocardium_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_heart_ventricle_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_heart_ventricle_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_heart_ventricle_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_heart_ventricle_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_lung_lower_lobe_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_lung_lower_lobe_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_lung_lower_lobe_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_lung_lower_lobe_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_lung_middle_lobe_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_lung_middle_lobe_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_lung_upper_lobe_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_lung_upper_lobe_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_rib_left_9_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_rib_left_9_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_rib_right_9_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_rib_right_9_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0114_111_vertebrae_T9_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0114_111_vertebrae_T9_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_colon_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_colon_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_femur_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_femur_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_gluteus_maximus_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_gluteus_maximus_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_gluteus_maximus_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_gluteus_maximus_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_hip_left_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_hip_left_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_hip_left_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_hip_left_001.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_hip_right_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_hip_right_000.png -------------------------------------------------------------------------------- /data_demo/masks/s0619_32_hip_right_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/SAM-Med2D/bfd2b93b1158100c8abd81f61766a2de92c1c175/data_demo/masks/s0619_32_hip_right_001.png -------------------------------------------------------------------------------- /examples/SAM-Med2D-onnxruntime/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | import matplotlib.pyplot as plt 7 | 8 | from tqdm import tqdm 9 | from typing import Any, Union 10 | from copy import deepcopy 11 | 12 | 13 | parser = argparse.ArgumentParser( 14 | description="Inference an image with onnxruntime backend." 15 | ) 16 | 17 | parser.add_argument( 18 | "--encoder_model", 19 | type=str, 20 | required=True, 21 | help="Path to the SAM-Med2D onnx encoder model.", 22 | ) 23 | 24 | parser.add_argument( 25 | "--decoder_model", 26 | type=str, 27 | required=True, 28 | help="Path to the SAM-Med2D onnx decoder model.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--img_path", 33 | type=str, 34 | default="../../data_demo/images/amos_0507_31.png", 35 | help="Path to the image", 36 | ) 37 | 38 | parser.add_argument( 39 | "--input_size", 40 | type=int, 41 | default=256, 42 | help="input_size" 43 | ) 44 | 45 | parser.add_argument( 46 | "--work_dir", 47 | type=str, 48 | default="workdir", 49 | help="work dir" 50 | ) 51 | 52 | args = parser.parse_args() 53 | 54 | def show_mask(mask, ax, random_color=False): 55 | if random_color: 56 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 57 | else: 58 | color = np.array([30/255, 144/255, 255/255, 0.6]) 59 | h, w = mask.shape[-2:] 60 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 61 | ax.imshow(mask_image) 62 | 63 | def show_points(coords, labels, ax, marker_size=375): 64 | pos_points = coords[labels==1] 65 | neg_points = coords[labels==0] 66 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 67 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 68 | 69 | def show_box(box, ax): 70 | x0, y0 = box[0], box[1] 71 | w, h = box[2] - box[0], box[3] - box[1] 72 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 73 | 74 | 75 | class SamEncoder: 76 | """Sam encoder model. 77 | 78 | In this class, encoder model will encoder the input image. 79 | 80 | Args: 81 | model_path (str): sam encoder onnx model path. 82 | device (str): Inference device, user can choose 'cuda' or 'cpu'. default to 'cuda'. 83 | warmup_epoch (int): Warmup, if set 0,the model won`t use random inputs to warmup. default to 3. 84 | """ 85 | 86 | def __init__(self, 87 | model_path: str, 88 | device: str = "cuda", 89 | warmup_epoch: int = 3, 90 | **kwargs): 91 | opt = ort.SessionOptions() 92 | 93 | if device == "cuda": 94 | provider = ['CUDAExecutionProvider'] 95 | elif device == "cpu": 96 | provider = ['CPUExecutionProvider'] 97 | else: 98 | raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.") 99 | 100 | print("loading encoder model...") 101 | self.session = ort.InferenceSession(model_path, 102 | opt, 103 | providers=provider, 104 | **kwargs) 105 | 106 | self.input_name = self.session.get_inputs()[0].name 107 | self.input_shape = self.session.get_inputs()[0].shape 108 | self.output_name = self.session.get_outputs()[0].name 109 | self.output_shape = self.session.get_outputs()[0].shape 110 | 111 | self.pixel_mean = np.array([123.675, 116.28, 103.53]) 112 | self.pixel_std = np.array([58.395, 57.12, 57.375]) 113 | self.input_size = (self.input_shape[-1], self.input_shape[-2]) 114 | 115 | if warmup_epoch: 116 | self.warmup(warmup_epoch) 117 | 118 | def warmup(self, epoch: int) -> None: 119 | """warmup function 120 | 121 | Args: 122 | epoch (int): warmup epoch. 123 | """ 124 | x = np.random.random(self.input_shape).astype(np.float32) 125 | print("start warmup!") 126 | for i in tqdm(range(epoch)): 127 | self.session.run(None, {self.input_name: x}) 128 | print("warmup finish!") 129 | 130 | def transform(self, img: np.ndarray) -> np.ndarray: 131 | """image transform 132 | 133 | This function can convert the input image to the required input format for vit. 134 | 135 | Args: 136 | img (np.ndarray): input image, the image type should be BGR. 137 | 138 | Returns: 139 | np.ndarray: transformed image. 140 | """ 141 | # BGR -> RGB 142 | input_image = img[..., ::-1] 143 | 144 | # Normalization 145 | input_image = (input_image - self.pixel_mean) / self.pixel_std 146 | 147 | # Resize 148 | input_image = cv2.resize(input_image, self.input_size, cv2.INTER_NEAREST) 149 | 150 | # HWC -> CHW 151 | input_image = input_image.transpose((2, 0, 1)) 152 | 153 | # CHW -> NCHW 154 | input_image = np.expand_dims(input_image, 0).astype(np.float32) 155 | 156 | return input_image 157 | 158 | def _extract_feature(self, tensor: np.ndarray) -> np.ndarray: 159 | """extract image feature 160 | 161 | this function can use vit to extract feature from transformed image. 162 | 163 | Args: 164 | tensor (np.ndarray): input image with BGR format. 165 | 166 | Returns: 167 | np.ndarray: image`s feature. 168 | """ 169 | input_image = self.transform(tensor) 170 | assert list(input_image.shape) == self.input_shape 171 | feature = self.session.run(None, {self.input_name: input_image})[0] 172 | assert list(feature.shape) == self.output_shape 173 | return feature 174 | 175 | def __call__(self, img: np.array, *args: Any, **kwds: Any) -> Any: 176 | return self._extract_feature(img) 177 | 178 | class SamDecoder: 179 | """Sam decoder model. 180 | 181 | This class is the sam prompt encoder and lightweight mask decoder. 182 | 183 | Args: 184 | model_path (str): decoder model path. 185 | device (str): Inference device, user can choose 'cuda' or 'cpu'. default to 'cuda'. 186 | """ 187 | 188 | def __init__(self, 189 | model_path: str, 190 | device: str = "cuda", 191 | img_size: int = 256, 192 | **kwargs): 193 | opt = ort.SessionOptions() 194 | 195 | if device == "cuda": 196 | provider = ['CUDAExecutionProvider'] 197 | elif device == "cpu": 198 | provider = ['CPUExecutionProvider'] 199 | else: 200 | raise ValueError("Invalid device, please use 'cuda' or 'cpu' device.") 201 | 202 | print("loading decoder model...") 203 | self.mask_threshold = 0.5 204 | self.img_size = (img_size, img_size) 205 | self.session = ort.InferenceSession(model_path, 206 | opt, 207 | providers=provider, 208 | **kwargs) 209 | 210 | def run(self, 211 | img_embeddings: np.ndarray, 212 | origin_image_size: Union[list, tuple], 213 | point_coords: Union[list, np.ndarray] = None, 214 | point_labels: Union[list, np.ndarray] = None, 215 | boxes: Union[list, np.ndarray] = None, 216 | mask_input: np.ndarray = None, 217 | return_logits: bool = False): 218 | """decoder forward function 219 | 220 | This function can use image feature and prompt to generate mask. Must input 221 | at least one box or point. 222 | 223 | Args: 224 | img_embeddings (np.ndarray): the image feature from vit encoder. 225 | origin_image_size (list or tuple): the input image size. 226 | point_coords (list or np.ndarray): the input points. 227 | point_labels (list or np.ndarray): the input points label, 1 indicates 228 | a foreground point and 0 indicates a background point. 229 | boxes (list or np.ndarray): A length 4 array given a box prompt to the 230 | model, in XYXY format. 231 | mask_input (np.ndarray): A low resolution mask input to the model, 232 | typically coming from a previous prediction iteration. Has form 233 | 1xHxW, where for SAM, H=W=4 * embedding.size. 234 | 235 | Returns: 236 | the segment results. 237 | """ 238 | if point_coords is None and point_labels is None and boxes is None: 239 | raise ValueError("Unable to segment, please input at least one box or point.") 240 | 241 | if img_embeddings.shape != (1, 256, 16, 16): 242 | raise ValueError("Got wrong embedding shape!") 243 | if mask_input is None: 244 | mask_input = np.zeros((1, 1, 64, 64), dtype=np.float32) 245 | has_mask_input = np.zeros(1, dtype=np.float32) 246 | else: 247 | mask_input = np.expand_dims(mask_input, axis=0) 248 | has_mask_input = np.ones(1, dtype=np.float32) 249 | if mask_input.shape != (1, 1, 64, 64): 250 | raise ValueError("Got wrong mask!") 251 | if point_coords is not None: 252 | if isinstance(point_coords, list): 253 | point_coords = np.array(point_coords, dtype=np.float32) 254 | if isinstance(point_labels, list): 255 | point_labels = np.array(point_labels, dtype=np.float32) 256 | 257 | if point_coords is not None: 258 | point_coords = self.apply_coords(point_coords, origin_image_size, self.img_size).astype(np.float32) 259 | point_coords = np.expand_dims(point_coords, axis=0) 260 | point_labels = np.expand_dims(point_labels, axis=0) 261 | 262 | if boxes is not None: 263 | if isinstance(boxes, list): 264 | boxes = np.array(boxes, dtype=np.float32) 265 | assert boxes.shape[-1] == 4 266 | 267 | boxes = self.apply_boxes(boxes, origin_image_size, self.img_size).reshape((1, -1, 2)).astype(np.float32) 268 | box_label = np.array([[2, 3] for i in range(boxes.shape[1] // 2)], dtype=np.float32).reshape((1, -1)) 269 | 270 | if point_coords is not None: 271 | point_coords = np.concatenate([point_coords, boxes], axis=1) 272 | point_labels = np.concatenate([point_labels, box_label], axis=1) 273 | else: 274 | point_coords = boxes 275 | point_labels = box_label 276 | 277 | assert point_coords.shape[0] == 1 and point_coords.shape[-1] == 2 278 | assert point_labels.shape[0] == 1 279 | print(f"point_coords={point_coords}, point_labels={point_labels}") 280 | input_dict = {"image_embeddings": img_embeddings, 281 | "point_coords": point_coords, 282 | "point_labels": point_labels, 283 | "mask_input": mask_input, 284 | "has_mask_input": has_mask_input, 285 | "orig_im_size": np.array(origin_image_size, dtype=np.float32)} 286 | masks, iou_predictions, low_res_masks = self.session.run(None, input_dict) 287 | 288 | if not return_logits: 289 | sigmoid_output = self.sigmoid(masks) 290 | masks = (sigmoid_output > self.mask_threshold).astype(np.float32) 291 | 292 | return masks[0], iou_predictions[0], low_res_masks[0] 293 | 294 | @staticmethod 295 | def sigmoid(x): 296 | return 0.5 * (np.tanh(0.5 * x) + 1) 297 | 298 | def apply_coords(self, coords, original_size, new_size): 299 | old_h, old_w = original_size 300 | new_h, new_w = new_size 301 | coords = deepcopy(coords).astype(float) 302 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 303 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 304 | return coords 305 | 306 | def apply_boxes(self, boxes, original_size, new_size): 307 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, new_size) 308 | return boxes.reshape(-1, 4) 309 | 310 | def main(): 311 | # Create save folder 312 | save_path = os.path.join(args.work_dir, 'ort_demo_results') 313 | if not os.path.exists(save_path): 314 | os.makedirs(save_path) 315 | base_name, file_extension = os.path.splitext(os.path.basename(args.img_path)) 316 | 317 | # Initialize the SAM-Med2D onnx model 318 | encoder = SamEncoder( 319 | model_path=args.encoder_model, 320 | warmup_epoch=3 321 | ) 322 | decoder = SamDecoder( 323 | model_path=args.decoder_model, 324 | ) 325 | 326 | '''Specifying a specific object with a point''' 327 | img_file = cv2.imread(args.img_path) 328 | img_embeddings = encoder(img_file) 329 | 330 | origin_image_size = img_file.shape[:2] 331 | point_coords = np.array([[162, 127]], dtype=np.float32) 332 | point_labels = np.array([1], dtype=np.float32) 333 | masks, _, logits = decoder.run( 334 | img_embeddings=img_embeddings, 335 | origin_image_size=origin_image_size, 336 | point_coords=point_coords, 337 | point_labels=point_labels 338 | ) 339 | 340 | plt.figure(figsize=(10,10)) 341 | plt.imshow(img_file) 342 | show_mask(masks, plt.gca()) 343 | show_points(point_coords, point_labels, plt.gca()) 344 | plt.axis('off') 345 | plt.savefig(os.path.join(save_path, base_name+'_point1'+file_extension)) 346 | plt.show() 347 | 348 | '''Optimizing Segmentation Results by Point Interaction''' 349 | new_point_coords = np.array([[169, 140]], dtype=np.float32) 350 | new_point_labels = np.array([0], dtype=np.float32) 351 | point_coords = np.concatenate((point_coords, new_point_coords)) 352 | point_labels = np.concatenate((point_labels, new_point_labels)) 353 | mask_inputs = 1. / (1. + np.exp(-logits.astype(np.float32))) 354 | 355 | masks, _, logits = decoder.run( 356 | img_embeddings=img_embeddings, 357 | origin_image_size=origin_image_size, 358 | point_coords=point_coords, 359 | point_labels=point_labels, 360 | mask_input = mask_inputs, 361 | ) 362 | 363 | plt.figure(figsize=(10,10)) 364 | plt.imshow(img_file) 365 | show_mask(masks, plt.gca()) 366 | show_points(point_coords, point_labels, plt.gca()) 367 | plt.axis('off') 368 | plt.savefig(os.path.join(save_path, base_name+'_point2'+file_extension)) 369 | plt.show() 370 | 371 | '''Specifying a specific object with a bounding box''' 372 | boxes = np.array([135,100,180,150]) 373 | 374 | masks, _, _ = decoder.run( 375 | img_embeddings=img_embeddings, 376 | origin_image_size=origin_image_size, 377 | boxes=boxes, 378 | ) 379 | plt.figure(figsize=(10,10)) 380 | plt.imshow(img_file) 381 | show_mask(masks, plt.gca()) 382 | show_box(boxes, plt.gca()) 383 | plt.axis('off') 384 | plt.savefig(os.path.join(save_path, base_name+'_box'+file_extension)) 385 | plt.show() 386 | 387 | 388 | if __name__ == '__main__': 389 | main() -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def _threshold(x, threshold=None): 6 | if threshold is not None: 7 | return (x > threshold).type(x.dtype) 8 | else: 9 | return x 10 | 11 | 12 | def _list_tensor(x, y): 13 | m = torch.nn.Sigmoid() 14 | if type(x) is list: 15 | x = torch.tensor(np.array(x)) 16 | y = torch.tensor(np.array(y)) 17 | if x.min() < 0: 18 | x = m(x) 19 | else: 20 | x, y = x, y 21 | if x.min() < 0: 22 | x = m(x) 23 | return x, y 24 | 25 | 26 | def iou(pr, gt, eps=1e-7, threshold = 0.5): 27 | pr_, gt_ = _list_tensor(pr, gt) 28 | pr_ = _threshold(pr_, threshold=threshold) 29 | gt_ = _threshold(gt_, threshold=threshold) 30 | intersection = torch.sum(gt_ * pr_,dim=[1,2,3]) 31 | union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3]) - intersection 32 | return ((intersection + eps) / (union + eps)).cpu().numpy() 33 | 34 | 35 | def dice(pr, gt, eps=1e-7, threshold = 0.5): 36 | pr_, gt_ = _list_tensor(pr, gt) 37 | pr_ = _threshold(pr_, threshold=threshold) 38 | gt_ = _threshold(gt_, threshold=threshold) 39 | intersection = torch.sum(gt_ * pr_,dim=[1,2,3]) 40 | union = torch.sum(gt_,dim=[1,2,3]) + torch.sum(pr_,dim=[1,2,3]) 41 | return ((2. * intersection +eps) / (union + eps)).cpu().numpy() 42 | 43 | 44 | def SegMetrics(pred, label, metrics): 45 | metric_list = [] 46 | if isinstance(metrics, str): 47 | metrics = [metrics, ] 48 | for i, metric in enumerate(metrics): 49 | if not isinstance(metric, str): 50 | continue 51 | elif metric == 'iou': 52 | metric_list.append(np.mean(iou(pred, label))) 53 | elif metric == 'dice': 54 | metric_list.append(np.mean(dice(pred, label))) 55 | else: 56 | raise ValueError('metric %s not recognized' % metric) 57 | if pred is not None: 58 | metric = np.array(metric_list) 59 | else: 60 | raise ValueError('metric mistakes in calculations') 61 | return metric -------------------------------------------------------------------------------- /scripts/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 # type: ignore 8 | 9 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 10 | 11 | import argparse 12 | import json 13 | import os 14 | from typing import Any, Dict, List 15 | 16 | parser = argparse.ArgumentParser( 17 | description=( 18 | "Runs automatic mask generation on an input image or directory of images, " 19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 20 | "as well as pycocotools if saving in RLE format." 21 | ) 22 | ) 23 | 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="Path to either a single input image or folder of images.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--output", 33 | type=str, 34 | required=True, 35 | help=( 36 | "Path to the directory where masks will be output. Output will be either a folder " 37 | "of PNGs per image or a single json with COCO-style masks." 38 | ), 39 | ) 40 | 41 | parser.add_argument( 42 | "--model-type", 43 | type=str, 44 | required=True, 45 | help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", 46 | ) 47 | 48 | parser.add_argument( 49 | "--checkpoint", 50 | type=str, 51 | required=True, 52 | help="The path to the SAM checkpoint to use for mask generation.", 53 | ) 54 | 55 | parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") 56 | 57 | parser.add_argument( 58 | "--convert-to-rle", 59 | action="store_true", 60 | help=( 61 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " 62 | "Requires pycocotools." 63 | ), 64 | ) 65 | 66 | amg_settings = parser.add_argument_group("AMG Settings") 67 | 68 | amg_settings.add_argument( 69 | "--points-per-side", 70 | type=int, 71 | default=None, 72 | help="Generate masks by sampling a grid over the image with this many points to a side.", 73 | ) 74 | 75 | amg_settings.add_argument( 76 | "--points-per-batch", 77 | type=int, 78 | default=None, 79 | help="How many input points to process simultaneously in one batch.", 80 | ) 81 | 82 | amg_settings.add_argument( 83 | "--pred-iou-thresh", 84 | type=float, 85 | default=None, 86 | help="Exclude masks with a predicted score from the model that is lower than this threshold.", 87 | ) 88 | 89 | amg_settings.add_argument( 90 | "--stability-score-thresh", 91 | type=float, 92 | default=None, 93 | help="Exclude masks with a stability score lower than this threshold.", 94 | ) 95 | 96 | amg_settings.add_argument( 97 | "--stability-score-offset", 98 | type=float, 99 | default=None, 100 | help="Larger values perturb the mask more when measuring stability score.", 101 | ) 102 | 103 | amg_settings.add_argument( 104 | "--box-nms-thresh", 105 | type=float, 106 | default=None, 107 | help="The overlap threshold for excluding a duplicate mask.", 108 | ) 109 | 110 | amg_settings.add_argument( 111 | "--crop-n-layers", 112 | type=int, 113 | default=None, 114 | help=( 115 | "If >0, mask generation is run on smaller crops of the image to generate more masks. " 116 | "The value sets how many different scales to crop at." 117 | ), 118 | ) 119 | 120 | amg_settings.add_argument( 121 | "--crop-nms-thresh", 122 | type=float, 123 | default=None, 124 | help="The overlap threshold for excluding duplicate masks across different crops.", 125 | ) 126 | 127 | amg_settings.add_argument( 128 | "--crop-overlap-ratio", 129 | type=int, 130 | default=None, 131 | help="Larger numbers mean image crops will overlap more.", 132 | ) 133 | 134 | amg_settings.add_argument( 135 | "--crop-n-points-downscale-factor", 136 | type=int, 137 | default=None, 138 | help="The number of points-per-side in each layer of crop is reduced by this factor.", 139 | ) 140 | 141 | amg_settings.add_argument( 142 | "--min-mask-region-area", 143 | type=int, 144 | default=None, 145 | help=( 146 | "Disconnected mask regions or holes with area smaller than this value " 147 | "in pixels are removed by postprocessing." 148 | ), 149 | ) 150 | 151 | 152 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 153 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa 154 | metadata = [header] 155 | for i, mask_data in enumerate(masks): 156 | mask = mask_data["segmentation"] 157 | filename = f"{i}.png" 158 | cv2.imwrite(os.path.join(path, filename), mask * 255) 159 | mask_metadata = [ 160 | str(i), 161 | str(mask_data["area"]), 162 | *[str(x) for x in mask_data["bbox"]], 163 | *[str(x) for x in mask_data["point_coords"][0]], 164 | str(mask_data["predicted_iou"]), 165 | str(mask_data["stability_score"]), 166 | *[str(x) for x in mask_data["crop_box"]], 167 | ] 168 | row = ",".join(mask_metadata) 169 | metadata.append(row) 170 | metadata_path = os.path.join(path, "metadata.csv") 171 | with open(metadata_path, "w") as f: 172 | f.write("\n".join(metadata)) 173 | 174 | return 175 | 176 | 177 | def get_amg_kwargs(args): 178 | amg_kwargs = { 179 | "points_per_side": args.points_per_side, 180 | "points_per_batch": args.points_per_batch, 181 | "pred_iou_thresh": args.pred_iou_thresh, 182 | "stability_score_thresh": args.stability_score_thresh, 183 | "stability_score_offset": args.stability_score_offset, 184 | "box_nms_thresh": args.box_nms_thresh, 185 | "crop_n_layers": args.crop_n_layers, 186 | "crop_nms_thresh": args.crop_nms_thresh, 187 | "crop_overlap_ratio": args.crop_overlap_ratio, 188 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 189 | "min_mask_region_area": args.min_mask_region_area, 190 | } 191 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 192 | return amg_kwargs 193 | 194 | 195 | def main(args: argparse.Namespace) -> None: 196 | print("Loading model...") 197 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 198 | _ = sam.to(device=args.device) 199 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 200 | amg_kwargs = get_amg_kwargs(args) 201 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 202 | 203 | if not os.path.isdir(args.input): 204 | targets = [args.input] 205 | else: 206 | targets = [ 207 | f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) 208 | ] 209 | targets = [os.path.join(args.input, f) for f in targets] 210 | 211 | os.makedirs(args.output, exist_ok=True) 212 | 213 | for t in targets: 214 | print(f"Processing '{t}'...") 215 | image = cv2.imread(t) 216 | if image is None: 217 | print(f"Could not load '{t}' as an image, skipping...") 218 | continue 219 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 220 | 221 | masks = generator.generate(image) 222 | 223 | base = os.path.basename(t) 224 | base = os.path.splitext(base)[0] 225 | save_base = os.path.join(args.output, base) 226 | if output_mode == "binary_mask": 227 | os.makedirs(save_base, exist_ok=False) 228 | write_masks_to_folder(masks, save_base) 229 | else: 230 | save_file = save_base + ".json" 231 | with open(save_file, "w") as f: 232 | json.dump(masks, f) 233 | print("Done!") 234 | 235 | 236 | if __name__ == "__main__": 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /scripts/export_onnx_encoder_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import os 5 | import cv2 6 | import onnx 7 | import pathlib 8 | import shutil 9 | import argparse 10 | import warnings 11 | import numpy as np 12 | import albumentations as A 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | from tempfile import mkdtemp 18 | from segment_anything import sam_model_registry 19 | from segment_anything.modeling import Sam 20 | from onnx.external_data_helper import convert_model_to_external_data 21 | 22 | try: 23 | import onnxruntime # type: ignore 24 | 25 | onnxruntime_exists = True 26 | except ImportError: 27 | onnxruntime_exists = False 28 | 29 | 30 | parser = argparse.ArgumentParser( 31 | description="Export the model to ONNX format with encoder support." 32 | ) 33 | 34 | parser.add_argument( 35 | "--sam_checkpoint", 36 | type=str, 37 | required=True, 38 | help="The path to the SAM-Med2D model checkpoint.", 39 | ) 40 | 41 | parser.add_argument( 42 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 43 | ) 44 | 45 | parser.add_argument( 46 | "--model-type", 47 | type=str, 48 | required=True, 49 | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM-Med2D model to export.", 50 | ) 51 | 52 | parser.add_argument( 53 | "--use-preprocess", 54 | action="store_true", 55 | help="Integrate preprocessing into the model.", 56 | ) 57 | 58 | parser.add_argument( 59 | "--opset", 60 | type=int, 61 | default=17, 62 | help="The ONNX opset version to use. Must be >=11", 63 | ) 64 | 65 | parser.add_argument( 66 | '--device', 67 | type=str, 68 | default='cpu' 69 | ) 70 | 71 | parser.add_argument( 72 | "--image_size", 73 | type=int, 74 | default=256, 75 | help="image_size" 76 | 77 | ) 78 | 79 | parser.add_argument( 80 | "--encoder_adapter", 81 | type=bool, 82 | default=True, 83 | help="use adapter" 84 | ) 85 | 86 | parser.add_argument( 87 | "--quantize-out", 88 | type=str, 89 | default=None, 90 | help=( 91 | "If set, will quantize the model and save it with this name. " 92 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 93 | ), 94 | ) 95 | 96 | parser.add_argument( 97 | "--gelu-approximate", 98 | action="store_true", 99 | help=( 100 | "Replace GELU operations with approximations using tanh. Useful " 101 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 102 | ), 103 | ) 104 | 105 | class OnnxEncoderModel(nn.Module): 106 | 107 | pixel_mean = [123.675, 116.28, 103.53] 108 | pixel_std = [58.395, 57.12, 57.375] 109 | 110 | def __init__( 111 | self, 112 | model: Sam, 113 | input_size: tuple = (256, 256), 114 | pixel_mean: list = [123.675, 116.28, 103.53], 115 | pixel_std: list=[58.395, 57.12, 57.375], 116 | use_preprocess: bool = False 117 | ): 118 | super().__init__() 119 | self.use_preprocess = use_preprocess 120 | self.pixel_mean = torch.tensor(pixel_mean, dtype=torch.float) 121 | self.pixel_std = torch.tensor(pixel_std, dtype=torch.float) 122 | self.input_size = input_size 123 | self.model = model 124 | self.image_encoder = model.image_encoder 125 | 126 | @torch.no_grad() 127 | def forward(self, input_image: torch.Tensor): 128 | if self.use_preprocess: 129 | input_image = self.preprocess(input_image) 130 | image_embeddings = self.image_encoder(input_image) 131 | return image_embeddings 132 | 133 | def preprocess(self, input_image: torch.Tensor) -> torch.Tensor: 134 | """Image transform 135 | 136 | This function can convert the input image to the required input format for VIT. 137 | 138 | Args: 139 | img (torch.Tensor): Input image in BGR format. 140 | 141 | Returns: 142 | torch.Tensor: Transformed image. 143 | """ 144 | 145 | # Normalization 146 | input_image = (input_image - self.pixel_mean) / self.pixel_std 147 | 148 | # permute channels 149 | input_image = torch.permute(input_image, (2, 0, 1)) 150 | 151 | # CHW -> NCHW & Resize 152 | input_image = F.interpolate(input_image.unsqueeze(0), size=self.input_size, mode='nearest') 153 | 154 | return input_image 155 | 156 | 157 | def run_export(args): 158 | print("Loading model...") 159 | sam = sam_model_registry[args.model_type](args).to(args.device) 160 | 161 | model = OnnxEncoderModel( 162 | model=sam, 163 | use_preprocess=args.use_preprocess, 164 | pixel_mean=[123.675, 116.28, 103.53], 165 | pixel_std=[58.395, 57.12, 57.375], 166 | ) 167 | 168 | if args.gelu_approximate: 169 | for _, m in model.named_modules(): 170 | if isinstance(m, torch.nn.GELU): 171 | m.approximate = "tanh" 172 | 173 | image_size = sam.image_encoder.img_size 174 | if args.use_preprocess: 175 | dummy_inputs = { 176 | "input_image": torch.randn( 177 | (image_size, image_size, 3), dtype=torch.float 178 | ) 179 | } 180 | dynamic_axes = { 181 | "input_image": {0: "image_height", 1: "image_width"}, 182 | } 183 | else: 184 | dummy_inputs = { 185 | "input_image": torch.randn( 186 | (1, 3, image_size, image_size), dtype=torch.float 187 | ) 188 | } 189 | dynamic_axes = None 190 | _ = model(**dummy_inputs) 191 | 192 | output_names = ["image_embeddings"] 193 | onnx_base = os.path.splitext(os.path.basename(args.output))[0] 194 | with warnings.catch_warnings(): 195 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 196 | warnings.filterwarnings("ignore", category=UserWarning) 197 | print(f"Exporting onnx model to {args.output}...") 198 | if args.model_type == "vit_h": 199 | tmp_dir = mkdtemp() 200 | tmp_model_path = os.path.join(tmp_dir, f"{onnx_base}.onnx") 201 | torch.onnx.export( 202 | model, 203 | tuple(dummy_inputs.values()), 204 | tmp_model_path, 205 | export_params=True, 206 | verbose=False, 207 | opset_version=args.opset, 208 | do_constant_folding=True, 209 | input_names=list(dummy_inputs.keys()), 210 | output_names=output_names, 211 | dynamic_axes=dynamic_axes, 212 | ) 213 | 214 | # Combine the weights into a single file 215 | pathlib.Path(args.output).parent.mkdir(parents=True, exist_ok=True) 216 | model = onnx.load(tmp_model_path) 217 | convert_model_to_external_data( 218 | model, 219 | all_tensors_to_one_file=True, 220 | location=f"{onnx_base}_data.bin", 221 | size_threshold=1024, 222 | convert_attribute=False, 223 | ) 224 | 225 | # Save the model 226 | onnx.save(model, args.output) 227 | 228 | # Cleanup the temporary directory 229 | shutil.rmtree(tmp_dir) 230 | else: 231 | with open(args.output, "wb") as f: 232 | torch.onnx.export( 233 | model, 234 | tuple(dummy_inputs.values()), 235 | f, 236 | export_params=True, 237 | verbose=False, 238 | opset_version=args.opset, 239 | do_constant_folding=True, 240 | input_names=list(dummy_inputs.keys()), 241 | output_names=output_names, 242 | dynamic_axes=dynamic_axes, 243 | ) 244 | 245 | if onnxruntime_exists: 246 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 247 | ort_session = onnxruntime.InferenceSession(args.output) 248 | _ = ort_session.run(None, ort_inputs) 249 | print("Model has successfully been run with ONNXRuntime.") 250 | 251 | 252 | def to_numpy(tensor): 253 | return tensor.cpu().numpy() 254 | 255 | 256 | if __name__ == "__main__": 257 | args = parser.parse_args() 258 | run_export(args=args) 259 | 260 | if args.quantize_out is not None: 261 | from onnxruntime.quantization import QuantType # type: ignore 262 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore 263 | 264 | print(f"Quantizing model and writing to {args.quantize_out}...") 265 | quantize_dynamic( 266 | model_input=args.output, 267 | model_output=args.quantize_out, 268 | optimize_model=True, 269 | per_channel=False, 270 | reduce_range=False, 271 | weight_type=QuantType.QUInt8, 272 | ) 273 | print("Done!") 274 | -------------------------------------------------------------------------------- /scripts/export_onnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | sys.path.append(".") 9 | 10 | import torch 11 | 12 | from segment_anything import sam_model_registry 13 | from segment_anything.utils.onnx import SamOnnxModel 14 | 15 | import argparse 16 | import warnings 17 | 18 | try: 19 | import onnxruntime # type: ignore 20 | 21 | onnxruntime_exists = True 22 | except ImportError: 23 | onnxruntime_exists = False 24 | 25 | parser = argparse.ArgumentParser( 26 | description="Export the SAM-Med2D prompt encoder and mask decoder to an ONNX model." 27 | ) 28 | 29 | parser.add_argument( 30 | "--sam_checkpoint", type=str, required=True, help="The path to the SAM-Med2D model checkpoint." 31 | "Usage: python3 scripts/export_onnx_model.py --checkpoint xxx/sam-med2d_b.pth \ 32 | --output xxx/sam-med2d_b.decoder.onnx --model-type vit_b --opset 12 --return-single-mask" 33 | ) 34 | 35 | parser.add_argument( 36 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 37 | ) 38 | 39 | parser.add_argument( 40 | '--device', 41 | type=str, 42 | default='cpu' 43 | ) 44 | 45 | parser.add_argument( 46 | "--image_size", 47 | type=int, 48 | default=256, 49 | help="image_size" 50 | 51 | ) 52 | 53 | parser.add_argument( 54 | "--encoder_adapter", 55 | type=bool, 56 | default=True, 57 | help="use adapter" 58 | ) 59 | 60 | parser.add_argument( 61 | "--model-type", 62 | type=str, 63 | required=True, 64 | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM-Med2D model to export.", 65 | ) 66 | 67 | parser.add_argument( 68 | "--return-single-mask", 69 | action="store_true", 70 | help=( 71 | "If true, the exported ONNX model will only return the best mask, " 72 | "instead of returning multiple masks. For high resolution images " 73 | "this can improve runtime when upscaling masks is expensive." 74 | ), 75 | ) 76 | 77 | parser.add_argument( 78 | "--opset", 79 | type=int, 80 | default=17, 81 | help="The ONNX opset version to use. Must be >=11", 82 | ) 83 | 84 | parser.add_argument( 85 | "--quantize-out", 86 | type=str, 87 | default=None, 88 | help=( 89 | "If set, will quantize the model and save it with this name. " 90 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 91 | ), 92 | ) 93 | 94 | parser.add_argument( 95 | "--gelu-approximate", 96 | action="store_true", 97 | help=( 98 | "Replace GELU operations with approximations using tanh. Useful " 99 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 100 | ), 101 | ) 102 | 103 | parser.add_argument( 104 | "--use-stability-score", 105 | action="store_true", 106 | help=( 107 | "Replaces the model's predicted mask quality score with the stability " 108 | "score calculated on the low resolution masks using an offset of 1.0. " 109 | ), 110 | ) 111 | 112 | parser.add_argument( 113 | "--return-extra-metrics", 114 | action="store_true", 115 | help=( 116 | "The model will return five results: (masks, scores, stability_scores, " 117 | "areas, low_res_logits) instead of the usual three. This can be " 118 | "significantly slower for high resolution outputs." 119 | ), 120 | ) 121 | 122 | parser.add_argument( 123 | "--resize-logest-img-size", 124 | action="store_true", 125 | help=( 126 | "If enabled, the input image will be resized to fit the longest side " 127 | "and then undergo mask post-processing." 128 | ), 129 | ) 130 | 131 | 132 | def run_export( 133 | args, 134 | model_type: str, 135 | output: str, 136 | opset: int, 137 | return_single_mask: bool, 138 | gelu_approximate: bool = False, 139 | use_stability_score: bool = False, 140 | return_extra_metrics: bool = False, 141 | resize_logest_img_size: bool = False, 142 | ): 143 | print("Loading model...") 144 | sam = sam_model_registry[model_type](args).to(args.device) 145 | 146 | onnx_model = SamOnnxModel( 147 | model=sam, 148 | return_single_mask=return_single_mask, 149 | use_stability_score=use_stability_score, 150 | return_extra_metrics=return_extra_metrics, 151 | resize_logest_img_size=resize_logest_img_size, 152 | ) 153 | 154 | if gelu_approximate: 155 | for n, m in onnx_model.named_modules(): 156 | if isinstance(m, torch.nn.GELU): 157 | m.approximate = "tanh" 158 | 159 | dynamic_axes = { 160 | "point_coords": {1: "num_points"}, 161 | "point_labels": {1: "num_points"}, 162 | } 163 | 164 | embed_dim = sam.prompt_encoder.embed_dim 165 | embed_size = sam.prompt_encoder.image_embedding_size 166 | mask_input_size = [4 * x for x in embed_size] 167 | dummy_inputs = { 168 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 169 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 170 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 171 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 172 | "has_mask_input": torch.tensor([1], dtype=torch.float), 173 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 174 | } 175 | 176 | _ = onnx_model(**dummy_inputs) 177 | 178 | output_names = ["masks", "iou_predictions", "low_res_masks"] 179 | 180 | with warnings.catch_warnings(): 181 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 182 | warnings.filterwarnings("ignore", category=UserWarning) 183 | with open(output, "wb") as f: 184 | print(f"Exporting onnx model to {output}...") 185 | torch.onnx.export( 186 | onnx_model, 187 | tuple(dummy_inputs.values()), 188 | f, 189 | export_params=True, 190 | verbose=False, 191 | opset_version=opset, 192 | do_constant_folding=True, 193 | input_names=list(dummy_inputs.keys()), 194 | output_names=output_names, 195 | dynamic_axes=dynamic_axes, 196 | ) 197 | 198 | if onnxruntime_exists: 199 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 200 | ort_session = onnxruntime.InferenceSession(output) 201 | _ = ort_session.run(None, ort_inputs) 202 | print("Model has successfully been run with ONNXRuntime.") 203 | 204 | 205 | def to_numpy(tensor): 206 | return tensor.cpu().numpy() 207 | 208 | 209 | if __name__ == "__main__": 210 | args = parser.parse_args() 211 | run_export( 212 | args, 213 | model_type=args.model_type, 214 | output=args.output, 215 | opset=args.opset, 216 | return_single_mask=args.return_single_mask, 217 | gelu_approximate=args.gelu_approximate, 218 | use_stability_score=args.use_stability_score, 219 | return_extra_metrics=args.return_extra_metrics, 220 | resize_logest_img_size=args.resize_logest_img_size, 221 | ) 222 | 223 | if args.quantize_out is not None: 224 | assert onnxruntime_exists, "onnxruntime is required to quantize the model." 225 | from onnxruntime.quantization import QuantType # type: ignore 226 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore 227 | 228 | print(f"Quantizing model and writing to {args.quantize_out}...") 229 | quantize_dynamic( 230 | model_input=args.output, 231 | model_output=args.quantize_out, 232 | optimize_model=True, 233 | per_channel=False, 234 | reduce_range=False, 235 | weight_type=QuantType.QUInt8, 236 | ) 237 | print("Done!") 238 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from functools import partial 9 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 10 | from torch.nn import functional as F 11 | 12 | def build_sam_vit_h(args): 13 | return _build_sam( 14 | encoder_embed_dim=1280, 15 | encoder_depth=32, 16 | encoder_num_heads=16, 17 | encoder_global_attn_indexes=[7, 15, 23, 31], 18 | image_size=args.image_size, 19 | checkpoint=args.sam_checkpoint, 20 | encoder_adapter = args.encoder_adapter, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(args): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | image_size=args.image_size, 34 | checkpoint=args.sam_checkpoint, 35 | encoder_adapter = args.encoder_adapter, 36 | ) 37 | 38 | 39 | def build_sam_vit_b(args): 40 | return _build_sam( 41 | encoder_embed_dim=768, 42 | encoder_depth=12, 43 | encoder_num_heads=12, 44 | encoder_global_attn_indexes=[2, 5, 8, 11], 45 | image_size=args.image_size, 46 | checkpoint=args.sam_checkpoint, 47 | encoder_adapter = args.encoder_adapter, 48 | 49 | ) 50 | 51 | 52 | sam_model_registry = { 53 | "default": build_sam_vit_h, 54 | "vit_h": build_sam_vit_h, 55 | "vit_l": build_sam_vit_l, 56 | "vit_b": build_sam_vit_b, 57 | } 58 | 59 | 60 | def _build_sam( 61 | encoder_embed_dim, 62 | encoder_depth, 63 | encoder_num_heads, 64 | encoder_global_attn_indexes, 65 | image_size, 66 | checkpoint, 67 | encoder_adapter, 68 | ): 69 | prompt_embed_dim = 256 70 | image_size = image_size 71 | vit_patch_size = 16 72 | image_embedding_size = image_size // vit_patch_size 73 | sam = Sam( 74 | image_encoder=ImageEncoderViT( 75 | depth=encoder_depth, 76 | embed_dim=encoder_embed_dim, 77 | img_size=image_size, 78 | mlp_ratio=4, 79 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 80 | num_heads=encoder_num_heads, 81 | patch_size=vit_patch_size, 82 | qkv_bias=True, 83 | use_rel_pos = True, 84 | global_attn_indexes=encoder_global_attn_indexes, 85 | window_size=14, 86 | out_chans=prompt_embed_dim, 87 | adapter_train = encoder_adapter, 88 | ), 89 | prompt_encoder=PromptEncoder( 90 | embed_dim=prompt_embed_dim, 91 | image_embedding_size=(image_embedding_size, image_embedding_size), 92 | input_image_size=(image_size, image_size), 93 | mask_in_chans=16, 94 | ), 95 | mask_decoder=MaskDecoder( 96 | num_multimask_outputs=3, 97 | transformer=TwoWayTransformer( 98 | depth=2, 99 | embedding_dim=prompt_embed_dim, 100 | mlp_dim=2048, 101 | num_heads=8, 102 | ), 103 | transformer_dim=prompt_embed_dim, 104 | iou_head_depth=3, 105 | iou_head_hidden_dim=256, 106 | ), 107 | pixel_mean=[123.675, 116.28, 103.53], 108 | pixel_std=[58.395, 57.12, 57.375], 109 | ) 110 | # sam.train() 111 | if checkpoint is not None: 112 | with open(checkpoint, "rb") as f: 113 | state_dict = torch.load(f, map_location="cpu") 114 | try: 115 | if 'model' in state_dict.keys(): 116 | print(encoder_adapter) 117 | sam.load_state_dict(state_dict['model'], False) 118 | else: 119 | if image_size==1024 and encoder_adapter==True: 120 | sam.load_state_dict(state_dict, False) 121 | else: 122 | sam.load_state_dict(state_dict) 123 | except: 124 | print('*******interpolate') 125 | new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size) 126 | sam.load_state_dict(new_state_dict) 127 | print(f"*******load {checkpoint}") 128 | 129 | return sam 130 | 131 | 132 | def load_from(sam, state_dicts, image_size, vit_patch_size): 133 | 134 | sam_dict = sam.state_dict() 135 | except_keys = ['mask_tokens', 'output_hypernetworks_mlps', 'iou_prediction_head'] 136 | new_state_dict = {k: v for k, v in state_dicts.items() if 137 | k in sam_dict.keys() and except_keys[0] not in k and except_keys[1] not in k and except_keys[2] not in k} 138 | pos_embed = new_state_dict['image_encoder.pos_embed'] 139 | token_size = int(image_size // vit_patch_size) 140 | if pos_embed.shape[1] != token_size: 141 | # resize pos embedding, which may sacrifice the performance, but I have no better idea 142 | pos_embed = pos_embed.permute(0, 3, 1, 2) # [b, c, h, w] 143 | pos_embed = F.interpolate(pos_embed, (token_size, token_size), mode='bilinear', align_corners=False) 144 | pos_embed = pos_embed.permute(0, 2, 3, 1) # [b, h, w, c] 145 | new_state_dict['image_encoder.pos_embed'] = pos_embed 146 | rel_pos_keys = [k for k in sam_dict.keys() if 'rel_pos' in k] 147 | 148 | global_rel_pos_keys = [k for k in rel_pos_keys if 149 | '2' in k or 150 | '5' in k or 151 | '7' in k or 152 | '8' in k or 153 | '11' in k or 154 | '13' in k or 155 | '15' in k or 156 | '23' in k or 157 | '31' in k] 158 | # print(sam_dict) 159 | for k in global_rel_pos_keys: 160 | h_check, w_check = sam_dict[k].shape 161 | rel_pos_params = new_state_dict[k] 162 | h, w = rel_pos_params.shape 163 | rel_pos_params = rel_pos_params.unsqueeze(0).unsqueeze(0) 164 | if h != h_check or w != w_check: 165 | rel_pos_params = F.interpolate(rel_pos_params, (h_check, w_check), mode='bilinear', align_corners=False) 166 | 167 | new_state_dict[k] = rel_pos_params[0, 0, ...] 168 | 169 | sam_dict.update(new_state_dict) 170 | return sam_dict 171 | 172 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # from .sam import Sam 8 | from .sam_model import Sam 9 | from .image_encoder import ImageEncoderViT 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | y = self.weight[:, None, None] * x 43 | # y = torch.mul(self.weight[:, None, None], x) 44 | x = y + self.bias[:, None, None] 45 | return x 46 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) #256 256 4 3 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, #[B, 256, 64, 64] 74 | image_pe: torch.Tensor, #[1, 256, 64, 64] 75 | sparse_prompt_embeddings: torch.Tensor, #[B, 3, 256] 76 | dense_prompt_embeddings: torch.Tensor, #[B, 256, 64, 64] 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | ) 101 | 102 | # Select the correct mask or masks for output 103 | if multimask_output: 104 | mask_slice = slice(1, None) 105 | else: 106 | mask_slice = slice(0, 1) 107 | masks = masks[:, mask_slice, :, :] 108 | iou_pred = iou_pred[:, mask_slice] 109 | 110 | # Prepare output 111 | return masks, iou_pred 112 | 113 | def predict_masks( 114 | self, 115 | image_embeddings: torch.Tensor, 116 | image_pe: torch.Tensor, 117 | sparse_prompt_embeddings: torch.Tensor, 118 | dense_prompt_embeddings: torch.Tensor, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """Predicts masks. See 'forward' for more details.""" 121 | # Concatenate output tokens 122 | 123 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) #iou_token:[1,256] mask_tokens:[4,256] 124 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 125 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 126 | 127 | # Expand per-image data in batch direction to be per-mask 128 | # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 129 | src = image_embeddings 130 | src = src + dense_prompt_embeddings 131 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 132 | b, c, h, w = src.shape 133 | 134 | # Run the transformer 135 | hs, src = self.transformer(src, pos_src, tokens) 136 | iou_token_out = hs[:, 0, :] 137 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 138 | 139 | # Upscale mask embeddings and predict masks using the mask tokens 140 | src = src.transpose(1, 2).view(b, c, h, w) 141 | upscaled_embedding = self.output_upscaling(src) 142 | hyper_in_list: List[torch.Tensor] = [] 143 | for i in range(self.num_mask_tokens): 144 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 145 | hyper_in = torch.stack(hyper_in_list, dim=1) #[1,4,32] 146 | 147 | b, c, h, w = upscaled_embedding.shape #[1, 32, 256, 256] 148 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 149 | 150 | # Generate mask quality predictions 151 | iou_pred = self.iou_prediction_head(iou_token_out) 152 | 153 | return masks, iou_pred 154 | 155 | 156 | # Lightly adapted from 157 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 158 | class MLP(nn.Module): 159 | def __init__( 160 | self, 161 | input_dim: int, 162 | hidden_dim: int, 163 | output_dim: int, 164 | num_layers: int, 165 | sigmoid_output: bool = False, 166 | ) -> None: 167 | super().__init__() 168 | self.num_layers = num_layers 169 | h = [hidden_dim] * (num_layers - 1) 170 | self.layers = nn.ModuleList( 171 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 172 | ) 173 | self.sigmoid_output = sigmoid_output 174 | self.relu = nn.ReLU(inplace=False) 175 | def forward(self, x): 176 | for i, layer in enumerate(self.layers): 177 | # x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 178 | # x = self.relu(layer(x)) if i < self.num_layers - 1 else layer(x) #源码 179 | if i < self.num_layers - 1: 180 | x = F.relu(layer(x)) 181 | else: 182 | x = layer(x) 183 | 184 | if self.sigmoid_output: 185 | x = F.sigmoid(x) 186 | return x 187 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | 82 | if pad: 83 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 84 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 85 | points = torch.cat([points, padding_point], dim=1) #B,N+1,2 86 | labels = torch.cat([labels, padding_label], dim=1) 87 | 88 | 89 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) #B,N+1,256 90 | point_embedding[labels == -1] = 0.0 91 | 92 | self.not_a_point_embed.weight = torch.nn.Parameter(self.not_a_point_embed.weight.to(point_embedding.dtype), requires_grad=True) # todo 93 | self.point_embeddings[0].weight = torch.nn.Parameter(self.point_embeddings[0].weight.to(point_embedding.dtype), requires_grad=True) #todo 94 | self.point_embeddings[1].weight = torch.nn.Parameter(self.point_embeddings[1].weight.to(point_embedding.dtype), requires_grad=True) #todo 95 | 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | return point_embedding 100 | 101 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 102 | """Embeds box prompts.""" 103 | 104 | boxes = boxes + 0.5 # Shift to center of pixel 105 | coords = boxes.reshape(-1, 2, 2) 106 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 109 | return corner_embedding 110 | 111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 112 | """Embeds mask inputs.""" 113 | mask_embedding = self.mask_downscaling(masks) 114 | return mask_embedding 115 | 116 | def _get_batch_size( 117 | self, 118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 119 | boxes: Optional[torch.Tensor], 120 | masks: Optional[torch.Tensor], 121 | ) -> int: 122 | """ 123 | Gets the batch size of the output given the batch size of the input prompts. 124 | """ 125 | if points is not None: 126 | return points[0].shape[0] 127 | elif boxes is not None: 128 | return boxes.shape[0] 129 | elif masks is not None: 130 | return masks.shape[0] 131 | else: 132 | return 1 133 | 134 | def _get_device(self) -> torch.device: 135 | return self.point_embeddings[0].weight.device 136 | 137 | def forward( 138 | self, 139 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 140 | boxes: Optional[torch.Tensor], 141 | masks: Optional[torch.Tensor], 142 | ) -> Tuple[torch.Tensor, torch.Tensor]: 143 | """ 144 | Embeds different types of prompts, returning both sparse and dense 145 | embeddings. 146 | 147 | Arguments: 148 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 149 | and labels to embed. 150 | boxes (torch.Tensor or none): boxes to embed 151 | masks (torch.Tensor or none): masks to embed 152 | 153 | Returns: 154 | torch.Tensor: sparse embeddings for the points and boxes, with shape 155 | BxNx(embed_dim), where N is determined by the number of input points 156 | and boxes. 157 | torch.Tensor: dense embeddings for the masks, in the shape 158 | Bx(embed_dim)x(embed_H)x(embed_W) 159 | """ 160 | bs = self._get_batch_size(points, boxes, masks) 161 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) #B,0,256 空[] 162 | 163 | if points is not None: 164 | coords, labels = points #coords:B,N,2 labels:B,N 165 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 166 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 167 | 168 | if boxes is not None: 169 | box_embeddings = self._embed_boxes(boxes) 170 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 171 | 172 | if masks is not None: 173 | dense_embeddings = self._embed_masks(masks) 174 | else: 175 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 176 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 177 | ) 178 | 179 | return sparse_embeddings, dense_embeddings 180 | 181 | 182 | class PositionEmbeddingRandom(nn.Module): 183 | """ 184 | Positional encoding using random spatial frequencies. 185 | """ 186 | 187 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 188 | super().__init__() 189 | if scale is None or scale <= 0.0: 190 | scale = 1.0 191 | self.register_buffer( 192 | "positional_encoding_gaussian_matrix", 193 | scale * torch.randn((2, num_pos_feats)), 194 | ) 195 | 196 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 197 | """Positionally encode points that are normalized to [0,1].""" 198 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 199 | coords = 2 * coords - 1 200 | # coords = coords @ self.positional_encoding_gaussian_matrix 201 | coords = coords @ self.positional_encoding_gaussian_matrix.to(torch.float32) # todo 202 | coords = 2 * np.pi * coords 203 | # outputs d_1 x ... x d_n x C shape 204 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 205 | 206 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 207 | """Generate positional encoding for a grid of the specified size.""" 208 | h, w = size 209 | 210 | device: Any = self.positional_encoding_gaussian_matrix.device 211 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 212 | y_embed = grid.cumsum(dim=0) - 0.5 213 | x_embed = grid.cumsum(dim=1) - 0.5 214 | y_embed = y_embed / h 215 | x_embed = x_embed / w 216 | 217 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 218 | return pe.permute(2, 0, 1) # C x H x W 219 | 220 | def forward_with_coords( 221 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 222 | ) -> torch.Tensor: 223 | """Positionally encode points that are not normalized to [0,1].""" 224 | coords = coords_input.clone() 225 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 226 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 227 | 228 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 229 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | 98 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 99 | image_embeddings = self.image_encoder(input_images) 100 | 101 | outputs = [] 102 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | ) 119 | masks = self.postprocess_masks( 120 | low_res_masks, 121 | input_size=image_record["image"].shape[-2:], 122 | original_size=image_record["original_size"], 123 | ) 124 | masks = masks > self.mask_threshold 125 | outputs.append( 126 | { 127 | "masks": masks, 128 | "iou_predictions": iou_predictions, 129 | "low_res_logits": low_res_masks, 130 | } 131 | ) 132 | return outputs 133 | 134 | def postprocess_masks( 135 | self, 136 | masks: torch.Tensor, 137 | input_size: Tuple[int, ...], 138 | original_size: Tuple[int, ...], 139 | ) -> torch.Tensor: 140 | """ 141 | Remove padding and upscale masks to the original image size. 142 | 143 | Arguments: 144 | masks (torch.Tensor): Batched masks from the mask_decoder, 145 | in BxCxHxW format. 146 | input_size (tuple(int, int)): The size of the image input to the 147 | model, in (H, W) format. Used to remove padding. 148 | original_size (tuple(int, int)): The original size of the image 149 | before resizing for input to the model, in (H, W) format. 150 | 151 | Returns: 152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 153 | is given by original_size. 154 | """ 155 | masks = F.interpolate( 156 | masks, 157 | (self.image_encoder.img_size, self.image_encoder.img_size), 158 | mode="bilinear", 159 | align_corners=False, 160 | ) 161 | masks = masks[..., : input_size[0], : input_size[1]] 162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 163 | return masks 164 | 165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 166 | """Normalize pixel values and pad to a square input.""" 167 | # Normalize colors 168 | x = (x - self.pixel_mean) / self.pixel_std 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from typing import Any, Dict, List, Tuple 10 | from .image_encoder import ImageEncoderViT 11 | from .mask_decoder import MaskDecoder 12 | from .prompt_encoder import PromptEncoder 13 | 14 | 15 | class Sam(nn.Module): 16 | mask_threshold: float = 0.0 17 | image_format: str = "RGB" 18 | 19 | def __init__( 20 | self, 21 | image_encoder: ImageEncoderViT, 22 | prompt_encoder: PromptEncoder, 23 | mask_decoder: MaskDecoder, 24 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 25 | pixel_std: List[float] = [58.395, 57.12, 57.375], 26 | ) -> None: 27 | """ 28 | SAM predicts object masks from an image and input prompts. 29 | 30 | Arguments: 31 | image_encoder (ImageEncoderViT): The backbone used to encode the 32 | image into image embeddings that allow for efficient mask prediction. 33 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 34 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 35 | and encoded prompts. 36 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 37 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 38 | """ 39 | super().__init__() 40 | self.image_encoder = image_encoder 41 | self.prompt_encoder = prompt_encoder 42 | self.mask_decoder = mask_decoder 43 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 44 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 45 | 46 | @property 47 | def device(self) -> Any: 48 | return self.pixel_mean.device 49 | 50 | def forward(self, batched_input: Dict[str, Any], multimask_output: bool) -> List[Dict[str, torch.Tensor]]: 51 | 52 | input_images = batched_input.get("image") 53 | image_embeddings = self.image_encoder(input_images) 54 | 55 | if "point_coords" in batched_input and batched_input["point_coords"] != None: 56 | points = (batched_input["point_coords"], batched_input["point_labels"]) 57 | else: 58 | points = None 59 | 60 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 61 | points=points, 62 | boxes=batched_input.get("boxes", None), 63 | masks=batched_input.get("mask_inputs", None), 64 | ) # sparse_embeddings:[2, 3, 256], dense_embeddings:[2, 256, 64, 64] 65 | 66 | low_res_masks, iou_predictions = self.mask_decoder( 67 | image_embeddings=image_embeddings, 68 | image_pe=self.prompt_encoder.get_dense_pe(), # 1x(256)x(64)x(64) 69 | sparse_prompt_embeddings=sparse_embeddings, 70 | dense_prompt_embeddings=dense_embeddings, 71 | multimask_output=multimask_output, 72 | ) 73 | 74 | masks = self.postprocess_masks( 75 | low_res_masks, 76 | input_size=batched_input["image"].shape[-2:], 77 | original_size=batched_input["original_size"], 78 | ) 79 | 80 | outputs = { 81 | "masks": masks, 82 | "iou_predictions": iou_predictions, 83 | "low_res_logits": low_res_masks, 84 | } 85 | 86 | return outputs 87 | 88 | def postprocess_masks(self,masks: torch.Tensor, input_size: Tuple[int, ...],original_size: Tuple[int, ...],) -> torch.Tensor: 89 | masks = F.interpolate( 90 | masks, 91 | (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False,) #[1,1024,1024] 92 | 93 | masks = masks[..., : input_size[0], : input_size[1]] 94 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 95 | return masks 96 | 97 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 98 | """Normalize pixel values and pad to a square input.""" 99 | # Normalize colors 100 | x = (x - self.pixel_mean) / self.pixel_std 101 | # Pad 102 | h, w = x.shape[-2:] 103 | padh = self.image_encoder.img_size - h 104 | padw = self.image_encoder.img_size - w 105 | x = F.pad(x, (0, padw, 0, padh)) 106 | return x 107 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q.to(self.q_proj.weight.dtype)) #todo 221 | k = self.k_proj(k.to(self.k_proj.weight.dtype)) #todo 222 | v = self.v_proj(v.to(self.v_proj.weight.dtype)) #todo 223 | 224 | # q = self.q_proj(q) 225 | # k = self.k_proj(k) 226 | # v = self.v_proj(v) 227 | 228 | # Separate into heads 229 | q = self._separate_heads(q, self.num_heads) 230 | k = self._separate_heads(k, self.num_heads) 231 | v = self._separate_heads(v, self.num_heads) 232 | 233 | # Attention 234 | _, _, _, c_per_head = q.shape 235 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 236 | attn = attn / math.sqrt(c_per_head) 237 | attn = torch.softmax(attn, dim=-1) 238 | 239 | # Get output 240 | out = attn @ v 241 | out = self._recombine_heads(out) 242 | out = self.out_proj(out) 243 | 244 | return out 245 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | 143 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 144 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 145 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 146 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 147 | 148 | if box is not None: 149 | box = self.transform.apply_boxes(box, self.original_size) 150 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 151 | box_torch = box_torch[None, :] 152 | if mask_input is not None: 153 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 154 | mask_input_torch = mask_input_torch[None, :, :, :] 155 | 156 | masks, iou_predictions, low_res_masks = self.predict_torch( 157 | coords_torch, 158 | labels_torch, 159 | box_torch, 160 | mask_input_torch, 161 | multimask_output, 162 | return_logits=return_logits, 163 | ) 164 | 165 | masks = masks[0].detach().cpu().numpy() 166 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 167 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 168 | return masks, iou_predictions, low_res_masks 169 | 170 | @torch.no_grad() 171 | def predict_torch( 172 | self, 173 | point_coords: Optional[torch.Tensor], 174 | point_labels: Optional[torch.Tensor], 175 | boxes: Optional[torch.Tensor] = None, 176 | mask_input: Optional[torch.Tensor] = None, 177 | multimask_output: bool = True, 178 | return_logits: bool = False, 179 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 180 | """ 181 | Predict masks for the given input prompts, using the currently set image. 182 | Input prompts are batched torch tensors and are expected to already be 183 | transformed to the input frame using ResizeLongestSide. 184 | 185 | Arguments: 186 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 187 | model. Each point is in (X,Y) in pixels. 188 | point_labels (torch.Tensor or None): A BxN array of labels for the 189 | point prompts. 1 indicates a foreground point and 0 indicates a 190 | background point. 191 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 192 | model, in XYXY format. 193 | mask_input (np.ndarray): A low resolution mask input to the model, typically 194 | coming from a previous prediction iteration. Has form Bx1xHxW, where 195 | for SAM, H=W=256. Masks returned by a previous iteration of the 196 | predict method do not need further transformation. 197 | multimask_output (bool): If true, the model will return three masks. 198 | For ambiguous input prompts (such as a single click), this will often 199 | produce better masks than a single prediction. If only a single 200 | mask is needed, the model's predicted quality score can be used 201 | to select the best mask. For non-ambiguous prompts, such as multiple 202 | input prompts, multimask_output=False can give better results. 203 | return_logits (bool): If true, returns un-thresholded masks logits 204 | instead of a binary mask. 205 | 206 | Returns: 207 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 208 | number of masks, and (H, W) is the original image size. 209 | (torch.Tensor): An array of shape BxC containing the model's 210 | predictions for the quality of each mask. 211 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 212 | of masks and H=W=256. These low res logits can be passed to 213 | a subsequent iteration as mask input. 214 | """ 215 | if not self.is_image_set: 216 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 217 | 218 | if point_coords is not None: 219 | points = (point_coords, point_labels) 220 | else: 221 | points = None 222 | 223 | # Embed prompts 224 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 225 | points=points, 226 | boxes=boxes, 227 | masks=mask_input, 228 | ) 229 | 230 | # Predict masks 231 | low_res_masks, iou_predictions = self.model.mask_decoder( 232 | image_embeddings=self.features, 233 | image_pe=self.model.prompt_encoder.get_dense_pe(), 234 | sparse_prompt_embeddings=sparse_embeddings, 235 | dense_prompt_embeddings=dense_embeddings, 236 | multimask_output=multimask_output, 237 | ) 238 | 239 | # Upscale the masks to the original image resolution 240 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 241 | 242 | if not return_logits: 243 | masks = masks > self.model.mask_threshold 244 | 245 | return masks, iou_predictions, low_res_masks 246 | 247 | def get_image_embedding(self) -> torch.Tensor: 248 | """ 249 | Returns the image embeddings for the currently set image, with 250 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 251 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 252 | """ 253 | if not self.is_image_set: 254 | raise RuntimeError( 255 | "An image must be set with .set_image(...) to generate an embedding." 256 | ) 257 | assert self.features is not None, "Features must exist if an image has been set." 258 | return self.features 259 | 260 | @property 261 | def device(self) -> torch.device: 262 | return self.model.device 263 | 264 | def reset_image(self) -> None: 265 | """Resets the currently set image.""" 266 | self.is_image_set = False 267 | self.features = None 268 | self.orig_h = None 269 | self.orig_w = None 270 | self.input_h = None 271 | self.input_w = None 272 | -------------------------------------------------------------------------------- /segment_anything/predictor_sammed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Optional, Tuple 4 | from torch.nn import functional as F 5 | from copy import deepcopy 6 | from albumentations.pytorch import ToTensorV2 7 | import albumentations as A 8 | import cv2 9 | 10 | class SammedPredictor: 11 | def __init__(self, sam_model): 12 | 13 | super().__init__() 14 | self.model = sam_model 15 | self.devices = sam_model.device 16 | self.reset_image() 17 | 18 | 19 | def set_image(self,image: np.ndarray, image_format: str = "RGB") -> None: 20 | assert image_format in ["RGB","BGR",], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 21 | if image_format != self.model.image_format: 22 | image = image[..., ::-1] 23 | 24 | # Transform the image to the form expected by the model 25 | if self.model.pixel_mean.device.type == 'cuda': 26 | pixel_mean, pixel_std = self.model.pixel_mean.squeeze().cpu().numpy(), self.model.pixel_std.squeeze().cpu().numpy() 27 | input_image = (image - pixel_mean) / pixel_std 28 | else: 29 | pixel_mean, pixel_std = self.model.pixel_mean.squeeze().numpy(), self.model.pixel_std.squeeze().numpy() 30 | input_image = (image - pixel_mean) / pixel_std 31 | 32 | ori_h, ori_w, _ = input_image.shape 33 | self.original_size = (ori_h, ori_w) 34 | self.new_size = (self.model.image_encoder.img_size, self.model.image_encoder.img_size) 35 | transforms = self.transforms(self.new_size) 36 | augments = transforms(image=input_image) 37 | input_image = augments['image'][None, :, :, :] 38 | 39 | assert ( 40 | len(input_image.shape) == 4 41 | and input_image.shape[1] == 3 42 | and max(*input_image.shape[2:]) == self.model.image_encoder.img_size 43 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 44 | 45 | self.features = self.model.image_encoder(input_image.to(self.device)) 46 | self.is_image_set = True 47 | 48 | def predict( 49 | self, 50 | point_coords: Optional[np.ndarray] = None, 51 | point_labels: Optional[np.ndarray] = None, 52 | box: Optional[np.ndarray] = None, 53 | mask_input: Optional[np.ndarray] = None, 54 | multimask_output: bool = True, 55 | return_logits: bool = False, 56 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 57 | 58 | if not self.is_image_set: 59 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 60 | 61 | # Transform input prompts 62 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 63 | if point_coords is not None: 64 | assert ( 65 | point_labels is not None 66 | ), "point_labels must be supplied if point_coords is supplied." 67 | 68 | point_coords = self.apply_coords(point_coords, self.original_size, self.new_size) 69 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 70 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 71 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 72 | 73 | if box is not None: 74 | box = self.apply_boxes(box, self.original_size, self.new_size) 75 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 76 | box_torch = box_torch[None, :] 77 | if mask_input is not None: 78 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 79 | mask_input_torch = mask_input_torch[None, :, :, :] 80 | 81 | masks, iou_predictions, low_res_masks = self.predict_torch( 82 | coords_torch, 83 | labels_torch, 84 | box_torch, 85 | mask_input_torch, 86 | multimask_output, 87 | return_logits=return_logits, 88 | ) 89 | 90 | masks = masks[0].detach().cpu().numpy() 91 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 92 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 93 | return masks, iou_predictions, low_res_masks 94 | 95 | @torch.no_grad() 96 | def predict_torch( 97 | self, 98 | point_coords: Optional[torch.Tensor], 99 | point_labels: Optional[torch.Tensor], 100 | boxes: Optional[torch.Tensor] = None, 101 | mask_input: Optional[torch.Tensor] = None, 102 | multimask_output: bool = True, 103 | return_logits: bool = False, 104 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 105 | 106 | if not self.is_image_set: 107 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 108 | 109 | if point_coords is not None: 110 | points = (point_coords, point_labels) 111 | else: 112 | points = None 113 | 114 | if boxes is not None and boxes.shape[0] > 1: 115 | mask_list = [] 116 | # Embed prompts 117 | for i in range(boxes.shape[0]): 118 | pre_boxes = boxes[i:i+1,...] 119 | 120 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 121 | points=points, 122 | boxes=pre_boxes, 123 | masks=mask_input, 124 | ) 125 | 126 | # Predict masks 127 | low_res_masks, iou_predictions = self.model.mask_decoder( 128 | image_embeddings=self.features, 129 | image_pe=self.model.prompt_encoder.get_dense_pe(), 130 | sparse_prompt_embeddings=sparse_embeddings, 131 | dense_prompt_embeddings=dense_embeddings, 132 | multimask_output=multimask_output, 133 | ) 134 | 135 | if multimask_output: 136 | max_values, max_indexs = torch.max(iou_predictions, dim=1) 137 | max_values = max_values.unsqueeze(1) 138 | iou_predictions = max_values 139 | low_res_masks = low_res_masks[:, max_indexs] 140 | 141 | # Upscale the masks to the original image resolution 142 | pre_masks = self.postprocess_masks(low_res_masks, self.model.image_encoder.img_size, self.original_size) 143 | 144 | mask_list.append(pre_masks) 145 | masks = torch.cat(mask_list, dim=0) 146 | 147 | else: 148 | # Embed prompts 149 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 150 | points=points, 151 | boxes=boxes, 152 | masks=mask_input, 153 | ) 154 | 155 | # Predict masks 156 | low_res_masks, iou_predictions = self.model.mask_decoder( 157 | image_embeddings=self.features, 158 | image_pe=self.model.prompt_encoder.get_dense_pe(), 159 | sparse_prompt_embeddings=sparse_embeddings, 160 | dense_prompt_embeddings=dense_embeddings, 161 | multimask_output=multimask_output, 162 | ) 163 | 164 | if multimask_output: 165 | max_values, max_indexs = torch.max(iou_predictions, dim=1) 166 | max_values = max_values.unsqueeze(1) 167 | iou_predictions = max_values 168 | low_res_masks = low_res_masks[:, max_indexs] 169 | 170 | # Upscale the masks to the original image resolution 171 | masks = self.postprocess_masks(low_res_masks, self.model.image_encoder.img_size, self.original_size) 172 | 173 | if not return_logits: 174 | sigmoid_output = torch.sigmoid(masks) 175 | masks = (sigmoid_output > 0.5).float() 176 | 177 | return masks, iou_predictions, low_res_masks 178 | 179 | 180 | def postprocess_masks(self, low_res_masks, image_size, original_size): 181 | ori_h, ori_w = original_size 182 | masks = F.interpolate(low_res_masks,(image_size, image_size), mode="bilinear", align_corners=False) 183 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 184 | return masks 185 | 186 | 187 | def apply_coords(self, coords, original_size, new_size): 188 | old_h, old_w = original_size 189 | new_h, new_w = new_size 190 | coords = deepcopy(coords).astype(float) 191 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 192 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 193 | 194 | return coords 195 | 196 | def apply_boxes(self, boxes, original_size, new_size): 197 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, new_size) 198 | return boxes.reshape(-1, 4) 199 | 200 | 201 | def apply_coords_torch(self, coords, original_size, new_size): 202 | old_h, old_w = original_size 203 | new_h, new_w = new_size 204 | coords = deepcopy(coords).to(torch.float) 205 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 206 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 207 | return coords 208 | 209 | def apply_boxes_torch(self, boxes, original_size, new_size): 210 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size, new_size) 211 | return boxes.reshape(-1, 4) 212 | 213 | 214 | def get_image_embedding(self) -> torch.Tensor: 215 | """ 216 | Returns the image embeddings for the currently set image, with 217 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 218 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 219 | """ 220 | if not self.is_image_set: 221 | raise RuntimeError( 222 | "An image must be set with .set_image(...) to generate an embedding." 223 | ) 224 | assert self.features is not None, "Features must exist if an image has been set." 225 | return self.features 226 | 227 | 228 | def transforms(self, new_size): 229 | Transforms = [] 230 | new_h, new_w = new_size 231 | Transforms.append(A.Resize(int(new_h), int(new_w), interpolation=cv2.INTER_NEAREST)) 232 | Transforms.append(ToTensorV2(p=1.0)) 233 | return A.Compose(Transforms, p=1.) 234 | 235 | @property 236 | def device(self) -> torch.device: 237 | return self.model.device 238 | 239 | def reset_image(self) -> None: 240 | """Resets the currently set image.""" 241 | self.is_image_set = False 242 | self.features = None 243 | self.orig_h = None 244 | self.orig_w = None 245 | self.input_h = None 246 | self.input_w = None 247 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | resize_logest_img_size: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.return_single_mask = return_single_mask 38 | self.use_stability_score = use_stability_score 39 | self.stability_score_offset = 1.0 40 | self.return_extra_metrics = return_extra_metrics 41 | self.resize_logest_img_size = resize_logest_img_size 42 | 43 | @staticmethod 44 | def resize_longest_image_size( 45 | input_image_size: torch.Tensor, longest_side: int 46 | ) -> torch.Tensor: 47 | input_image_size = input_image_size.to(torch.float32) 48 | scale = longest_side / torch.max(input_image_size) 49 | transformed_size = scale * input_image_size 50 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 51 | return transformed_size 52 | 53 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 61 | point_labels == -1 62 | ) 63 | 64 | for i in range(self.model.prompt_encoder.num_point_embeddings): 65 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 66 | i 67 | ].weight * (point_labels == i) 68 | 69 | return point_embedding 70 | 71 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 72 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 73 | mask_embedding = mask_embedding + ( 74 | 1 - has_mask_input 75 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 76 | return mask_embedding 77 | 78 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 79 | masks = F.interpolate( 80 | masks, 81 | size=(self.img_size, self.img_size), 82 | mode="bilinear", 83 | align_corners=False, 84 | ) 85 | 86 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 87 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 88 | 89 | orig_im_size = orig_im_size.to(torch.int64) 90 | h, w = orig_im_size[0], orig_im_size[1] 91 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 92 | return masks 93 | 94 | def mask_postprocessing_without_rescale(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 95 | masks = F.interpolate(masks,(self.img_size, self.img_size), mode="bilinear", align_corners=False) 96 | orig_im_size = orig_im_size.to(torch.int64) 97 | h, w = orig_im_size[0], orig_im_size[1] 98 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 99 | return masks 100 | 101 | def select_masks( 102 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 103 | ) -> Tuple[torch.Tensor, torch.Tensor]: 104 | # Determine if we should return the multiclick mask or not from the number of points. 105 | # The reweighting is used to avoid control flow. 106 | score_reweight = torch.tensor( 107 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 108 | ).to(iou_preds.device) 109 | score = iou_preds + (num_points - 2.5) * score_reweight 110 | best_idx = torch.argmax(score, dim=1) 111 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 112 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 113 | 114 | return masks, iou_preds 115 | 116 | @torch.no_grad() 117 | def forward( 118 | self, 119 | image_embeddings: torch.Tensor, 120 | point_coords: torch.Tensor, 121 | point_labels: torch.Tensor, 122 | mask_input: torch.Tensor, 123 | has_mask_input: torch.Tensor, 124 | orig_im_size: torch.Tensor, 125 | ): 126 | sparse_embedding = self._embed_points(point_coords, point_labels) 127 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 128 | 129 | masks, scores = self.model.mask_decoder.predict_masks( 130 | image_embeddings=image_embeddings, 131 | image_pe=self.model.prompt_encoder.get_dense_pe(), 132 | sparse_prompt_embeddings=sparse_embedding, 133 | dense_prompt_embeddings=dense_embedding, 134 | ) 135 | 136 | if self.use_stability_score: 137 | scores = calculate_stability_score( 138 | masks, self.model.mask_threshold, self.stability_score_offset 139 | ) 140 | 141 | if self.return_single_mask: 142 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 143 | 144 | if self.resize_logest_img_size: 145 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 146 | else: 147 | upscaled_masks = self.mask_postprocessing_without_rescale(masks, orig_im_size) 148 | 149 | if self.return_extra_metrics: 150 | stability_scores = calculate_stability_score( 151 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 152 | ) 153 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 154 | return upscaled_masks, scores, stability_scores, areas, masks 155 | 156 | return upscaled_masks, scores, masks 157 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | 39 | old_h, old_w = original_size 40 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 41 | coords = deepcopy(coords).astype(float) 42 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 43 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 44 | return coords 45 | 46 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 47 | """ 48 | Expects a numpy array shape Bx4. Requires the original image size 49 | in (H, W) format. 50 | """ 51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 52 | return boxes.reshape(-1, 4) 53 | 54 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Expects batched images with shape BxCxHxW and float format. This 57 | transformation may not exactly match apply_image. apply_image is 58 | the transformation expected by the model. 59 | """ 60 | # Expects an image in BCHW format. May not exactly match apply_image. 61 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 62 | return F.interpolate( 63 | image, target_size, mode="bilinear", align_corners=False, antialias=True 64 | ) 65 | 66 | def apply_coords_torch( 67 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a torch tensor with length 2 in the last dimension. Requires the 71 | original image size in (H, W) format. 72 | """ 73 | old_h, old_w = original_size 74 | new_h, new_w = self.get_preprocess_shape( 75 | original_size[0], original_size[1], self.target_length 76 | ) 77 | 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from segment_anything import sam_model_registry 2 | import torch.nn as nn 3 | import torch 4 | import argparse 5 | import os 6 | from utils import FocalDiceloss_IoULoss, generate_point, save_masks 7 | from torch.utils.data import DataLoader 8 | from DataLoader import TestingDataset 9 | from metrics import SegMetrics 10 | import time 11 | from tqdm import tqdm 12 | import numpy as np 13 | from torch.nn import functional as F 14 | import logging 15 | import datetime 16 | import cv2 17 | import random 18 | import csv 19 | import json 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--work_dir", type=str, default="workdir", help="work dir") 25 | parser.add_argument("--run_name", type=str, default="sammed", help="run model name") 26 | parser.add_argument("--batch_size", type=int, default=1, help="batch size") 27 | parser.add_argument("--image_size", type=int, default=256, help="image_size") 28 | parser.add_argument('--device', type=str, default='cuda') 29 | parser.add_argument("--data_path", type=str, default="data_demo", help="train data path") 30 | parser.add_argument("--metrics", nargs='+', default=['iou', 'dice'], help="metrics") 31 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 32 | parser.add_argument("--sam_checkpoint", type=str, default="pretrain_model/sam-med2d_b.pth", help="sam checkpoint") 33 | parser.add_argument("--boxes_prompt", type=bool, default=True, help="use boxes prompt") 34 | parser.add_argument("--point_num", type=int, default=1, help="point num") 35 | parser.add_argument("--iter_point", type=int, default=1, help="iter num") 36 | parser.add_argument("--multimask", type=bool, default=True, help="ouput multimask") 37 | parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter") 38 | parser.add_argument("--prompt_path", type=str, default=None, help="fix prompt path") 39 | parser.add_argument("--save_pred", type=bool, default=False, help="save reslut") 40 | args = parser.parse_args() 41 | if args.iter_point > 1: 42 | args.point_num = 1 43 | return args 44 | 45 | 46 | def to_device(batch_input, device): 47 | device_input = {} 48 | for key, value in batch_input.items(): 49 | if value is not None: 50 | if key=='image' or key=='label': 51 | device_input[key] = value.float().to(device) 52 | elif type(value) is list or type(value) is torch.Size: 53 | device_input[key] = value 54 | else: 55 | device_input[key] = value.to(device) 56 | else: 57 | device_input[key] = value 58 | return device_input 59 | 60 | 61 | def postprocess_masks(low_res_masks, image_size, original_size): 62 | ori_h, ori_w = original_size 63 | masks = F.interpolate( 64 | low_res_masks, 65 | (image_size, image_size), 66 | mode="bilinear", 67 | align_corners=False, 68 | ) 69 | 70 | if ori_h < image_size and ori_w < image_size: 71 | top = torch.div((image_size - ori_h), 2, rounding_mode='trunc') #(image_size - ori_h) // 2 72 | left = torch.div((image_size - ori_w), 2, rounding_mode='trunc') #(image_size - ori_w) // 2 73 | masks = masks[..., top : ori_h + top, left : ori_w + left] 74 | pad = (top, left) 75 | else: 76 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 77 | pad = None 78 | return masks, pad 79 | 80 | 81 | def prompt_and_decoder(args, batched_input, ddp_model, image_embeddings): 82 | if batched_input["point_coords"] is not None: 83 | points = (batched_input["point_coords"], batched_input["point_labels"]) 84 | else: 85 | points = None 86 | 87 | with torch.no_grad(): 88 | sparse_embeddings, dense_embeddings = ddp_model.prompt_encoder( 89 | points=points, 90 | boxes=batched_input.get("boxes", None), 91 | masks=batched_input.get("mask_inputs", None), 92 | ) 93 | 94 | low_res_masks, iou_predictions = ddp_model.mask_decoder( 95 | image_embeddings = image_embeddings, 96 | image_pe = ddp_model.prompt_encoder.get_dense_pe(), 97 | sparse_prompt_embeddings=sparse_embeddings, 98 | dense_prompt_embeddings=dense_embeddings, 99 | multimask_output=args.multimask, 100 | ) 101 | 102 | if args.multimask: 103 | max_values, max_indexs = torch.max(iou_predictions, dim=1) 104 | max_values = max_values.unsqueeze(1) 105 | iou_predictions = max_values 106 | low_res = [] 107 | for i, idx in enumerate(max_indexs): 108 | low_res.append(low_res_masks[i:i+1, idx]) 109 | low_res_masks = torch.stack(low_res, 0) 110 | masks = F.interpolate(low_res_masks,(args.image_size, args.image_size), mode="bilinear", align_corners=False,) 111 | return masks, low_res_masks, iou_predictions 112 | 113 | 114 | def is_not_saved(save_path, mask_name): 115 | masks_path = os.path.join(save_path, f"{mask_name}") 116 | if os.path.exists(masks_path): 117 | return False 118 | else: 119 | return True 120 | 121 | 122 | def main(args): 123 | print('*'*100) 124 | for key, value in vars(args).items(): 125 | print(key + ': ' + str(value)) 126 | print('*'*100) 127 | 128 | model = sam_model_registry[args.model_type](args).to(args.device) 129 | 130 | criterion = FocalDiceloss_IoULoss() 131 | test_dataset = TestingDataset(data_path=args.data_path, image_size=args.image_size, mode='test', requires_name=True, point_num=args.point_num, return_ori_mask=True, prompt_path=args.prompt_path) 132 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4) 133 | print('Test data:', len(test_loader)) 134 | 135 | test_pbar = tqdm(test_loader) 136 | l = len(test_loader) 137 | 138 | model.eval() 139 | test_loss = [] 140 | test_iter_metrics = [0] * len(args.metrics) 141 | test_metrics = {} 142 | prompt_dict = {} 143 | 144 | for i, batched_input in enumerate(test_pbar): 145 | batched_input = to_device(batched_input, args.device) 146 | ori_labels = batched_input["ori_label"] 147 | original_size = batched_input["original_size"] 148 | labels = batched_input["label"] 149 | img_name = batched_input['name'][0] 150 | if args.prompt_path is None: 151 | prompt_dict[img_name] = { 152 | "boxes": batched_input["boxes"].squeeze(1).cpu().numpy().tolist(), 153 | "point_coords": batched_input["point_coords"].squeeze(1).cpu().numpy().tolist(), 154 | "point_labels": batched_input["point_labels"].squeeze(1).cpu().numpy().tolist() 155 | } 156 | 157 | with torch.no_grad(): 158 | image_embeddings = model.image_encoder(batched_input["image"]) 159 | 160 | if args.boxes_prompt: 161 | save_path = os.path.join(args.work_dir, args.run_name, "boxes_prompt") 162 | batched_input["point_coords"], batched_input["point_labels"] = None, None 163 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings) 164 | points_show = None 165 | 166 | else: 167 | save_path = os.path.join(f"{args.work_dir}", args.run_name, f"iter{args.iter_point if args.iter_point > 1 else args.point_num}_prompt") 168 | batched_input["boxes"] = None 169 | point_coords, point_labels = [batched_input["point_coords"]], [batched_input["point_labels"]] 170 | 171 | for iter in range(args.iter_point): 172 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings) 173 | if iter != args.iter_point-1: 174 | batched_input = generate_point(masks, labels, low_res_masks, batched_input, args.point_num) 175 | batched_input = to_device(batched_input, args.device) 176 | point_coords.append(batched_input["point_coords"]) 177 | point_labels.append(batched_input["point_labels"]) 178 | batched_input["point_coords"] = torch.concat(point_coords,dim=1) 179 | batched_input["point_labels"] = torch.concat(point_labels, dim=1) 180 | 181 | points_show = (torch.concat(point_coords, dim=1), torch.concat(point_labels, dim=1)) 182 | 183 | masks, pad = postprocess_masks(low_res_masks, args.image_size, original_size) 184 | if args.save_pred: 185 | save_masks(masks, save_path, img_name, args.image_size, original_size, pad, batched_input.get("boxes", None), points_show) 186 | 187 | loss = criterion(masks, ori_labels, iou_predictions) 188 | test_loss.append(loss.item()) 189 | 190 | test_batch_metrics = SegMetrics(masks, ori_labels, args.metrics) 191 | test_batch_metrics = [float('{:.4f}'.format(metric)) for metric in test_batch_metrics] 192 | 193 | for j in range(len(args.metrics)): 194 | test_iter_metrics[j] += test_batch_metrics[j] 195 | 196 | test_iter_metrics = [metric / l for metric in test_iter_metrics] 197 | test_metrics = {args.metrics[i]: '{:.4f}'.format(test_iter_metrics[i]) for i in range(len(test_iter_metrics))} 198 | 199 | average_loss = np.mean(test_loss) 200 | if args.prompt_path is None: 201 | with open(os.path.join(args.work_dir,f'{args.image_size}_prompt.json'), 'w') as f: 202 | json.dump(prompt_dict, f, indent=2) 203 | print(f"Test loss: {average_loss:.4f}, metrics: {test_metrics}") 204 | 205 | 206 | if __name__ == '__main__': 207 | args = parse_args() 208 | main(args) 209 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from segment_anything import sam_model_registry, SamPredictor 2 | import torch.nn as nn 3 | import torch 4 | import argparse 5 | import os 6 | from torch import optim 7 | from torch.utils.data import DataLoader 8 | from DataLoader import TrainingDataset, stack_dict_batched 9 | from utils import FocalDiceloss_IoULoss, get_logger, generate_point, setting_prompt_none 10 | from metrics import SegMetrics 11 | import time 12 | from tqdm import tqdm 13 | import numpy as np 14 | import datetime 15 | from torch.nn import functional as F 16 | from apex import amp 17 | import random 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--work_dir", type=str, default="workdir", help="work dir") 23 | parser.add_argument("--run_name", type=str, default="sam-med2d", help="run model name") 24 | parser.add_argument("--epochs", type=int, default=15, help="number of epochs") 25 | parser.add_argument("--batch_size", type=int, default=2, help="train batch size") 26 | parser.add_argument("--image_size", type=int, default=256, help="image_size") 27 | parser.add_argument("--mask_num", type=int, default=5, help="get mask number") 28 | parser.add_argument("--data_path", type=str, default="data_demo", help="train data path") 29 | parser.add_argument("--metrics", nargs='+', default=['iou', 'dice'], help="metrics") 30 | parser.add_argument('--device', type=str, default='cuda') 31 | parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 32 | parser.add_argument("--resume", type=str, default=None, help="load resume") 33 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 34 | parser.add_argument("--sam_checkpoint", type=str, default="pretrain_model/sam-med2d_b.pth", help="sam checkpoint") 35 | parser.add_argument("--iter_point", type=int, default=8, help="point iterations") 36 | parser.add_argument('--lr_scheduler', type=str, default=None, help='lr scheduler') 37 | parser.add_argument("--point_list", type=list, default=[1, 3, 5, 9], help="point_list") 38 | parser.add_argument("--multimask", type=bool, default=True, help="ouput multimask") 39 | parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter") 40 | parser.add_argument("--use_amp", type=bool, default=False, help="use amp") 41 | args = parser.parse_args() 42 | if args.resume is not None: 43 | args.sam_checkpoint = None 44 | return args 45 | 46 | 47 | def to_device(batch_input, device): 48 | device_input = {} 49 | for key, value in batch_input.items(): 50 | if value is not None: 51 | if key=='image' or key=='label': 52 | device_input[key] = value.float().to(device) 53 | elif type(value) is list or type(value) is torch.Size: 54 | device_input[key] = value 55 | else: 56 | device_input[key] = value.to(device) 57 | else: 58 | device_input[key] = value 59 | return device_input 60 | 61 | 62 | def prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter = False): 63 | if batched_input["point_coords"] is not None: 64 | points = (batched_input["point_coords"], batched_input["point_labels"]) 65 | else: 66 | points = None 67 | 68 | if decoder_iter: 69 | with torch.no_grad(): 70 | sparse_embeddings, dense_embeddings = model.prompt_encoder( 71 | points=points, 72 | boxes=batched_input.get("boxes", None), 73 | masks=batched_input.get("mask_inputs", None), 74 | ) 75 | 76 | else: 77 | sparse_embeddings, dense_embeddings = model.prompt_encoder( 78 | points=points, 79 | boxes=batched_input.get("boxes", None), 80 | masks=batched_input.get("mask_inputs", None), 81 | ) 82 | 83 | low_res_masks, iou_predictions = model.mask_decoder( 84 | image_embeddings = image_embeddings, 85 | image_pe = model.prompt_encoder.get_dense_pe(), 86 | sparse_prompt_embeddings=sparse_embeddings, 87 | dense_prompt_embeddings=dense_embeddings, 88 | multimask_output=args.multimask, 89 | ) 90 | 91 | if args.multimask: 92 | max_values, max_indexs = torch.max(iou_predictions, dim=1) 93 | max_values = max_values.unsqueeze(1) 94 | iou_predictions = max_values 95 | low_res = [] 96 | for i, idx in enumerate(max_indexs): 97 | low_res.append(low_res_masks[i:i+1, idx]) 98 | low_res_masks = torch.stack(low_res, 0) 99 | 100 | masks = F.interpolate(low_res_masks,(args.image_size, args.image_size), mode="bilinear", align_corners=False,) 101 | return masks, low_res_masks, iou_predictions 102 | 103 | 104 | def train_one_epoch(args, model, optimizer, train_loader, epoch, criterion): 105 | train_loader = tqdm(train_loader) 106 | train_losses = [] 107 | train_iter_metrics = [0] * len(args.metrics) 108 | for batch, batched_input in enumerate(train_loader): 109 | batched_input = stack_dict_batched(batched_input) 110 | batched_input = to_device(batched_input, args.device) 111 | 112 | if random.random() > 0.5: 113 | batched_input["point_coords"] = None 114 | flag = "boxes" 115 | else: 116 | batched_input["boxes"] = None 117 | flag = "point" 118 | 119 | for n, value in model.image_encoder.named_parameters(): 120 | if "Adapter" in n: 121 | value.requires_grad = True 122 | else: 123 | value.requires_grad = False 124 | 125 | if args.use_amp: 126 | labels = batched_input["label"].half() 127 | image_embeddings = model.image_encoder(batched_input["image"].half()) 128 | 129 | B, _, _, _ = image_embeddings.shape 130 | image_embeddings_repeat = [] 131 | for i in range(B): 132 | image_embed = image_embeddings[i] 133 | image_embed = image_embed.repeat(args.mask_num, 1, 1, 1) 134 | image_embeddings_repeat.append(image_embed) 135 | image_embeddings = torch.cat(image_embeddings_repeat, dim=0) 136 | 137 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter = False) 138 | loss = criterion(masks, labels, iou_predictions) 139 | with amp.scale_loss(loss, optimizer) as scaled_loss: 140 | scaled_loss.backward(retain_graph=False) 141 | 142 | else: 143 | labels = batched_input["label"] 144 | image_embeddings = model.image_encoder(batched_input["image"]) 145 | 146 | B, _, _, _ = image_embeddings.shape 147 | image_embeddings_repeat = [] 148 | for i in range(B): 149 | image_embed = image_embeddings[i] 150 | image_embed = image_embed.repeat(args.mask_num, 1, 1, 1) 151 | image_embeddings_repeat.append(image_embed) 152 | image_embeddings = torch.cat(image_embeddings_repeat, dim=0) 153 | 154 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter = False) 155 | loss = criterion(masks, labels, iou_predictions) 156 | loss.backward(retain_graph=False) 157 | 158 | optimizer.step() 159 | optimizer.zero_grad() 160 | 161 | if int(batch+1) % 50 == 0: 162 | print(f'Epoch: {epoch+1}, Batch: {batch+1}, first {flag} prompt: {SegMetrics(masks, labels, args.metrics)}') 163 | 164 | point_num = random.choice(args.point_list) 165 | batched_input = generate_point(masks, labels, low_res_masks, batched_input, point_num) 166 | batched_input = to_device(batched_input, args.device) 167 | 168 | image_embeddings = image_embeddings.detach().clone() 169 | for n, value in model.named_parameters(): 170 | if "image_encoder" in n: 171 | value.requires_grad = False 172 | else: 173 | value.requires_grad = True 174 | 175 | init_mask_num = np.random.randint(1, args.iter_point - 1) 176 | for iter in range(args.iter_point): 177 | if iter == init_mask_num or iter == args.iter_point - 1: 178 | batched_input = setting_prompt_none(batched_input) 179 | 180 | if args.use_amp: 181 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter=True) 182 | loss = criterion(masks, labels, iou_predictions) 183 | with amp.scale_loss(loss, optimizer) as scaled_loss: 184 | scaled_loss.backward(retain_graph=True) 185 | else: 186 | masks, low_res_masks, iou_predictions = prompt_and_decoder(args, batched_input, model, image_embeddings, decoder_iter=True) 187 | loss = criterion(masks, labels, iou_predictions) 188 | loss.backward(retain_graph=True) 189 | 190 | optimizer.step() 191 | optimizer.zero_grad() 192 | 193 | if iter != args.iter_point - 1: 194 | point_num = random.choice(args.point_list) 195 | batched_input = generate_point(masks, labels, low_res_masks, batched_input, point_num) 196 | batched_input = to_device(batched_input, args.device) 197 | 198 | if int(batch+1) % 50 == 0: 199 | if iter == init_mask_num or iter == args.iter_point - 1: 200 | print(f'Epoch: {epoch+1}, Batch: {batch+1}, mask prompt: {SegMetrics(masks, labels, args.metrics)}') 201 | else: 202 | print(f'Epoch: {epoch+1}, Batch: {batch+1}, point {point_num} prompt: { SegMetrics(masks, labels, args.metrics)}') 203 | 204 | if int(batch+1) % 200 == 0: 205 | print(f"epoch:{epoch+1}, iteration:{batch+1}, loss:{loss.item()}") 206 | save_path = os.path.join(f"{args.work_dir}/models", args.run_name, f"epoch{epoch+1}_batch{batch+1}_sam.pth") 207 | state = {'model': model.state_dict(), 'optimizer': optimizer} 208 | torch.save(state, save_path) 209 | 210 | train_losses.append(loss.item()) 211 | 212 | gpu_info = {} 213 | gpu_info['gpu_name'] = args.device 214 | train_loader.set_postfix(train_loss=loss.item(), gpu_info=gpu_info) 215 | 216 | train_batch_metrics = SegMetrics(masks, labels, args.metrics) 217 | train_iter_metrics = [train_iter_metrics[i] + train_batch_metrics[i] for i in range(len(args.metrics))] 218 | 219 | return train_losses, train_iter_metrics 220 | 221 | 222 | 223 | def main(args): 224 | model = sam_model_registry[args.model_type](args).to(args.device) 225 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 226 | criterion = FocalDiceloss_IoULoss() 227 | 228 | if args.lr_scheduler: 229 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10], gamma = 0.5) 230 | print('*******Use MultiStepLR') 231 | 232 | if args.resume is not None: 233 | with open(args.resume, "rb") as f: 234 | checkpoint = torch.load(f) 235 | model.load_state_dict(checkpoint['model']) 236 | optimizer.load_state_dict(checkpoint['optimizer'].state_dict()) 237 | print(f"*******load {args.resume}") 238 | 239 | if args.use_amp: 240 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 241 | print("*******Mixed precision with Apex") 242 | else: 243 | print('*******Do not use mixed precision') 244 | 245 | train_dataset = TrainingDataset(args.data_path, image_size=args.image_size, mode='train', point_num=1, mask_num=args.mask_num, requires_name = False) 246 | train_loader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True, num_workers=4) 247 | print('*******Train data:', len(train_dataset)) 248 | 249 | loggers = get_logger(os.path.join(args.work_dir, "logs", f"{args.run_name}_{datetime.datetime.now().strftime('%Y%m%d-%H%M.log')}")) 250 | 251 | best_loss = 1e10 252 | l = len(train_loader) 253 | 254 | for epoch in range(0, args.epochs): 255 | model.train() 256 | train_metrics = {} 257 | start = time.time() 258 | os.makedirs(os.path.join(f"{args.work_dir}/models", args.run_name), exist_ok=True) 259 | train_losses, train_iter_metrics = train_one_epoch(args, model, optimizer, train_loader, epoch, criterion) 260 | 261 | if args.lr_scheduler is not None: 262 | scheduler.step() 263 | 264 | train_iter_metrics = [metric / l for metric in train_iter_metrics] 265 | train_metrics = {args.metrics[i]: '{:.4f}'.format(train_iter_metrics[i]) for i in range(len(train_iter_metrics))} 266 | 267 | average_loss = np.mean(train_losses) 268 | lr = scheduler.get_last_lr()[0] if args.lr_scheduler is not None else args.lr 269 | loggers.info(f"epoch: {epoch + 1}, lr: {lr}, Train loss: {average_loss:.4f}, metrics: {train_metrics}") 270 | 271 | if average_loss < best_loss: 272 | best_loss = average_loss 273 | save_path = os.path.join(args.work_dir, "models", args.run_name, f"epoch{epoch+1}_sam.pth") 274 | state = {'model': model.float().state_dict(), 'optimizer': optimizer} 275 | torch.save(state, save_path) 276 | if args.use_amp: 277 | model = model.half() 278 | 279 | end = time.time() 280 | print("Run epoch time: %.2fs" % (end - start)) 281 | 282 | 283 | if __name__ == '__main__': 284 | args = parse_args() 285 | main(args) 286 | 287 | 288 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from albumentations.pytorch import ToTensorV2 2 | import cv2 3 | import albumentations as A 4 | import torch 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from skimage.measure import label, regionprops 8 | from matplotlib import pyplot as plt 9 | import random 10 | import torch.nn as nn 11 | import logging 12 | import os 13 | 14 | 15 | def get_boxes_from_mask(mask, box_num=1, std = 0.1, max_pixel = 5): 16 | """ 17 | Args: 18 | mask: Mask, can be a torch.Tensor or a numpy array of binary mask. 19 | box_num: Number of bounding boxes, default is 1. 20 | std: Standard deviation of the noise, default is 0.1. 21 | max_pixel: Maximum noise pixel value, default is 5. 22 | Returns: 23 | noise_boxes: Bounding boxes after noise perturbation, returned as a torch.Tensor. 24 | """ 25 | if isinstance(mask, torch.Tensor): 26 | mask = mask.numpy() 27 | 28 | label_img = label(mask) 29 | regions = regionprops(label_img) 30 | 31 | # Iterate through all regions and get the bounding box coordinates 32 | boxes = [tuple(region.bbox) for region in regions] 33 | 34 | # If the generated number of boxes is greater than the number of categories, 35 | # sort them by region area and select the top n regions 36 | if len(boxes) >= box_num: 37 | sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num] 38 | boxes = [tuple(region.bbox) for region in sorted_regions] 39 | 40 | # If the generated number of boxes is less than the number of categories, 41 | # duplicate the existing boxes 42 | elif len(boxes) < box_num: 43 | num_duplicates = box_num - len(boxes) 44 | boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)] 45 | 46 | # Perturb each bounding box with noise 47 | noise_boxes = [] 48 | for box in boxes: 49 | y0, x0, y1, x1 = box 50 | width, height = abs(x1 - x0), abs(y1 - y0) 51 | # Calculate the standard deviation and maximum noise value 52 | noise_std = min(width, height) * std 53 | max_noise = min(max_pixel, int(noise_std * 5)) 54 | # Add random noise to each coordinate 55 | try: 56 | noise_x = np.random.randint(-max_noise, max_noise) 57 | except: 58 | noise_x = 0 59 | try: 60 | noise_y = np.random.randint(-max_noise, max_noise) 61 | except: 62 | noise_y = 0 63 | x0, y0 = x0 + noise_x, y0 + noise_y 64 | x1, y1 = x1 + noise_x, y1 + noise_y 65 | noise_boxes.append((x0, y0, x1, y1)) 66 | return torch.as_tensor(noise_boxes, dtype=torch.float) 67 | 68 | 69 | def select_random_points(pr, gt, point_num = 9): 70 | """ 71 | Selects random points from the predicted and ground truth masks and assigns labels to them. 72 | Args: 73 | pred (torch.Tensor): Predicted mask tensor. 74 | gt (torch.Tensor): Ground truth mask tensor. 75 | point_num (int): Number of random points to select. Default is 9. 76 | Returns: 77 | batch_points (np.array): Array of selected points coordinates (x, y) for each batch. 78 | batch_labels (np.array): Array of corresponding labels (0 for background, 1 for foreground) for each batch. 79 | """ 80 | pred, gt = pr.data.cpu().numpy(), gt.data.cpu().numpy() 81 | error = np.zeros_like(pred) 82 | error[pred != gt] = 1 83 | 84 | # error = np.logical_xor(pred, gt) 85 | batch_points = [] 86 | batch_labels = [] 87 | for j in range(error.shape[0]): 88 | one_pred = pred[j].squeeze(0) 89 | one_gt = gt[j].squeeze(0) 90 | one_erroer = error[j].squeeze(0) 91 | 92 | indices = np.argwhere(one_erroer == 1) 93 | if indices.shape[0] > 0: 94 | selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)] 95 | else: 96 | indices = np.random.randint(0, 256, size=(point_num, 2)) 97 | selected_indices = indices[np.random.choice(indices.shape[0], point_num, replace=True)] 98 | selected_indices = selected_indices.reshape(-1, 2) 99 | 100 | points, labels = [], [] 101 | for i in selected_indices: 102 | x, y = i[0], i[1] 103 | if one_pred[x,y] == 0 and one_gt[x,y] == 1: 104 | label = 1 105 | elif one_pred[x,y] == 1 and one_gt[x,y] == 0: 106 | label = 0 107 | else: 108 | label = -1 109 | points.append((y, x)) #Negate the coordinates 110 | labels.append(label) 111 | 112 | batch_points.append(points) 113 | batch_labels.append(labels) 114 | return np.array(batch_points), np.array(batch_labels) 115 | 116 | 117 | def init_point_sampling(mask, get_point=1): 118 | """ 119 | Initialization samples points from the mask and assigns labels to them. 120 | Args: 121 | mask (torch.Tensor): Input mask tensor. 122 | num_points (int): Number of points to sample. Default is 1. 123 | Returns: 124 | coords (torch.Tensor): Tensor containing the sampled points' coordinates (x, y). 125 | labels (torch.Tensor): Tensor containing the corresponding labels (0 for background, 1 for foreground). 126 | """ 127 | if isinstance(mask, torch.Tensor): 128 | mask = mask.numpy() 129 | 130 | # Get coordinates of black/white pixels 131 | fg_coords = np.argwhere(mask == 1)[:,::-1] 132 | bg_coords = np.argwhere(mask == 0)[:,::-1] 133 | 134 | fg_size = len(fg_coords) 135 | bg_size = len(bg_coords) 136 | 137 | if get_point == 1: 138 | if fg_size > 0: 139 | index = np.random.randint(fg_size) 140 | fg_coord = fg_coords[index] 141 | label = 1 142 | else: 143 | index = np.random.randint(bg_size) 144 | fg_coord = bg_coords[index] 145 | label = 0 146 | return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int) 147 | else: 148 | num_fg = get_point // 2 149 | num_bg = get_point - num_fg 150 | fg_indices = np.random.choice(fg_size, size=num_fg, replace=True) 151 | bg_indices = np.random.choice(bg_size, size=num_bg, replace=True) 152 | fg_coords = fg_coords[fg_indices] 153 | bg_coords = bg_coords[bg_indices] 154 | coords = np.concatenate([fg_coords, bg_coords], axis=0) 155 | labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int) 156 | indices = np.random.permutation(get_point) 157 | coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices], dtype=torch.int) 158 | return coords, labels 159 | 160 | 161 | def train_transforms(img_size, ori_h, ori_w): 162 | transforms = [] 163 | if ori_h < img_size and ori_w < img_size: 164 | transforms.append(A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0))) 165 | else: 166 | transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST)) 167 | transforms.append(ToTensorV2(p=1.0)) 168 | return A.Compose(transforms, p=1.) 169 | 170 | 171 | def get_logger(filename, verbosity=1, name=None): 172 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 173 | formatter = logging.Formatter( 174 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 175 | ) 176 | logger = logging.getLogger(name) 177 | logger.setLevel(level_dict[verbosity]) 178 | 179 | os.makedirs(os.path.dirname(filename), exist_ok=True) 180 | 181 | fh = logging.FileHandler(filename, "w") 182 | fh.setFormatter(formatter) 183 | logger.addHandler(fh) 184 | 185 | sh = logging.StreamHandler() 186 | sh.setFormatter(formatter) 187 | logger.addHandler(sh) 188 | 189 | return logger 190 | 191 | 192 | def generate_point(masks, labels, low_res_masks, batched_input, point_num): 193 | masks_clone = masks.clone() 194 | masks_sigmoid = torch.sigmoid(masks_clone) 195 | masks_binary = (masks_sigmoid > 0.5).float() 196 | 197 | low_res_masks_clone = low_res_masks.clone() 198 | low_res_masks_logist = torch.sigmoid(low_res_masks_clone) 199 | 200 | points, point_labels = select_random_points(masks_binary, labels, point_num = point_num) 201 | batched_input["mask_inputs"] = low_res_masks_logist 202 | batched_input["point_coords"] = torch.as_tensor(points) 203 | batched_input["point_labels"] = torch.as_tensor(point_labels) 204 | batched_input["boxes"] = None 205 | return batched_input 206 | 207 | 208 | def setting_prompt_none(batched_input): 209 | batched_input["point_coords"] = None 210 | batched_input["point_labels"] = None 211 | batched_input["boxes"] = None 212 | return batched_input 213 | 214 | 215 | def draw_boxes(img, boxes): 216 | img_copy = np.copy(img) 217 | for box in boxes: 218 | cv2.rectangle(img_copy, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) 219 | return img_copy 220 | 221 | 222 | def save_masks(preds, save_path, mask_name, image_size, original_size, pad=None, boxes=None, points=None, visual_prompt=False): 223 | ori_h, ori_w = original_size 224 | 225 | preds = torch.sigmoid(preds) 226 | preds[preds > 0.5] = int(1) 227 | preds[preds <= 0.5] = int(0) 228 | 229 | mask = preds.squeeze().cpu().numpy() 230 | mask = cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR) 231 | 232 | if visual_prompt: #visualize the prompt 233 | if boxes is not None: 234 | boxes = boxes.squeeze().cpu().numpy() 235 | 236 | x0, y0, x1, y1 = boxes 237 | if pad is not None: 238 | x0_ori = int((x0 - pad[1]) + 0.5) 239 | y0_ori = int((y0 - pad[0]) + 0.5) 240 | x1_ori = int((x1 - pad[1]) + 0.5) 241 | y1_ori = int((y1 - pad[0]) + 0.5) 242 | else: 243 | x0_ori = int(x0 * ori_w / image_size) 244 | y0_ori = int(y0 * ori_h / image_size) 245 | x1_ori = int(x1 * ori_w / image_size) 246 | y1_ori = int(y1 * ori_h / image_size) 247 | 248 | boxes = [(x0_ori, y0_ori, x1_ori, y1_ori)] 249 | mask = draw_boxes(mask, boxes) 250 | 251 | if points is not None: 252 | point_coords, point_labels = points[0].squeeze(0).cpu().numpy(), points[1].squeeze(0).cpu().numpy() 253 | point_coords = point_coords.tolist() 254 | if pad is not None: 255 | ori_points = [[int((x * ori_w / image_size)) , int((y * ori_h / image_size))]if l==0 else [x - pad[1], y - pad[0]] for (x, y), l in zip(point_coords, point_labels)] 256 | else: 257 | ori_points = [[int((x * ori_w / image_size)) , int((y * ori_h / image_size))] for x, y in point_coords] 258 | 259 | for point, label in zip(ori_points, point_labels): 260 | x, y = map(int, point) 261 | color = (0, 255, 0) if label == 1 else (0, 0, 255) 262 | mask[y, x] = color 263 | cv2.drawMarker(mask, (x, y), color, markerType=cv2.MARKER_CROSS , markerSize=7, thickness=2) 264 | os.makedirs(save_path, exist_ok=True) 265 | mask_path = os.path.join(save_path, f"{mask_name}") 266 | cv2.imwrite(mask_path, np.uint8(mask)) 267 | 268 | 269 | #Loss funcation 270 | class FocalLoss(nn.Module): 271 | def __init__(self, gamma=2.0, alpha=0.25): 272 | super(FocalLoss, self).__init__() 273 | self.gamma = gamma 274 | self.alpha = alpha 275 | 276 | def forward(self, pred, mask): 277 | """ 278 | pred: [B, 1, H, W] 279 | mask: [B, 1, H, W] 280 | """ 281 | assert pred.shape == mask.shape, "pred and mask should have the same shape." 282 | p = torch.sigmoid(pred) 283 | num_pos = torch.sum(mask) 284 | num_neg = mask.numel() - num_pos 285 | w_pos = (1 - p) ** self.gamma 286 | w_neg = p ** self.gamma 287 | 288 | loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12) 289 | loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12) 290 | 291 | loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12) 292 | 293 | return loss 294 | 295 | 296 | class DiceLoss(nn.Module): 297 | def __init__(self, smooth=1.0): 298 | super(DiceLoss, self).__init__() 299 | self.smooth = smooth 300 | 301 | def forward(self, pred, mask): 302 | """ 303 | pred: [B, 1, H, W] 304 | mask: [B, 1, H, W] 305 | """ 306 | assert pred.shape == mask.shape, "pred and mask should have the same shape." 307 | p = torch.sigmoid(pred) 308 | intersection = torch.sum(p * mask) 309 | union = torch.sum(p) + torch.sum(mask) 310 | dice_loss = (2.0 * intersection + self.smooth) / (union + self.smooth) 311 | 312 | return 1 - dice_loss 313 | 314 | 315 | class MaskIoULoss(nn.Module): 316 | 317 | def __init__(self, ): 318 | super(MaskIoULoss, self).__init__() 319 | 320 | def forward(self, pred_mask, ground_truth_mask, pred_iou): 321 | """ 322 | pred_mask: [B, 1, H, W] 323 | ground_truth_mask: [B, 1, H, W] 324 | pred_iou: [B, 1] 325 | """ 326 | assert pred_mask.shape == ground_truth_mask.shape, "pred_mask and ground_truth_mask should have the same shape." 327 | 328 | p = torch.sigmoid(pred_mask) 329 | intersection = torch.sum(p * ground_truth_mask) 330 | union = torch.sum(p) + torch.sum(ground_truth_mask) - intersection 331 | iou = (intersection + 1e-7) / (union + 1e-7) 332 | iou_loss = torch.mean((iou - pred_iou) ** 2) 333 | return iou_loss 334 | 335 | 336 | class FocalDiceloss_IoULoss(nn.Module): 337 | 338 | def __init__(self, weight=20.0, iou_scale=1.0): 339 | super(FocalDiceloss_IoULoss, self).__init__() 340 | self.weight = weight 341 | self.iou_scale = iou_scale 342 | self.focal_loss = FocalLoss() 343 | self.dice_loss = DiceLoss() 344 | self.maskiou_loss = MaskIoULoss() 345 | 346 | def forward(self, pred, mask, pred_iou): 347 | """ 348 | pred: [B, 1, H, W] 349 | mask: [B, 1, H, W] 350 | """ 351 | assert pred.shape == mask.shape, "pred and mask should have the same shape." 352 | 353 | focal_loss = self.focal_loss(pred, mask) 354 | dice_loss =self.dice_loss(pred, mask) 355 | loss1 = self.weight * focal_loss + dice_loss 356 | loss2 = self.maskiou_loss(pred, mask, pred_iou) 357 | loss = loss1 + loss2 * self.iou_scale 358 | return loss 359 | --------------------------------------------------------------------------------