├── .gitignore ├── LICENSE ├── README.md ├── figs ├── Vis1.png ├── cam.jpg ├── pipeline.png ├── result1.png ├── result2.png └── sort.png ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── build_sam_baseline.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── mask_decoder_hq.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── tiny_vit_sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── setup.cfg ├── setup.py └── train ├── segment_anything_training ├── __init__.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py └── utils │ ├── __init__.py │ └── transforms.py ├── train.py └── utils ├── dataloader.py ├── loss_mask.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 chenly0618 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## EG-SAM: An Edge-Guided SAM for Accurate Complex Object Segmentation 2 | 3 | ### Pipeline 4 | 5 | ![pipeline](figs/pipeline.png) 6 | 7 | ### Environment 8 | 9 | Python 3.8 10 | 11 | CUDA 11.7 12 | 13 | PyTorch 1.13.1 14 | 15 | TorchVision 0.14.1 16 | 17 | ### Datasets 18 | We follow the data set structure of HQSeg-44K as follows: 19 | ``` 20 | data 21 | |____DIS5K 22 | |____cascade_psp 23 | | |____DUTS-TE 24 | | |____DUTS-TR 25 | | |____ecssd 26 | | |____fss_all 27 | | |____MSRA_10K 28 | |____thin_object_detection 29 | | |____COIFT 30 | | |____ThinObject5K 31 | ``` 32 | You can get the datasets from [here](https://drive.google.com/drive/folders/1j1yFEejTAdAQzSbCrdBWoHE4VjaAf25L?usp=drive_link) 33 | ### Train 34 | ``` 35 | python -m torch.distributed.launch --nproc_per_node= train.py --checkpoint --model-type --output 36 | ``` 37 | 38 | EG-SAM is an improvement on the basis of HQ-SAM, you can follow the environment Settings [here](https://github.com/SysCV/SAM-HQ?tab=readme-ov-file) 39 | ### Evaluation 40 | ``` 41 | python -m torch.distributed.launch --nproc_per_node= train.py --checkpoint --model-type --output --eval --restore-model 42 | ``` 43 | You can get the weight file [here](https://drive.google.com/file/d/1B9-bTQ4c_fG8s--837HpMhaUT5gWGdv9/view?usp=drive_link) 44 | 45 | ### Visualization 46 | 47 | ![Vis1](figs/Vis1.png) 48 | 49 | Visual comparison with **nine state-of-the-art COD methods**. EG-SAM demonstrates superior accuracy in delineating the boundaries of camouflaged objects. 50 | 51 | ![cam](figs/cam.jpg) 52 | ![sort](figs/sort.png) 53 | 54 | ### Results on DIS,COIFT and ThinObject 55 | 56 | ![result1](figs/result1.png) 57 | 58 | ### Results on CODs 59 | 60 | ![result2](figs/result2.png) 61 | 62 | -------------------------------------------------------------------------------- /figs/Vis1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/Vis1.png -------------------------------------------------------------------------------- /figs/cam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/cam.jpg -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/pipeline.png -------------------------------------------------------------------------------- /figs/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/result1.png -------------------------------------------------------------------------------- /figs/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/result2.png -------------------------------------------------------------------------------- /figs/sort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-797/EG-SAM/d2acb7a474a78fd25c7aec93dd395b96f2a32e24/figs/sort.png -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .build_sam_baseline import sam_model_registry_baseline 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image, multimask_output) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | multimask_output: bool = True, 232 | ) -> MaskData: 233 | # Crop the image and calculate embeddings 234 | x0, y0, x1, y1 = crop_box 235 | cropped_im = image[y0:y1, x0:x1, :] 236 | cropped_im_size = cropped_im.shape[:2] 237 | self.predictor.set_image(cropped_im) 238 | 239 | # Get points for this crop 240 | points_scale = np.array(cropped_im_size)[None, ::-1] 241 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 242 | 243 | # Generate masks for this crop in batches 244 | data = MaskData() 245 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 246 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output) 247 | data.cat(batch_data) 248 | del batch_data 249 | self.predictor.reset_image() 250 | 251 | # Remove duplicates within this crop. 252 | keep_by_nms = batched_nms( 253 | data["boxes"].float(), 254 | data["iou_preds"], 255 | torch.zeros_like(data["boxes"][:, 0]), # categories 256 | iou_threshold=self.box_nms_thresh, 257 | ) 258 | data.filter(keep_by_nms) 259 | 260 | # Return to the original image frame 261 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 262 | data["points"] = uncrop_points(data["points"], crop_box) 263 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 264 | 265 | return data 266 | 267 | def _process_batch( 268 | self, 269 | points: np.ndarray, 270 | im_size: Tuple[int, ...], 271 | crop_box: List[int], 272 | orig_size: Tuple[int, ...], 273 | multimask_output: bool = True, 274 | ) -> MaskData: 275 | orig_h, orig_w = orig_size 276 | 277 | # Run model on this batch 278 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 279 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 280 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 281 | masks, iou_preds, _ = self.predictor.predict_torch( 282 | in_points[:, None, :], 283 | in_labels[:, None], 284 | multimask_output=multimask_output, 285 | return_logits=True, 286 | ) 287 | 288 | # Serialize predictions and store in MaskData 289 | data = MaskData( 290 | masks=masks.flatten(0, 1), 291 | iou_preds=iou_preds.flatten(0, 1), 292 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 293 | ) 294 | del masks 295 | 296 | # Filter by predicted IoU 297 | if self.pred_iou_thresh > 0.0: 298 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 299 | data.filter(keep_mask) 300 | 301 | # Calculate stability score 302 | data["stability_score"] = calculate_stability_score( 303 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 304 | ) 305 | if self.stability_score_thresh > 0.0: 306 | keep_mask = data["stability_score"] >= self.stability_score_thresh 307 | data.filter(keep_mask) 308 | 309 | # Threshold masks and calculate boxes 310 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 311 | data["boxes"] = batched_mask_to_box(data["masks"]) 312 | 313 | # Filter boxes that touch crop boundaries 314 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 315 | if not torch.all(keep_mask): 316 | data.filter(keep_mask) 317 | 318 | # Compress to RLE 319 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 320 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 321 | del data["masks"] 322 | 323 | return data 324 | 325 | @staticmethod 326 | def postprocess_small_regions( 327 | mask_data: MaskData, min_area: int, nms_thresh: float 328 | ) -> MaskData: 329 | """ 330 | Removes small disconnected regions and holes in masks, then reruns 331 | box NMS to remove any new duplicates. 332 | 333 | Edits mask_data in place. 334 | 335 | Requires open-cv as a dependency. 336 | """ 337 | if len(mask_data["rles"]) == 0: 338 | return mask_data 339 | 340 | # Filter small disconnected regions and holes 341 | new_masks = [] 342 | scores = [] 343 | for rle in mask_data["rles"]: 344 | mask = rle_to_mask(rle) 345 | 346 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 347 | unchanged = not changed 348 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 349 | unchanged = unchanged and not changed 350 | 351 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 352 | # Give score=0 to changed masks and score=1 to unchanged masks 353 | # so NMS will prefer ones that didn't need postprocessing 354 | scores.append(float(unchanged)) 355 | 356 | # Recalculate boxes and remove any new duplicates 357 | masks = torch.cat(new_masks, dim=0) 358 | boxes = batched_mask_to_box(masks) 359 | keep_by_nms = batched_nms( 360 | boxes.float(), 361 | torch.as_tensor(scores), 362 | torch.zeros_like(boxes[:, 0]), # categories 363 | iou_threshold=nms_thresh, 364 | ) 365 | 366 | # Only recalculate RLEs for masks that have changed 367 | for i_mask in keep_by_nms: 368 | if scores[i_mask] == 0.0: 369 | mask_torch = masks[i_mask].unsqueeze(0) 370 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 371 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 372 | mask_data.filter(keep_by_nms) 373 | 374 | return mask_data 375 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | def build_sam_vit_t(checkpoint=None): 48 | prompt_embed_dim = 256 49 | image_size = 1024 50 | vit_patch_size = 16 51 | image_embedding_size = image_size // vit_patch_size 52 | mobile_sam = Sam( 53 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 54 | embed_dims=[64, 128, 160, 320], 55 | depths=[2, 2, 6, 2], 56 | num_heads=[2, 4, 5, 10], 57 | window_sizes=[7, 7, 14, 7], 58 | mlp_ratio=4., 59 | drop_rate=0., 60 | drop_path_rate=0.0, 61 | use_checkpoint=False, 62 | mbconv_expand_ratio=4.0, 63 | local_conv_size=3, 64 | layer_lr_decay=0.8 65 | ), 66 | prompt_encoder=PromptEncoder( 67 | embed_dim=prompt_embed_dim, 68 | image_embedding_size=(image_embedding_size, image_embedding_size), 69 | input_image_size=(image_size, image_size), 70 | mask_in_chans=16, 71 | ), 72 | mask_decoder=MaskDecoderHQ( 73 | num_multimask_outputs=3, 74 | transformer=TwoWayTransformer( 75 | depth=2, 76 | embedding_dim=prompt_embed_dim, 77 | mlp_dim=2048, 78 | num_heads=8, 79 | ), 80 | transformer_dim=prompt_embed_dim, 81 | iou_head_depth=3, 82 | iou_head_hidden_dim=256, 83 | vit_dim=160, 84 | ), 85 | pixel_mean=[123.675, 116.28, 103.53], 86 | pixel_std=[58.395, 57.12, 57.375], 87 | ) 88 | 89 | mobile_sam.eval() 90 | if checkpoint is not None: 91 | with open(checkpoint, "rb") as f: 92 | device = "cuda" if torch.cuda.is_available() else "cpu" 93 | state_dict = torch.load(f, map_location=device) 94 | info = mobile_sam.load_state_dict(state_dict, strict=False) 95 | print(info) 96 | for n, p in mobile_sam.named_parameters(): 97 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: 98 | p.requires_grad = False 99 | return mobile_sam 100 | 101 | sam_model_registry = { 102 | "default": build_sam_vit_h, 103 | "vit_h": build_sam_vit_h, 104 | "vit_l": build_sam_vit_l, 105 | "vit_b": build_sam_vit_b, 106 | "vit_tiny": build_sam_vit_t 107 | } 108 | 109 | 110 | def _build_sam( 111 | encoder_embed_dim, 112 | encoder_depth, 113 | encoder_num_heads, 114 | encoder_global_attn_indexes, 115 | checkpoint=None, 116 | ): 117 | prompt_embed_dim = 256 118 | image_size = 1024 119 | vit_patch_size = 16 120 | image_embedding_size = image_size // vit_patch_size 121 | sam = Sam( 122 | image_encoder=ImageEncoderViT( 123 | depth=encoder_depth, 124 | embed_dim=encoder_embed_dim, 125 | img_size=image_size, 126 | mlp_ratio=4, 127 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 128 | num_heads=encoder_num_heads, 129 | patch_size=vit_patch_size, 130 | qkv_bias=True, 131 | use_rel_pos=True, 132 | global_attn_indexes=encoder_global_attn_indexes, 133 | window_size=14, 134 | out_chans=prompt_embed_dim, 135 | ), 136 | prompt_encoder=PromptEncoder( 137 | embed_dim=prompt_embed_dim, 138 | image_embedding_size=(image_embedding_size, image_embedding_size), 139 | input_image_size=(image_size, image_size), 140 | mask_in_chans=16, 141 | ), 142 | mask_decoder=MaskDecoderHQ( 143 | num_multimask_outputs=3, 144 | transformer=TwoWayTransformer( 145 | depth=2, 146 | embedding_dim=prompt_embed_dim, 147 | mlp_dim=2048, 148 | num_heads=8, 149 | ), 150 | transformer_dim=prompt_embed_dim, 151 | iou_head_depth=3, 152 | iou_head_hidden_dim=256, 153 | vit_dim=encoder_embed_dim, 154 | ), 155 | pixel_mean=[123.675, 116.28, 103.53], 156 | pixel_std=[58.395, 57.12, 57.375], 157 | ) 158 | sam.eval() 159 | if checkpoint is not None: 160 | with open(checkpoint, "rb") as f: 161 | device = "cuda" if torch.cuda.is_available() else "cpu" 162 | state_dict = torch.load(f, map_location=device) 163 | info = sam.load_state_dict(state_dict, strict=False) 164 | print(info) 165 | for n, p in sam.named_parameters(): 166 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: 167 | p.requires_grad = False 168 | 169 | return sam 170 | -------------------------------------------------------------------------------- /segment_anything/build_sam_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | def build_sam_vit_t(checkpoint=None): 48 | prompt_embed_dim = 256 49 | image_size = 1024 50 | vit_patch_size = 16 51 | image_embedding_size = image_size // vit_patch_size 52 | mobile_sam = Sam( 53 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 54 | embed_dims=[64, 128, 160, 320], 55 | depths=[2, 2, 6, 2], 56 | num_heads=[2, 4, 5, 10], 57 | window_sizes=[7, 7, 14, 7], 58 | mlp_ratio=4., 59 | drop_rate=0., 60 | drop_path_rate=0.0, 61 | use_checkpoint=False, 62 | mbconv_expand_ratio=4.0, 63 | local_conv_size=3, 64 | layer_lr_decay=0.8 65 | ), 66 | prompt_encoder=PromptEncoder( 67 | embed_dim=prompt_embed_dim, 68 | image_embedding_size=(image_embedding_size, image_embedding_size), 69 | input_image_size=(image_size, image_size), 70 | mask_in_chans=16, 71 | ), 72 | mask_decoder=MaskDecoder( 73 | num_multimask_outputs=3, 74 | transformer=TwoWayTransformer( 75 | depth=2, 76 | embedding_dim=prompt_embed_dim, 77 | mlp_dim=2048, 78 | num_heads=8, 79 | ), 80 | transformer_dim=prompt_embed_dim, 81 | iou_head_depth=3, 82 | iou_head_hidden_dim=256, 83 | ), 84 | pixel_mean=[123.675, 116.28, 103.53], 85 | pixel_std=[58.395, 57.12, 57.375], 86 | ) 87 | 88 | mobile_sam.eval() 89 | if checkpoint is not None: 90 | with open(checkpoint, "rb") as f: 91 | state_dict = torch.load(f) 92 | mobile_sam.load_state_dict(state_dict) 93 | return mobile_sam 94 | 95 | sam_model_registry_baseline = { 96 | "default": build_sam_vit_h, 97 | "vit_h": build_sam_vit_h, 98 | "vit_l": build_sam_vit_l, 99 | "vit_b": build_sam_vit_b, 100 | "vit_tiny": build_sam_vit_t 101 | } 102 | 103 | 104 | def _build_sam( 105 | encoder_embed_dim, 106 | encoder_depth, 107 | encoder_num_heads, 108 | encoder_global_attn_indexes, 109 | checkpoint=None, 110 | ): 111 | prompt_embed_dim = 256 112 | image_size = 1024 113 | vit_patch_size = 16 114 | image_embedding_size = image_size // vit_patch_size 115 | sam = Sam( 116 | image_encoder=ImageEncoderViT( 117 | depth=encoder_depth, 118 | embed_dim=encoder_embed_dim, 119 | img_size=image_size, 120 | mlp_ratio=4, 121 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 122 | num_heads=encoder_num_heads, 123 | patch_size=vit_patch_size, 124 | qkv_bias=True, 125 | use_rel_pos=True, 126 | global_attn_indexes=encoder_global_attn_indexes, 127 | window_size=14, 128 | out_chans=prompt_embed_dim, 129 | ), 130 | prompt_encoder=PromptEncoder( 131 | embed_dim=prompt_embed_dim, 132 | image_embedding_size=(image_embedding_size, image_embedding_size), 133 | input_image_size=(image_size, image_size), 134 | mask_in_chans=16, 135 | ), 136 | mask_decoder=MaskDecoder( 137 | num_multimask_outputs=3, 138 | transformer=TwoWayTransformer( 139 | depth=2, 140 | embedding_dim=prompt_embed_dim, 141 | mlp_dim=2048, 142 | num_heads=8, 143 | ), 144 | transformer_dim=prompt_embed_dim, 145 | iou_head_depth=3, 146 | iou_head_hidden_dim=256, 147 | ), 148 | pixel_mean=[123.675, 116.28, 103.53], 149 | pixel_std=[58.395, 57.12, 57.375], 150 | ) 151 | sam.eval() 152 | if checkpoint is not None: 153 | with open(checkpoint, "rb") as f: 154 | state_dict = torch.load(f) 155 | sam.load_state_dict(state_dict) 156 | return sam -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder_hq import MaskDecoderHQ 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | from .tiny_vit_sam import TinyViT 14 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | interm_embeddings=[] 112 | for blk in self.blocks: 113 | x = blk(x) 114 | if blk.window_size == 0: 115 | interm_embeddings.append(x) 116 | 117 | x = self.neck(x.permute(0, 3, 1, 2)) 118 | 119 | return x, interm_embeddings 120 | 121 | 122 | class Block(nn.Module): 123 | """Transformer blocks with support of window attention and residual propagation blocks""" 124 | 125 | def __init__( 126 | self, 127 | dim: int, 128 | num_heads: int, 129 | mlp_ratio: float = 4.0, 130 | qkv_bias: bool = True, 131 | norm_layer: Type[nn.Module] = nn.LayerNorm, 132 | act_layer: Type[nn.Module] = nn.GELU, 133 | use_rel_pos: bool = False, 134 | rel_pos_zero_init: bool = True, 135 | window_size: int = 0, 136 | input_size: Optional[Tuple[int, int]] = None, 137 | ) -> None: 138 | """ 139 | Args: 140 | dim (int): Number of input channels. 141 | num_heads (int): Number of attention heads in each ViT block. 142 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 143 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 144 | norm_layer (nn.Module): Normalization layer. 145 | act_layer (nn.Module): Activation layer. 146 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 147 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 148 | window_size (int): Window size for window attention blocks. If it equals 0, then 149 | use global attention. 150 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 151 | positional parameter size. 152 | """ 153 | super().__init__() 154 | self.norm1 = norm_layer(dim) 155 | self.attn = Attention( 156 | dim, 157 | num_heads=num_heads, 158 | qkv_bias=qkv_bias, 159 | use_rel_pos=use_rel_pos, 160 | rel_pos_zero_init=rel_pos_zero_init, 161 | input_size=input_size if window_size == 0 else (window_size, window_size), 162 | ) 163 | 164 | self.norm2 = norm_layer(dim) 165 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 166 | 167 | self.window_size = window_size 168 | 169 | def forward(self, x: torch.Tensor) -> torch.Tensor: 170 | shortcut = x 171 | x = self.norm1(x) 172 | # Window partition 173 | if self.window_size > 0: 174 | H, W = x.shape[1], x.shape[2] 175 | x, pad_hw = window_partition(x, self.window_size) 176 | 177 | x = self.attn(x) 178 | # Reverse window partition 179 | if self.window_size > 0: 180 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 181 | 182 | x = shortcut + x 183 | x = x + self.mlp(self.norm2(x)) 184 | 185 | return x 186 | 187 | 188 | class Attention(nn.Module): 189 | """Multi-head Attention block with relative position embeddings.""" 190 | 191 | def __init__( 192 | self, 193 | dim: int, 194 | num_heads: int = 8, 195 | qkv_bias: bool = True, 196 | use_rel_pos: bool = False, 197 | rel_pos_zero_init: bool = True, 198 | input_size: Optional[Tuple[int, int]] = None, 199 | ) -> None: 200 | """ 201 | Args: 202 | dim (int): Number of input channels. 203 | num_heads (int): Number of attention heads. 204 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 205 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 206 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 207 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 208 | positional parameter size. 209 | """ 210 | super().__init__() 211 | self.num_heads = num_heads 212 | head_dim = dim // num_heads 213 | self.scale = head_dim**-0.5 214 | 215 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 216 | self.proj = nn.Linear(dim, dim) 217 | 218 | self.use_rel_pos = use_rel_pos 219 | if self.use_rel_pos: 220 | assert ( 221 | input_size is not None 222 | ), "Input size must be provided if using relative positional encoding." 223 | # initialize relative positional embeddings 224 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 225 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 226 | 227 | def forward(self, x: torch.Tensor) -> torch.Tensor: 228 | B, H, W, _ = x.shape 229 | # qkv with shape (3, B, nHead, H * W, C) 230 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 231 | # q, k, v with shape (B * nHead, H * W, C) 232 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 233 | 234 | attn = (q * self.scale) @ k.transpose(-2, -1) 235 | 236 | if self.use_rel_pos: 237 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 238 | 239 | attn = attn.softmax(dim=-1) 240 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 241 | x = self.proj(x) 242 | 243 | return x 244 | 245 | 246 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 247 | """ 248 | Partition into non-overlapping windows with padding if needed. 249 | Args: 250 | x (tensor): input tokens with [B, H, W, C]. 251 | window_size (int): window size. 252 | 253 | Returns: 254 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 255 | (Hp, Wp): padded height and width before partition 256 | """ 257 | B, H, W, C = x.shape 258 | 259 | pad_h = (window_size - H % window_size) % window_size 260 | pad_w = (window_size - W % window_size) % window_size 261 | if pad_h > 0 or pad_w > 0: 262 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 263 | Hp, Wp = H + pad_h, W + pad_w 264 | 265 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 266 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 267 | return windows, (Hp, Wp) 268 | 269 | 270 | def window_unpartition( 271 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 272 | ) -> torch.Tensor: 273 | """ 274 | Window unpartition into original sequences and removing padding. 275 | Args: 276 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 277 | window_size (int): window size. 278 | pad_hw (Tuple): padded height and width (Hp, Wp). 279 | hw (Tuple): original height and width (H, W) before padding. 280 | 281 | Returns: 282 | x: unpartitioned sequences with [B, H, W, C]. 283 | """ 284 | Hp, Wp = pad_hw 285 | H, W = hw 286 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 287 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 288 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 289 | 290 | if Hp > H or Wp > W: 291 | x = x[:, :H, :W, :].contiguous() 292 | return x 293 | 294 | 295 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 296 | """ 297 | Get relative positional embeddings according to the relative positions of 298 | query and key sizes. 299 | Args: 300 | q_size (int): size of query q. 301 | k_size (int): size of key k. 302 | rel_pos (Tensor): relative position embeddings (L, C). 303 | 304 | Returns: 305 | Extracted positional embeddings according to relative positions. 306 | """ 307 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 308 | # Interpolate rel pos if needed. 309 | if rel_pos.shape[0] != max_rel_dist: 310 | # Interpolate rel pos. 311 | rel_pos_resized = F.interpolate( 312 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 313 | size=max_rel_dist, 314 | mode="linear", 315 | ) 316 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 317 | else: 318 | rel_pos_resized = rel_pos 319 | 320 | # Scale the coords with short length if shapes for q and k are different. 321 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 322 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 323 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 324 | 325 | return rel_pos_resized[relative_coords.long()] 326 | 327 | 328 | def add_decomposed_rel_pos( 329 | attn: torch.Tensor, 330 | q: torch.Tensor, 331 | rel_pos_h: torch.Tensor, 332 | rel_pos_w: torch.Tensor, 333 | q_size: Tuple[int, int], 334 | k_size: Tuple[int, int], 335 | ) -> torch.Tensor: 336 | """ 337 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 338 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 339 | Args: 340 | attn (Tensor): attention map. 341 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 342 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 343 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 344 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 345 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 346 | 347 | Returns: 348 | attn (Tensor): attention map with added relative positional embeddings. 349 | """ 350 | q_h, q_w = q_size 351 | k_h, k_w = k_size 352 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 353 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 354 | 355 | B, _, dim = q.shape 356 | r_q = q.reshape(B, q_h, q_w, dim) 357 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 358 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 359 | 360 | attn = ( 361 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 362 | ).view(B, q_h * q_w, k_h * k_w) 363 | 364 | return attn 365 | 366 | 367 | class PatchEmbed(nn.Module): 368 | """ 369 | Image to Patch Embedding. 370 | """ 371 | 372 | def __init__( 373 | self, 374 | kernel_size: Tuple[int, int] = (16, 16), 375 | stride: Tuple[int, int] = (16, 16), 376 | padding: Tuple[int, int] = (0, 0), 377 | in_chans: int = 3, 378 | embed_dim: int = 768, 379 | ) -> None: 380 | """ 381 | Args: 382 | kernel_size (Tuple): kernel size of the projection layer. 383 | stride (Tuple): stride of the projection layer. 384 | padding (Tuple): padding size of the projection layer. 385 | in_chans (int): Number of input image channels. 386 | embed_dim (int): Patch embedding dimension. 387 | """ 388 | super().__init__() 389 | 390 | self.proj = nn.Conv2d( 391 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 392 | ) 393 | 394 | def forward(self, x: torch.Tensor) -> torch.Tensor: 395 | x = self.proj(x) 396 | # B C H W -> B H W C 397 | x = x.permute(0, 2, 3, 1) 398 | return x 399 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | hq_token_only: bool, 79 | interm_embeddings: torch.Tensor, 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Predict masks given image and prompt embeddings. 83 | 84 | Arguments: 85 | image_embeddings (torch.Tensor): the embeddings from the image encoder 86 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 87 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 88 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 89 | multimask_output (bool): Whether to return multiple masks or a single 90 | mask. 91 | 92 | Returns: 93 | torch.Tensor: batched predicted masks 94 | torch.Tensor: batched predictions of mask quality 95 | """ 96 | masks, iou_pred = self.predict_masks( 97 | image_embeddings=image_embeddings, 98 | image_pe=image_pe, 99 | sparse_prompt_embeddings=sparse_prompt_embeddings, 100 | dense_prompt_embeddings=dense_prompt_embeddings, 101 | ) 102 | 103 | # Select the correct mask or masks for output 104 | if multimask_output: 105 | mask_slice = slice(1, None) 106 | else: 107 | mask_slice = slice(0, 1) 108 | masks = masks[:, mask_slice, :, :] 109 | iou_pred = iou_pred[:, mask_slice] 110 | 111 | # Prepare output 112 | return masks, iou_pred 113 | 114 | def predict_masks( 115 | self, 116 | image_embeddings: torch.Tensor, 117 | image_pe: torch.Tensor, 118 | sparse_prompt_embeddings: torch.Tensor, 119 | dense_prompt_embeddings: torch.Tensor, 120 | ) -> Tuple[torch.Tensor, torch.Tensor]: 121 | """Predicts masks. See 'forward' for more details.""" 122 | # Concatenate output tokens 123 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 124 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 125 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 126 | 127 | # Expand per-image data in batch direction to be per-mask 128 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 129 | src = src + dense_prompt_embeddings 130 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 131 | b, c, h, w = src.shape 132 | 133 | # Run the transformer 134 | hs, src = self.transformer(src, pos_src, tokens) 135 | iou_token_out = hs[:, 0, :] 136 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 137 | 138 | # Upscale mask embeddings and predict masks using the mask tokens 139 | src = src.transpose(1, 2).view(b, c, h, w) 140 | upscaled_embedding = self.output_upscaling(src) 141 | hyper_in_list: List[torch.Tensor] = [] 142 | for i in range(self.num_mask_tokens): 143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 144 | hyper_in = torch.stack(hyper_in_list, dim=1) 145 | b, c, h, w = upscaled_embedding.shape 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 147 | 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x 179 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder_hq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # Modified by HQ-SAM team 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from typing import List, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class MaskDecoderHQ(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | transformer_dim: int, 22 | transformer: nn.Module, 23 | num_multimask_outputs: int = 3, 24 | activation: Type[nn.Module] = nn.GELU, 25 | iou_head_depth: int = 3, 26 | iou_head_hidden_dim: int = 256, 27 | vit_dim: int = 1024, 28 | ) -> None: 29 | """ 30 | Predicts masks given an image and prompt embeddings, using a 31 | transformer architecture. 32 | 33 | Arguments: 34 | transformer_dim (int): the channel dimension of the transformer 35 | transformer (nn.Module): the transformer used to predict masks 36 | num_multimask_outputs (int): the number of masks to predict 37 | when disambiguating masks 38 | activation (nn.Module): the type of activation to use when 39 | upscaling masks 40 | iou_head_depth (int): the depth of the MLP used to predict 41 | mask quality 42 | iou_head_hidden_dim (int): the hidden dimension of the MLP 43 | used to predict mask quality 44 | """ 45 | super().__init__() 46 | self.transformer_dim = transformer_dim 47 | self.transformer = transformer 48 | 49 | self.num_multimask_outputs = num_multimask_outputs 50 | 51 | self.iou_token = nn.Embedding(1, transformer_dim) 52 | self.num_mask_tokens = num_multimask_outputs + 1 53 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 54 | 55 | self.output_upscaling = nn.Sequential( 56 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 57 | LayerNorm2d(transformer_dim // 4), 58 | activation(), 59 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 60 | activation(), 61 | ) 62 | self.output_hypernetworks_mlps = nn.ModuleList( 63 | [ 64 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 65 | for i in range(self.num_mask_tokens) 66 | ] 67 | ) 68 | 69 | self.iou_prediction_head = MLP( 70 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 71 | ) 72 | 73 | # HQ-SAM parameters 74 | self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token 75 | self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token 76 | self.num_mask_tokens = self.num_mask_tokens + 1 77 | 78 | # three conv fusion layers for obtaining HQ-Feature 79 | self.compress_vit_feat = nn.Sequential( 80 | nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), 81 | LayerNorm2d(transformer_dim), 82 | nn.GELU(), 83 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) 84 | 85 | self.embedding_encoder = nn.Sequential( 86 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 87 | LayerNorm2d(transformer_dim // 4), 88 | nn.GELU(), 89 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 90 | ) 91 | self.embedding_maskfeature = nn.Sequential( 92 | nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), 93 | LayerNorm2d(transformer_dim // 4), 94 | nn.GELU(), 95 | nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) 96 | 97 | 98 | 99 | def forward( 100 | self, 101 | image_embeddings: torch.Tensor, 102 | image_pe: torch.Tensor, 103 | sparse_prompt_embeddings: torch.Tensor, 104 | dense_prompt_embeddings: torch.Tensor, 105 | multimask_output: bool, 106 | hq_token_only: bool, 107 | interm_embeddings: torch.Tensor, 108 | ) -> Tuple[torch.Tensor, torch.Tensor]: 109 | """ 110 | Predict masks given image and prompt embeddings. 111 | 112 | Arguments: 113 | image_embeddings (torch.Tensor): the embeddings from the ViT image encoder 114 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 115 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 116 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 117 | multimask_output (bool): Whether to return multiple masks or a single 118 | mask. 119 | 120 | Returns: 121 | torch.Tensor: batched predicted masks 122 | torch.Tensor: batched predictions of mask quality 123 | """ 124 | vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT 125 | hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) 126 | 127 | masks, iou_pred = self.predict_masks( 128 | image_embeddings=image_embeddings, 129 | image_pe=image_pe, 130 | sparse_prompt_embeddings=sparse_prompt_embeddings, 131 | dense_prompt_embeddings=dense_prompt_embeddings, 132 | hq_features=hq_features, 133 | ) 134 | 135 | # Select the correct mask or masks for output 136 | if multimask_output: 137 | # mask with highest score 138 | mask_slice = slice(1,self.num_mask_tokens-1) 139 | iou_pred = iou_pred[:, mask_slice] 140 | iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) 141 | iou_pred = iou_pred.unsqueeze(1) 142 | masks_multi = masks[:, mask_slice, :, :] 143 | masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) 144 | else: 145 | # singale mask output, default 146 | mask_slice = slice(0, 1) 147 | iou_pred = iou_pred[:,mask_slice] 148 | masks_sam = masks[:,mask_slice] 149 | 150 | masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)] 151 | if hq_token_only: 152 | masks = masks_hq 153 | else: 154 | masks = masks_sam + masks_hq 155 | # Prepare output 156 | return masks, iou_pred 157 | 158 | def predict_masks( 159 | self, 160 | image_embeddings: torch.Tensor, 161 | image_pe: torch.Tensor, 162 | sparse_prompt_embeddings: torch.Tensor, 163 | dense_prompt_embeddings: torch.Tensor, 164 | hq_features: torch.Tensor, 165 | ) -> Tuple[torch.Tensor, torch.Tensor]: 166 | """Predicts masks. See 'forward' for more details.""" 167 | # Concatenate output tokens 168 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) 169 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 170 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 171 | 172 | # Expand per-image data in batch direction to be per-mask 173 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 174 | src = src + dense_prompt_embeddings 175 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 176 | b, c, h, w = src.shape 177 | 178 | # Run the transformer 179 | hs, src = self.transformer(src, pos_src, tokens) 180 | iou_token_out = hs[:, 0, :] 181 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 182 | 183 | # Upscale mask embeddings and predict masks using the mask tokens 184 | src = src.transpose(1, 2).view(b, c, h, w) 185 | 186 | upscaled_embedding_sam = self.output_upscaling(src) 187 | upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1) 188 | 189 | hyper_in_list: List[torch.Tensor] = [] 190 | for i in range(self.num_mask_tokens): 191 | if i < self.num_mask_tokens - 1: 192 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 193 | else: 194 | hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) 195 | 196 | hyper_in = torch.stack(hyper_in_list, dim=1) 197 | b, c, h, w = upscaled_embedding_sam.shape 198 | 199 | masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) 200 | masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) 201 | masks = torch.cat([masks_sam,masks_sam_hq],dim=1) 202 | # Generate mask quality predictions 203 | iou_pred = self.iou_prediction_head(iou_token_out) 204 | 205 | return masks, iou_pred 206 | 207 | 208 | # Lightly adapted from 209 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 210 | class MLP(nn.Module): 211 | def __init__( 212 | self, 213 | input_dim: int, 214 | hidden_dim: int, 215 | output_dim: int, 216 | num_layers: int, 217 | sigmoid_output: bool = False, 218 | ) -> None: 219 | super().__init__() 220 | self.num_layers = num_layers 221 | h = [hidden_dim] * (num_layers - 1) 222 | self.layers = nn.ModuleList( 223 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 224 | ) 225 | self.sigmoid_output = sigmoid_output 226 | 227 | def forward(self, x): 228 | for i, layer in enumerate(self.layers): 229 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 230 | if self.sigmoid_output: 231 | x = F.sigmoid(x) 232 | return x 233 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | def forward( 54 | self, 55 | batched_input: List[Dict[str, Any]], 56 | multimask_output: bool, 57 | hq_token_only: bool =False, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 99 | interm_embeddings = interm_embeddings[0] # early layer 100 | 101 | outputs = [] 102 | for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | hq_token_only=hq_token_only, 119 | interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), 120 | ) 121 | masks = self.postprocess_masks( 122 | low_res_masks, 123 | input_size=image_record["image"].shape[-2:], 124 | original_size=image_record["original_size"], 125 | ) 126 | masks = masks > self.mask_threshold 127 | outputs.append( 128 | { 129 | "masks": masks, 130 | "iou_predictions": iou_predictions, 131 | "low_res_logits": low_res_masks, 132 | } 133 | ) 134 | return outputs 135 | 136 | def postprocess_masks( 137 | self, 138 | masks: torch.Tensor, 139 | input_size: Tuple[int, ...], 140 | original_size: Tuple[int, ...], 141 | ) -> torch.Tensor: 142 | """ 143 | Remove padding and upscale masks to the original image size. 144 | 145 | Arguments: 146 | masks (torch.Tensor): Batched masks from the mask_decoder, 147 | in BxCxHxW format. 148 | input_size (tuple(int, int)): The size of the image input to the 149 | model, in (H, W) format. Used to remove padding. 150 | original_size (tuple(int, int)): The original size of the image 151 | before resizing for input to the model, in (H, W) format. 152 | 153 | Returns: 154 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 155 | is given by original_size. 156 | """ 157 | masks = F.interpolate( 158 | masks, 159 | (self.image_encoder.img_size, self.image_encoder.img_size), 160 | mode="bilinear", 161 | align_corners=False, 162 | ) 163 | masks = masks[..., : input_size[0], : input_size[1]] 164 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 165 | return masks 166 | 167 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 168 | """Normalize pixel values and pad to a square input.""" 169 | # Normalize colors 170 | x = (x - self.pixel_mean) / self.pixel_std 171 | 172 | # Pad 173 | h, w = x.shape[-2:] 174 | padh = self.image_encoder.img_size - h 175 | padw = self.image_encoder.img_size - w 176 | x = F.pad(x, (0, padw, 0, padh)) 177 | return x 178 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | # import pdb;pdb.set_trace() 53 | if image_format != self.model.image_format: 54 | image = image[..., ::-1] 55 | 56 | # Transform the image to the form expected by the model 57 | # import pdb;pdb.set_trace() 58 | input_image = self.transform.apply_image(image) 59 | input_image_torch = torch.as_tensor(input_image, device=self.device) 60 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 61 | 62 | self.set_torch_image(input_image_torch, image.shape[:2]) 63 | 64 | @torch.no_grad() 65 | def set_torch_image( 66 | self, 67 | transformed_image: torch.Tensor, 68 | original_image_size: Tuple[int, ...], 69 | ) -> None: 70 | """ 71 | Calculates the image embeddings for the provided image, allowing 72 | masks to be predicted with the 'predict' method. Expects the input 73 | image to be already transformed to the format expected by the model. 74 | 75 | Arguments: 76 | transformed_image (torch.Tensor): The input image, with shape 77 | 1x3xHxW, which has been transformed with ResizeLongestSide. 78 | original_image_size (tuple(int, int)): The size of the image 79 | before transformation, in (H, W) format. 80 | """ 81 | assert ( 82 | len(transformed_image.shape) == 4 83 | and transformed_image.shape[1] == 3 84 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 85 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 86 | self.reset_image() 87 | 88 | self.original_size = original_image_size 89 | self.input_size = tuple(transformed_image.shape[-2:]) 90 | input_image = self.model.preprocess(transformed_image) 91 | self.features, self.interm_features = self.model.image_encoder(input_image) 92 | self.is_image_set = True 93 | 94 | def predict( 95 | self, 96 | point_coords: Optional[np.ndarray] = None, 97 | point_labels: Optional[np.ndarray] = None, 98 | box: Optional[np.ndarray] = None, 99 | mask_input: Optional[np.ndarray] = None, 100 | multimask_output: bool = True, 101 | return_logits: bool = False, 102 | hq_token_only: bool =False, 103 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 104 | """ 105 | Predict masks for the given input prompts, using the currently set image. 106 | 107 | Arguments: 108 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 109 | model. Each point is in (X,Y) in pixels. 110 | point_labels (np.ndarray or None): A length N array of labels for the 111 | point prompts. 1 indicates a foreground point and 0 indicates a 112 | background point. 113 | box (np.ndarray or None): A length 4 array given a box prompt to the 114 | model, in XYXY format. 115 | mask_input (np.ndarray): A low resolution mask input to the model, typically 116 | coming from a previous prediction iteration. Has form 1xHxW, where 117 | for SAM, H=W=256. 118 | multimask_output (bool): If true, the model will return three masks. 119 | For ambiguous input prompts (such as a single click), this will often 120 | produce better masks than a single prediction. If only a single 121 | mask is needed, the model's predicted quality score can be used 122 | to select the best mask. For non-ambiguous prompts, such as multiple 123 | input prompts, multimask_output=False can give better results. 124 | return_logits (bool): If true, returns un-thresholded masks logits 125 | instead of a binary mask. 126 | 127 | Returns: 128 | (np.ndarray): The output masks in CxHxW format, where C is the 129 | number of masks, and (H, W) is the original image size. 130 | (np.ndarray): An array of length C containing the model's 131 | predictions for the quality of each mask. 132 | (np.ndarray): An array of shape CxHxW, where C is the number 133 | of masks and H=W=256. These low resolution logits can be passed to 134 | a subsequent iteration as mask input. 135 | """ 136 | if not self.is_image_set: 137 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 138 | 139 | # Transform input prompts 140 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 141 | if point_coords is not None: 142 | assert ( 143 | point_labels is not None 144 | ), "point_labels must be supplied if point_coords is supplied." 145 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 146 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 147 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 148 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 149 | if box is not None: 150 | box = self.transform.apply_boxes(box, self.original_size) 151 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 152 | box_torch = box_torch[None, :] 153 | if mask_input is not None: 154 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 155 | mask_input_torch = mask_input_torch[None, :, :, :] 156 | 157 | masks, iou_predictions, low_res_masks = self.predict_torch( 158 | coords_torch, 159 | labels_torch, 160 | box_torch, 161 | mask_input_torch, 162 | multimask_output, 163 | return_logits=return_logits, 164 | hq_token_only=hq_token_only, 165 | ) 166 | 167 | masks_np = masks[0].detach().cpu().numpy() 168 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 169 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 170 | return masks_np, iou_predictions_np, low_res_masks_np 171 | 172 | @torch.no_grad() 173 | def predict_torch( 174 | self, 175 | point_coords: Optional[torch.Tensor], 176 | point_labels: Optional[torch.Tensor], 177 | boxes: Optional[torch.Tensor] = None, 178 | mask_input: Optional[torch.Tensor] = None, 179 | multimask_output: bool = True, 180 | return_logits: bool = False, 181 | hq_token_only: bool =False, 182 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 183 | """ 184 | Predict masks for the given input prompts, using the currently set image. 185 | Input prompts are batched torch tensors and are expected to already be 186 | transformed to the input frame using ResizeLongestSide. 187 | 188 | Arguments: 189 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 190 | model. Each point is in (X,Y) in pixels. 191 | point_labels (torch.Tensor or None): A BxN array of labels for the 192 | point prompts. 1 indicates a foreground point and 0 indicates a 193 | background point. 194 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 195 | model, in XYXY format. 196 | mask_input (np.ndarray): A low resolution mask input to the model, typically 197 | coming from a previous prediction iteration. Has form Bx1xHxW, where 198 | for SAM, H=W=256. Masks returned by a previous iteration of the 199 | predict method do not need further transformation. 200 | multimask_output (bool): If true, the model will return three masks. 201 | For ambiguous input prompts (such as a single click), this will often 202 | produce better masks than a single prediction. If only a single 203 | mask is needed, the model's predicted quality score can be used 204 | to select the best mask. For non-ambiguous prompts, such as multiple 205 | input prompts, multimask_output=False can give better results. 206 | return_logits (bool): If true, returns un-thresholded masks logits 207 | instead of a binary mask. 208 | 209 | Returns: 210 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 211 | number of masks, and (H, W) is the original image size. 212 | (torch.Tensor): An array of shape BxC containing the model's 213 | predictions for the quality of each mask. 214 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 215 | of masks and H=W=256. These low res logits can be passed to 216 | a subsequent iteration as mask input. 217 | """ 218 | if not self.is_image_set: 219 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 220 | 221 | if point_coords is not None: 222 | points = (point_coords, point_labels) 223 | else: 224 | points = None 225 | 226 | # Embed prompts 227 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 228 | points=points, 229 | boxes=boxes, 230 | masks=mask_input, 231 | ) 232 | 233 | # Predict masks 234 | low_res_masks, iou_predictions = self.model.mask_decoder( 235 | image_embeddings=self.features, 236 | image_pe=self.model.prompt_encoder.get_dense_pe(), 237 | sparse_prompt_embeddings=sparse_embeddings, 238 | dense_prompt_embeddings=dense_embeddings, 239 | multimask_output=multimask_output, 240 | hq_token_only=hq_token_only, 241 | interm_embeddings=self.interm_features, 242 | ) 243 | 244 | # Upscale the masks to the original image resolution 245 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 246 | 247 | if not return_logits: 248 | masks = masks > self.model.mask_threshold 249 | 250 | return masks, iou_predictions, low_res_masks 251 | 252 | def get_image_embedding(self) -> torch.Tensor: 253 | """ 254 | Returns the image embeddings for the currently set image, with 255 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 256 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 257 | """ 258 | if not self.is_image_set: 259 | raise RuntimeError( 260 | "An image must be set with .set_image(...) to generate an embedding." 261 | ) 262 | assert self.features is not None, "Features must exist if an image has been set." 263 | return self.features 264 | 265 | @property 266 | def device(self) -> torch.device: 267 | return self.model.device 268 | 269 | def reset_image(self) -> None: 270 | """Resets the currently set image.""" 271 | self.is_image_set = False 272 | self.features = None 273 | self.orig_h = None 274 | self.orig_w = None 275 | self.input_h = None 276 | self.input_w = None 277 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | hq_token_only: bool = False, 29 | multimask_output: bool = False, 30 | use_stability_score: bool = False, 31 | return_extra_metrics: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.hq_token_only = hq_token_only 38 | self.multimask_output = multimask_output 39 | self.use_stability_score = use_stability_score 40 | self.stability_score_offset = 1.0 41 | self.return_extra_metrics = return_extra_metrics 42 | 43 | @staticmethod 44 | def resize_longest_image_size( 45 | input_image_size: torch.Tensor, longest_side: int 46 | ) -> torch.Tensor: 47 | input_image_size = input_image_size.to(torch.float32) 48 | scale = longest_side / torch.max(input_image_size) 49 | transformed_size = scale * input_image_size 50 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 51 | return transformed_size 52 | 53 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 61 | point_labels == -1 62 | ) 63 | 64 | for i in range(self.model.prompt_encoder.num_point_embeddings): 65 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 66 | i 67 | ].weight * (point_labels == i) 68 | 69 | return point_embedding 70 | 71 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 72 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 73 | mask_embedding = mask_embedding + ( 74 | 1 - has_mask_input 75 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 76 | return mask_embedding 77 | 78 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 79 | masks = F.interpolate( 80 | masks, 81 | size=(self.img_size, self.img_size), 82 | mode="bilinear", 83 | align_corners=False, 84 | ) 85 | 86 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 87 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 88 | 89 | orig_im_size = orig_im_size.to(torch.int64) 90 | h, w = orig_im_size[0], orig_im_size[1] 91 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 92 | return masks 93 | 94 | 95 | @torch.no_grad() 96 | def forward( 97 | self, 98 | image_embeddings: torch.Tensor, 99 | interm_embeddings: torch.Tensor, 100 | point_coords: torch.Tensor, 101 | point_labels: torch.Tensor, 102 | mask_input: torch.Tensor, 103 | has_mask_input: torch.Tensor, 104 | orig_im_size: torch.Tensor, 105 | ): 106 | sparse_embedding = self._embed_points(point_coords, point_labels) 107 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 108 | 109 | vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT 110 | hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features) 111 | 112 | masks, scores = self.model.mask_decoder.predict_masks( 113 | image_embeddings=image_embeddings, 114 | image_pe=self.model.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embedding, 116 | dense_prompt_embeddings=dense_embedding, 117 | hq_features=hq_features, 118 | ) 119 | 120 | if self.use_stability_score: 121 | scores = calculate_stability_score( 122 | masks, self.model.mask_threshold, self.stability_score_offset 123 | ) 124 | 125 | if self.multimask_output: 126 | # mask with highest score 127 | mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1) 128 | scores = scores[:, mask_slice] 129 | scores, max_iou_idx = torch.max(scores,dim=1) 130 | scores = scores.unsqueeze(1) 131 | masks_multi = masks[:, mask_slice, :, :] 132 | masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) 133 | else: 134 | # singale mask output, default 135 | mask_slice = slice(0, 1) 136 | scores = scores[:,mask_slice] 137 | masks_sam = masks[:,mask_slice] 138 | 139 | masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)] 140 | 141 | if self.hq_token_only: 142 | masks = masks_hq 143 | else: 144 | masks = masks_sam + masks_hq 145 | 146 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 147 | 148 | if self.return_extra_metrics: 149 | stability_scores = calculate_stability_score( 150 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 151 | ) 152 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 153 | return upscaled_masks, scores, stability_scores, areas, masks 154 | 155 | return upscaled_masks, scores, masks 156 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime", "timm"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /train/segment_anything_training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | -------------------------------------------------------------------------------- /train/segment_anything_training/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) #[4,1,64,64] 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, #b 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | 74 | for i in range(depth): 75 | block = Block( 76 | dim=embed_dim, 77 | num_heads=num_heads, 78 | mlp_ratio=mlp_ratio, 79 | qkv_bias=qkv_bias, 80 | norm_layer=norm_layer, 81 | act_layer=act_layer, 82 | use_rel_pos=use_rel_pos, 83 | rel_pos_zero_init=rel_pos_zero_init, 84 | window_size=window_size if i not in global_attn_indexes else 0, 85 | input_size=(img_size // patch_size, img_size // patch_size), 86 | ) 87 | self.blocks.append(block) 88 | 89 | self.neck = nn.Sequential( 90 | nn.Conv2d( 91 | embed_dim, 92 | out_chans, 93 | kernel_size=1, 94 | bias=False, 95 | ), 96 | LayerNorm2d(out_chans), 97 | nn.Conv2d( 98 | out_chans, 99 | out_chans, 100 | kernel_size=3, 101 | padding=1, 102 | bias=False, 103 | ), 104 | LayerNorm2d(out_chans), 105 | ) 106 | 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | x = self.patch_embed(x) 110 | if self.pos_embed is not None: 111 | x = x + self.pos_embed 112 | interm_embeddings = [] 113 | for blk in self.blocks: 114 | x = blk(x) 115 | if blk.window_size == 0: 116 | interm_embeddings.append(x) 117 | 118 | x = self.neck(x.permute(0, 3, 1, 2)) 119 | 120 | return x, interm_embeddings 121 | 122 | 123 | class Block(nn.Module): 124 | """Transformer blocks with support of window attention and residual propagation blocks""" 125 | 126 | def __init__( 127 | self, 128 | dim: int, 129 | num_heads: int, 130 | mlp_ratio: float = 4.0, 131 | qkv_bias: bool = True, 132 | norm_layer: Type[nn.Module] = nn.LayerNorm, 133 | act_layer: Type[nn.Module] = nn.GELU, 134 | use_rel_pos: bool = False, 135 | rel_pos_zero_init: bool = True, 136 | window_size: int = 0, 137 | input_size: Optional[Tuple[int, int]] = None, 138 | ) -> None: 139 | """ 140 | Args: 141 | dim (int): Number of input channels. 142 | num_heads (int): Number of attention heads in each ViT block. 143 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 144 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 145 | norm_layer (nn.Module): Normalization layer. 146 | act_layer (nn.Module): Activation layer. 147 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 148 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 149 | window_size (int): Window size for window attention blocks. If it equals 0, then 150 | use global attention. 151 | input_size (int or None): Input resolution for calculating the relative positional 152 | parameter size. 153 | """ 154 | super().__init__() 155 | self.norm1 = norm_layer(dim) 156 | self.attn = Attention( 157 | dim, 158 | num_heads=num_heads, 159 | qkv_bias=qkv_bias, 160 | use_rel_pos=use_rel_pos, 161 | rel_pos_zero_init=rel_pos_zero_init, 162 | input_size=input_size if window_size == 0 else (window_size, window_size), 163 | ) 164 | 165 | self.norm2 = norm_layer(dim) 166 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 167 | 168 | self.window_size = window_size 169 | 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | shortcut = x 172 | x = self.norm1(x) 173 | if self.window_size > 0: 174 | H, W = x.shape[1], x.shape[2] 175 | x, pad_hw = window_partition(x, self.window_size) 176 | x = self.attn(x) 177 | # Reverse window partition 178 | if self.window_size > 0: 179 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 180 | 181 | x = shortcut + x 182 | x = x + self.mlp(self.norm2(x)) 183 | 184 | return x 185 | 186 | class Attention(nn.Module): 187 | """Multi-head Attention block with relative position embeddings.""" 188 | 189 | def __init__( 190 | self, 191 | dim: int, 192 | num_heads: int = 8, 193 | qkv_bias: bool = True, 194 | use_rel_pos: bool = False, 195 | rel_pos_zero_init: bool = True, 196 | input_size: Optional[Tuple[int, int]] = None, 197 | ) -> None: 198 | """ 199 | Args: 200 | dim (int): Number of input channels. 201 | num_heads (int): Number of attention heads. 202 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 203 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 204 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 205 | input_size (int or None): Input resolution for calculating the relative positional 206 | parameter size. 207 | """ 208 | super().__init__() 209 | self.num_heads = num_heads 210 | head_dim = dim // num_heads 211 | self.scale = head_dim**-0.5 212 | 213 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 214 | self.proj = nn.Linear(dim, dim) 215 | 216 | self.use_rel_pos = use_rel_pos 217 | if self.use_rel_pos: 218 | assert ( 219 | input_size is not None 220 | ), "Input size must be provided if using relative positional encoding." 221 | # initialize relative positional embeddings 222 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 223 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 224 | 225 | def forward(self, x: torch.Tensor) -> torch.Tensor: 226 | B, H, W, _ = x.shape 227 | # qkv with shape (3, B, nHead, H * W, C) 228 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 229 | # q, k, v with shape (B * nHead, H * W, C) 230 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 231 | 232 | attn = (q * self.scale) @ k.transpose(-2, -1) 233 | 234 | if self.use_rel_pos: 235 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 236 | 237 | attn = attn.softmax(dim=-1) 238 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 239 | x = self.proj(x) 240 | return x 241 | 242 | 243 | 244 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 245 | """ 246 | Partition into non-overlapping windows with padding if needed. 247 | Args: 248 | x (tensor): input tokens with [B, H, W, C]. 249 | window_size (int): window size. 250 | 251 | Returns: 252 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 253 | (Hp, Wp): padded height and width before partition 254 | """ 255 | B, H, W, C = x.shape 256 | 257 | pad_h = (window_size - H % window_size) % window_size 258 | pad_w = (window_size - W % window_size) % window_size 259 | if pad_h > 0 or pad_w > 0: 260 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 261 | Hp, Wp = H + pad_h, W + pad_w 262 | 263 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 264 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 265 | return windows, (Hp, Wp) 266 | 267 | 268 | def window_unpartition( 269 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 270 | ) -> torch.Tensor: 271 | """ 272 | Window unpartition into original sequences and removing padding. 273 | Args: 274 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 275 | window_size (int): window size. 276 | pad_hw (Tuple): padded height and width (Hp, Wp). 277 | hw (Tuple): original height and width (H, W) before padding. 278 | 279 | Returns: 280 | x: unpartitioned sequences with [B, H, W, C]. 281 | """ 282 | Hp, Wp = pad_hw 283 | H, W = hw 284 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 285 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 286 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 287 | 288 | if Hp > H or Wp > W: 289 | x = x[:, :H, :W, :].contiguous() 290 | return x 291 | 292 | 293 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 294 | """ 295 | Get relative positional embeddings according to the relative positions of 296 | query and key sizes. 297 | Args: 298 | q_size (int): size of query q. 299 | k_size (int): size of key k. 300 | rel_pos (Tensor): relative position embeddings (L, C). 301 | 302 | Returns: 303 | Extracted positional embeddings according to relative positions. 304 | """ 305 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 306 | # Interpolate rel pos if needed. 307 | if rel_pos.shape[0] != max_rel_dist: 308 | # Interpolate rel pos. 309 | rel_pos_resized = F.interpolate( 310 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 311 | size=max_rel_dist, 312 | mode="linear", 313 | ) 314 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 315 | else: 316 | rel_pos_resized = rel_pos 317 | 318 | # Scale the coords with short length if shapes for q and k are different. 319 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 320 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 321 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 322 | 323 | return rel_pos_resized[relative_coords.long()] 324 | 325 | 326 | def add_decomposed_rel_pos( 327 | attn: torch.Tensor, 328 | q: torch.Tensor, 329 | rel_pos_h: torch.Tensor, 330 | rel_pos_w: torch.Tensor, 331 | q_size: Tuple[int, int], 332 | k_size: Tuple[int, int], 333 | ) -> torch.Tensor: 334 | """ 335 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 336 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 337 | Args: 338 | attn (Tensor): attention map. 339 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 340 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 341 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 342 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 343 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 344 | 345 | Returns: 346 | attn (Tensor): attention map with added relative positional embeddings. 347 | """ 348 | q_h, q_w = q_size 349 | k_h, k_w = k_size 350 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 351 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 352 | 353 | B, _, dim = q.shape 354 | r_q = q.reshape(B, q_h, q_w, dim) 355 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 356 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 357 | 358 | attn = ( 359 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 360 | ).view(B, q_h * q_w, k_h * k_w) 361 | 362 | return attn 363 | 364 | 365 | class PatchEmbed(nn.Module): 366 | """ 367 | Image to Patch Embedding. 368 | """ 369 | 370 | def __init__( 371 | self, 372 | kernel_size: Tuple[int, int] = (16, 16), 373 | stride: Tuple[int, int] = (16, 16), 374 | padding: Tuple[int, int] = (0, 0), 375 | in_chans: int = 3, 376 | embed_dim: int = 768, 377 | ) -> None: 378 | """ 379 | Args: 380 | kernel_size (Tuple): kernel size of the projection layer. 381 | stride (Tuple): stride of the projection layer. 382 | padding (Tuple): padding size of the projection layer. 383 | in_chans (int): Number of input image channels. 384 | embed_dim (int): embed_dim (int): Patch embedding dimension. 385 | """ 386 | super().__init__() 387 | 388 | self.proj = nn.Conv2d( 389 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 390 | ) 391 | 392 | def forward(self, x: torch.Tensor) -> torch.Tensor: 393 | x = self.proj(x) 394 | # B C H W -> B H W C 395 | x = x.permute(0, 2, 3, 1) 396 | return x 397 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | 129 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 130 | b, c, h, w = src.shape 131 | 132 | # Run the transformer 133 | hs, src = self.transformer(src, pos_src, tokens) 134 | iou_token_out = hs[:, 0, :] 135 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 136 | 137 | # Upscale mask embeddings and predict masks using the mask tokens 138 | src = src.transpose(1, 2).view(b, c, h, w) 139 | upscaled_embedding = self.output_upscaling(src) 140 | hyper_in_list: List[torch.Tensor] = [] 141 | # -----mlp----- 142 | for i in range(self.num_mask_tokens): 143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 144 | hyper_in = torch.stack(hyper_in_list, dim=1) 145 | b, c, h, w = upscaled_embedding.shape 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 147 | # -----mlp----- 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x 179 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input promts, 89 | C is determiend by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | 99 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 100 | 101 | outputs = [] 102 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | ) 119 | 120 | masks = self.postprocess_masks( 121 | low_res_masks, 122 | input_size=image_record["image"].shape[-2:], 123 | original_size=image_record["original_size"], 124 | ) 125 | masks = masks > self.mask_threshold 126 | 127 | outputs.append( 128 | { 129 | "masks": masks, 130 | "iou_predictions": iou_predictions, 131 | "low_res_logits": low_res_masks, 132 | "encoder_embedding": curr_embedding.unsqueeze(0), 133 | "image_pe": self.prompt_encoder.get_dense_pe(), 134 | "sparse_embeddings":sparse_embeddings, 135 | "dense_embeddings":dense_embeddings, 136 | } 137 | ) 138 | 139 | return outputs, interm_embeddings 140 | 141 | def postprocess_masks( 142 | self, 143 | masks: torch.Tensor, 144 | input_size: Tuple[int, ...], 145 | original_size: Tuple[int, ...], 146 | ) -> torch.Tensor: 147 | """ 148 | Remove padding and upscale masks to the original image size. 149 | 150 | Arguments: 151 | masks (torch.Tensor): Batched masks from the mask_decoder, 152 | in BxCxHxW format. 153 | input_size (tuple(int, int)): The size of the image input to the 154 | model, in (H, W) format. Used to remove padding. 155 | original_size (tuple(int, int)): The original size of the image 156 | before resizing for input to the model, in (H, W) format. 157 | 158 | Returns: 159 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 160 | is given by original_size. 161 | """ 162 | masks = F.interpolate( 163 | masks, 164 | (self.image_encoder.img_size, self.image_encoder.img_size), 165 | mode="bilinear", 166 | align_corners=False, 167 | ) 168 | masks = masks[..., : input_size[0], : input_size[1]] 169 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 170 | return masks 171 | 172 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 173 | """Normalize pixel values and pad to a square input.""" 174 | # Normalize colors 175 | x = (x - self.pixel_mean) / self.pixel_std 176 | 177 | # Pad 178 | h, w = x.shape[-2:] 179 | padh = self.image_encoder.img_size - h 180 | padw = self.image_encoder.img_size - w 181 | x = F.pad(x, (0, padw, 0, padh)) 182 | return x 183 | -------------------------------------------------------------------------------- /train/segment_anything_training/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /train/segment_anything_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /train/segment_anything_training/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /train/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright by HQ-SAM team 2 | # All rights reserved. 3 | 4 | ## data loader 5 | from __future__ import print_function, division 6 | 7 | import numpy as np 8 | import random 9 | from copy import deepcopy 10 | from skimage import io 11 | import os 12 | from glob import glob 13 | import matplotlib.pyplot as plt 14 | 15 | import torch 16 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 17 | from torchvision import transforms, utils 18 | from torchvision.transforms.functional import normalize 19 | import torch.nn.functional as F 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | #### --------------------- dataloader online ---------------------#### 23 | 24 | def get_im_gt_name_dict(datasets, flag='valid'): 25 | print("------------------------------", flag, "--------------------------------") 26 | name_im_gt_list = [] 27 | 28 | for i in range(len(datasets)): 29 | print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") 30 | #tmp_im_list, tmp_gt_list = [], [] 31 | tmp_im_list, tmp_gt_list,tmp_eg_list = [], [],[] 32 | tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) 33 | print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) 34 | 35 | if(datasets[i]["gt_dir"]==""): 36 | print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') 37 | tmp_gt_list = [] 38 | else: 39 | tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] 40 | print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) 41 | 42 | if(datasets[i]["eg_dir"]==""): 43 | print('-eg-', datasets[i]["name"], datasets[i]["eg_dir"], ': ', 'No Ground Truth Found') 44 | tmp_eg_list = [] 45 | else: 46 | tmp_eg_list = [datasets[i]["eg_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["eg_ext"] for x in tmp_im_list] 47 | print('-eg-', datasets[i]["name"],datasets[i]["eg_dir"], ': ',len(tmp_eg_list)) 48 | 49 | name_im_gt_list.append({"dataset_name":datasets[i]["name"], 50 | "im_path":tmp_im_list, 51 | "gt_path":tmp_gt_list, 52 | "eg_path":tmp_eg_list, 53 | "im_ext":datasets[i]["im_ext"], 54 | "gt_ext":datasets[i]["gt_ext"], 55 | "eg_ext":datasets[i]["eg_ext"], 56 | }) 57 | 58 | return name_im_gt_list 59 | 60 | def create_dataloaders(name_im_gt_list, my_transforms=[], batch_size=1, training=False): 61 | gos_dataloaders = [] 62 | gos_datasets = [] 63 | 64 | if(len(name_im_gt_list)==0): 65 | return gos_dataloaders, gos_datasets 66 | 67 | num_workers_ = 1 68 | if(batch_size>1): 69 | num_workers_ = 2 70 | if(batch_size>4): 71 | num_workers_ = 4 72 | if(batch_size>8): 73 | num_workers_ = 8 74 | 75 | 76 | if training: 77 | for i in range(len(name_im_gt_list)): 78 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms)) 79 | gos_datasets.append(gos_dataset) 80 | 81 | gos_dataset = ConcatDataset(gos_datasets) 82 | sampler = DistributedSampler(gos_dataset) 83 | batch_sampler_train = torch.utils.data.BatchSampler( 84 | sampler, batch_size, drop_last=True) 85 | dataloader = DataLoader(gos_dataset, batch_sampler=batch_sampler_train, num_workers=num_workers_) 86 | 87 | gos_dataloaders = dataloader 88 | gos_datasets = gos_dataset 89 | 90 | else: 91 | for i in range(len(name_im_gt_list)): 92 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms), eval_ori_resolution = True) 93 | sampler = DistributedSampler(gos_dataset, shuffle=False) 94 | dataloader = DataLoader(gos_dataset, batch_size, sampler=sampler, drop_last=False, num_workers=num_workers_) 95 | 96 | gos_dataloaders.append(dataloader) 97 | gos_datasets.append(gos_dataset) 98 | 99 | return gos_dataloaders, gos_datasets 100 | 101 | class RandomHFlip(object): 102 | def __init__(self,prob=0.5): 103 | self.prob = prob 104 | def __call__(self,sample): 105 | imidx, image, label,edge_label, shape,im_name = sample['imidx'], sample['image'], sample['label'],sample['edge_label'], sample['shape'],sample['im_name'] 106 | 107 | # random horizontal flip 108 | if random.random() >= self.prob: 109 | image = torch.flip(image,dims=[2]) 110 | label = torch.flip(label,dims=[2]) 111 | edge_label = torch.flip(edge_label,dims=[2]) 112 | 113 | return {'imidx':imidx,'image':image, 'label':label,'edge_label':edge_label , 'shape':shape,'im_name':im_name} 114 | 115 | class Resize(object): 116 | def __init__(self,size=[320,320]): 117 | self.size = size 118 | def __call__(self,sample): 119 | imidx, image, label,edge_label, shape,im_name = sample['imidx'], sample['image'], sample['label'],sample['edge_label'], sample['shape'],sample['im_name'] 120 | 121 | image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0) 122 | label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0) 123 | edge_label = torch.squeeze(F.interpolate(torch.unsqueeze(edge_label,0),self.size,mode='bilinear'),dim=0) 124 | 125 | return {'imidx':imidx,'image':image, 'label':label,'edge_label':edge_label, 'shape':torch.tensor(self.size),'im_name':im_name} 126 | 127 | class RandomCrop(object): 128 | def __init__(self,size=[288,288]): 129 | self.size = size 130 | def __call__(self,sample): 131 | imidx, image, label,edge_label, shape,im_name = sample['imidx'], sample['image'], sample['label'],sample['edge_label'], sample['shape'],sample['im_name'] 132 | 133 | h, w = image.shape[1:] 134 | new_h, new_w = self.size 135 | 136 | top = np.random.randint(0, h - new_h) 137 | left = np.random.randint(0, w - new_w) 138 | 139 | image = image[:,top:top+new_h,left:left+new_w] 140 | label = label[:,top:top+new_h,left:left+new_w] 141 | edge_label = edge_label[:,top:top+new_h,left:left+new_w] 142 | 143 | return {'imidx':imidx,'image':image, 'label':label, 'edge_label':edge_label,'shape':torch.tensor(self.size),'im_name':im_name} 144 | 145 | 146 | class Normalize(object): 147 | def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): 148 | self.mean = mean 149 | self.std = std 150 | 151 | def __call__(self,sample): 152 | 153 | imidx, image, label, shape,edge_label,im_name= sample['imidx'], sample['image'], sample['label'], sample['shape'],sample['edge_label'],sample['im_name'] 154 | image = normalize(image,self.mean,self.std) 155 | 156 | return {'imidx':imidx,'image':image, 'label':label, 'edge_label':edge_label,'shape':shape,'im_name':im_name} 157 | 158 | 159 | 160 | class LargeScaleJitter(object): 161 | """ 162 | implementation of large scale jitter from copy_paste 163 | https://github.com/gaopengcuhk/Pretrained-Pix2Seq/blob/7d908d499212bfabd33aeaa838778a6bfb7b84cc/datasets/transforms.py 164 | """ 165 | 166 | def __init__(self, output_size=1024, aug_scale_min=0.1, aug_scale_max=2.0): 167 | self.desired_size = torch.tensor(output_size) 168 | self.aug_scale_min = aug_scale_min 169 | self.aug_scale_max = aug_scale_max 170 | 171 | def pad_target(self, padding, target): 172 | target = target.copy() 173 | if "masks" in target: 174 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0])) 175 | return target 176 | 177 | def __call__(self, sample): 178 | imidx, image, label, image_size,edge_label,im_name = sample['imidx'], sample['image'], sample['label'], sample['shape'],sample['edge_label'],sample['im_name'] 179 | #resize keep ratio 180 | out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() 181 | 182 | random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min 183 | scaled_size = (random_scale * self.desired_size).round() 184 | 185 | scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) 186 | scaled_size = (image_size * scale).round().long() 187 | 188 | scaled_image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),scaled_size.tolist(),mode='bilinear'),dim=0) 189 | scaled_label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),scaled_size.tolist(),mode='bilinear'),dim=0) 190 | scaled_edge_label = torch.squeeze(F.interpolate(torch.unsqueeze(edge_label,0),scaled_size.tolist(),mode='bilinear'),dim=0) 191 | 192 | # random crop 193 | crop_size = (min(self.desired_size, scaled_size[0]), min(self.desired_size, scaled_size[1])) 194 | 195 | margin_h = max(scaled_size[0] - crop_size[0], 0).item() 196 | margin_w = max(scaled_size[1] - crop_size[1], 0).item() 197 | offset_h = np.random.randint(0, margin_h + 1) 198 | offset_w = np.random.randint(0, margin_w + 1) 199 | crop_y1, crop_y2 = offset_h, offset_h + crop_size[0].item() 200 | crop_x1, crop_x2 = offset_w, offset_w + crop_size[1].item() 201 | 202 | scaled_image = scaled_image[:,crop_y1:crop_y2, crop_x1:crop_x2] 203 | scaled_label = scaled_label[:,crop_y1:crop_y2, crop_x1:crop_x2] 204 | scaled_edge_label = scaled_edge_label[:,crop_y1:crop_y2, crop_x1:crop_x2] 205 | 206 | # pad 207 | padding_h = max(self.desired_size - scaled_image.size(1), 0).item() 208 | padding_w = max(self.desired_size - scaled_image.size(2), 0).item() 209 | image = F.pad(scaled_image, [0,padding_w, 0,padding_h],value=128) 210 | label = F.pad(scaled_label, [0,padding_w, 0,padding_h],value=0) 211 | edge_label = F.pad(scaled_edge_label, [0,padding_w, 0,padding_h],value=0) 212 | 213 | return {'imidx':imidx,'image':image, 'label':label,"edge_label":edge_label,'shape':torch.tensor(image.shape[-2:]),'im_name':im_name} 214 | 215 | 216 | 217 | 218 | 219 | 220 | class OnlineDataset(Dataset): 221 | def __init__(self, name_im_gt_list, transform=None, eval_ori_resolution=False): 222 | 223 | self.transform = transform 224 | self.dataset = {} 225 | ## combine different datasets into one 226 | dataset_names = [] 227 | dt_name_list = [] # dataset name per image 228 | im_name_list = [] # image name 229 | im_path_list = [] # im path 230 | gt_path_list = [] # gt path 231 | eg_path_list = [] # eg path 232 | im_ext_list = [] # im ext 233 | gt_ext_list = [] # gt ext 234 | eg_ext_list = [] # eg ext 235 | for i in range(0,len(name_im_gt_list)): 236 | dataset_names.append(name_im_gt_list[i]["dataset_name"]) 237 | # dataset name repeated based on the number of images in this dataset 238 | dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) 239 | im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) 240 | im_path_list.extend(name_im_gt_list[i]["im_path"]) 241 | gt_path_list.extend(name_im_gt_list[i]["gt_path"]) 242 | eg_path_list.extend(name_im_gt_list[i]["eg_path"]) 243 | im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) 244 | gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) 245 | eg_ext_list.extend([name_im_gt_list[i]["eg_ext"] for x in name_im_gt_list[i]["eg_path"]]) 246 | 247 | 248 | self.dataset["data_name"] = dt_name_list 249 | self.dataset["im_name"] = im_name_list 250 | self.dataset["im_path"] = im_path_list 251 | self.dataset["ori_im_path"] = deepcopy(im_path_list) 252 | self.dataset["gt_path"] = gt_path_list 253 | self.dataset["ori_gt_path"] = deepcopy(gt_path_list) 254 | self.dataset["im_ext"] = im_ext_list 255 | self.dataset["gt_ext"] = gt_ext_list 256 | self.dataset["eg_path"] = eg_path_list 257 | self.dataset["ori_eg_path"] = deepcopy(eg_path_list) 258 | self.dataset["eg_ext"] = eg_ext_list 259 | 260 | self.eval_ori_resolution = eval_ori_resolution 261 | 262 | def __len__(self): 263 | return len(self.dataset["im_path"]) 264 | def __getitem__(self, idx): 265 | im_path = self.dataset["im_path"][idx] 266 | gt_path = self.dataset["gt_path"][idx] 267 | eg_path = self.dataset["eg_path"][idx] 268 | im_name = self.dataset["im_name"][idx] 269 | im = io.imread(im_path) 270 | gt = io.imread(gt_path) 271 | eg = io.imread(eg_path) 272 | if len(gt.shape) > 2: 273 | gt = gt[:, :, 0] 274 | if len(eg.shape) > 2: 275 | eg = eg[:, :, 0] 276 | if len(im.shape) < 3: 277 | im = im[:, :, np.newaxis] 278 | if im.shape[2] == 1: 279 | im = np.repeat(im, 3, axis=2) 280 | im = torch.tensor(im.copy(), dtype=torch.float32) 281 | im = torch.transpose(torch.transpose(im,1,2),0,1) 282 | gt = torch.unsqueeze(torch.tensor(gt, dtype=torch.float32),0) 283 | eg = torch.unsqueeze(torch.tensor(eg, dtype=torch.float32),0) 284 | 285 | sample = { 286 | "imidx": torch.from_numpy(np.array(idx)), 287 | "image": im, 288 | "label": gt, 289 | "edge_label":eg, 290 | "shape": torch.tensor(im.shape[-2:]), 291 | "im_name": im_name, 292 | } 293 | 294 | if self.transform: 295 | sample = self.transform(sample) 296 | 297 | if self.eval_ori_resolution: 298 | sample["ori_label"] = gt.type(torch.uint8) # NOTE for evaluation only. And no flip here 299 | sample['ori_im_path'] = self.dataset["im_path"][idx] 300 | sample['ori_gt_path'] = self.dataset["gt_path"][idx] 301 | sample['ori_eg_path'] = self.dataset["eg_path"][idx] 302 | 303 | return sample -------------------------------------------------------------------------------- /train/utils/loss_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import List, Optional 4 | import utils.misc as misc 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | def point_sample(input, point_coords, **kwargs): 9 | """ 10 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 11 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 12 | [0, 1] x [0, 1] square. 13 | Args: 14 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 15 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 16 | [0, 1] x [0, 1] normalized point coordinates. 17 | Returns: 18 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 19 | features for points in `point_coords`. The features are obtained via bilinear 20 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 21 | """ 22 | add_dim = False 23 | if point_coords.dim() == 3: 24 | add_dim = True 25 | point_coords = point_coords.unsqueeze(2) 26 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 27 | if add_dim: 28 | output = output.squeeze(3) 29 | return output 30 | 31 | def cat(tensors: List[torch.Tensor], dim: int = 0): 32 | """ 33 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 34 | """ 35 | assert isinstance(tensors, (list, tuple)) 36 | if len(tensors) == 1: 37 | return tensors[0] 38 | return torch.cat(tensors, dim) 39 | 40 | def get_uncertain_point_coords_with_randomness( 41 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 42 | ): 43 | """ 44 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 45 | are calculated for each point using 'uncertainty_func' function that takes point's logit 46 | prediction as input. 47 | See PointRend paper for details. 48 | Args: 49 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 50 | class-specific or class-agnostic prediction. 51 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 52 | contains logit predictions for P points and returns their uncertainties as a Tensor of 53 | shape (N, 1, P). 54 | num_points (int): The number of points P to sample. 55 | oversample_ratio (int): Oversampling parameter. 56 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 57 | Returns: 58 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 59 | sampled points. 60 | """ 61 | assert oversample_ratio >= 1 62 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 63 | num_boxes = coarse_logits.shape[0] 64 | num_sampled = int(num_points * oversample_ratio) 65 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 66 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 67 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 68 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 69 | # to incorrect results. 70 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 71 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 72 | # However, if we calculate uncertainties for the coarse predictions first, 73 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 74 | point_uncertainties = uncertainty_func(point_logits) 75 | num_uncertain_points = int(importance_sample_ratio * num_points) 76 | num_random_points = num_points - num_uncertain_points 77 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 78 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 79 | idx += shift[:, None] 80 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 81 | num_boxes, num_uncertain_points, 2 82 | ) 83 | if num_random_points > 0: 84 | point_coords = cat( 85 | [ 86 | point_coords, 87 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 88 | ], 89 | dim=1, 90 | ) 91 | return point_coords 92 | 93 | def edge_dice_loss( 94 | inputs: torch.Tensor, 95 | targets: torch.Tensor, 96 | num_masks:float, 97 | ): 98 | ep=1e-8 99 | inputs = F.interpolate(inputs, size=(1024, 1024), mode='bilinear', align_corners=False) 100 | inputs = inputs.sigmoid() 101 | intersection = 2*torch.sum(inputs*targets)+ep 102 | union = torch.sum(inputs)+torch.sum(targets)+ep 103 | dice = 1-intersection / union 104 | return dice/num_masks 105 | 106 | def dice_loss( 107 | inputs: torch.Tensor, 108 | targets: torch.Tensor, 109 | num_masks: float, 110 | ): 111 | """ 112 | Compute the DICE loss, similar to generalized IOU for masks 113 | Args: 114 | inputs: A float tensor of arbitrary shape. 115 | The predictions for each example. 116 | targets: A float tensor with the same shape as inputs. Stores the binary 117 | classification label for each element in inputs 118 | (0 for the negative class and 1 for the positive class). 119 | """ 120 | inputs = inputs.sigmoid() 121 | inputs = inputs.flatten(1) 122 | numerator = 2 * (inputs * targets).sum(-1) 123 | denominator = inputs.sum(-1) + targets.sum(-1) 124 | loss = 1 - (numerator + 1) / (denominator + 1) 125 | return loss.sum() / num_masks 126 | 127 | 128 | dice_loss_jit = torch.jit.script( 129 | dice_loss 130 | ) # type: torch.jit.ScriptModule 131 | 132 | def calculate_l1_loss(pred_mask, gt_mask): 133 | pred_mask1 = F.interpolate(pred_mask, size=(1024, 1024), mode='bilinear', align_corners=False) 134 | pred_mask = pred_mask1.sigmoid() 135 | pred_mask = torch.round(pred_mask) 136 | l1_loss = F.l1_loss(pred_mask, gt_mask) 137 | return l1_loss 138 | 139 | def sigmoid_ce_loss( 140 | inputs: torch.Tensor, 141 | targets: torch.Tensor, 142 | num_masks: float, 143 | ): 144 | """ 145 | Args: 146 | inputs: A float tensor of arbitrary shape. 147 | The predictions for each example. 148 | targets: A float tensor with the same shape as inputs. Stores the binary 149 | classification label for each element in inputs 150 | (0 for the negative class and 1 for the positive class). 151 | Returns: 152 | Loss tensor 153 | """ 154 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 155 | 156 | return loss.mean(1).sum() / num_masks 157 | 158 | 159 | sigmoid_ce_loss_jit = torch.jit.script( 160 | sigmoid_ce_loss 161 | ) # type: torch.jit.ScriptModule 162 | 163 | 164 | def calculate_uncertainty(logits): 165 | """ 166 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 167 | foreground class in `classes`. 168 | Args: 169 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 170 | class-agnostic, where R is the total number of predicted masks in all images and C is 171 | the number of foreground classes. The values are logits. 172 | Returns: 173 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 174 | the most uncertain locations having the highest uncertainty score. 175 | """ 176 | assert logits.shape[1] == 1 177 | gt_class_logits = logits.clone() 178 | return -(torch.abs(gt_class_logits)) 179 | 180 | def loss_masks(src_masks,edge_mask, target_masks,edge_target_masks,num_masks, oversample_ratio=3.0): 181 | """Compute the losses related to the masks: the focal loss and the dice loss. 182 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 183 | """ 184 | 185 | # No need to upsample predictions as we are using normalized coordinates :) 186 | 187 | with torch.no_grad(): 188 | # sample point_coords 189 | point_coords = get_uncertain_point_coords_with_randomness( 190 | src_masks, 191 | lambda logits: calculate_uncertainty(logits), 192 | 112 * 112, 193 | oversample_ratio, 194 | 0.75, 195 | ) 196 | edge_point_coords = get_uncertain_point_coords_with_randomness( 197 | edge_mask, 198 | lambda logits_edge: calculate_uncertainty(logits_edge), 199 | 112 * 112, 200 | oversample_ratio, 201 | 0.75, 202 | ) 203 | # get gt labels 204 | point_labels = point_sample( 205 | target_masks, 206 | point_coords, 207 | align_corners=False, 208 | ).squeeze(1) 209 | edge_point_labels = point_sample( 210 | edge_target_masks, 211 | point_coords, 212 | align_corners=False, 213 | ).squeeze(1) 214 | 215 | point_logits = point_sample( 216 | src_masks, 217 | point_coords, 218 | align_corners=False, 219 | ).squeeze(1) 220 | 221 | edge_point_logits = point_sample( 222 | edge_mask, 223 | edge_point_coords, 224 | align_corners=False, 225 | ).squeeze(1) 226 | loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks) 227 | loss_dice = dice_loss_jit(point_logits, point_labels, num_masks) 228 | loss_edge = dice_loss_jit(edge_point_logits,edge_point_labels,num_masks) + edge_dice_loss(edge_mask, edge_target_masks, num_masks) 229 | del src_masks 230 | del target_masks 231 | del edge_mask 232 | del edge_target_masks 233 | return loss_mask, loss_dice,loss_edge 234 | 235 | 236 | 237 | --------------------------------------------------------------------------------