├── utils ├── __init__.py ├── tools_ui.py ├── losses.py ├── transforms.py └── interactive_sampling.py ├── sum_on_hq-sam ├── .gitignore ├── figs │ ├── sam-hf-framework.png │ └── merged_iou_clean_HQSeg-44k-f1.png ├── segment_anything │ ├── utils │ │ ├── __init__.py │ │ ├── transforms.py │ │ └── onnx.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── mask_decoder.py │ │ ├── sam.py │ │ ├── transformer.py │ │ └── prompt_encoder.py │ ├── __init__.py │ ├── build_sam_baseline.py │ └── build_sam.py ├── train │ ├── segment_anything_training │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── transforms.py │ │ ├── __init__.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── mask_decoder.py │ │ │ ├── transformer.py │ │ │ ├── prompt_encoder.py │ │ │ └── sam.py │ │ └── build_sam.py │ ├── README.md │ └── utils │ │ ├── transforms.py │ │ ├── loss_mask.py │ │ └── dataloader.py ├── setup.cfg ├── setup.py ├── README.md ├── demo │ ├── demo_sam.py │ └── demo_hqsam.py └── LICENSE ├── .gitignore ├── __init__.py ├── modeling ├── __init__.py ├── common.py ├── mask_decoder.py ├── sam.py ├── transformer.py └── prompt_encoder.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted for the original code from Meta SAM 2 | -------------------------------------------------------------------------------- /sum_on_hq-sam/.gitignore: -------------------------------------------------------------------------------- 1 | eval_scripts/* 2 | demo/input_imgs/* 3 | scripts/* 4 | -------------------------------------------------------------------------------- /sum_on_hq-sam/figs/sam-hf-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kangningthu/SUM/HEAD/sum_on_hq-sam/figs/sam-hf-framework.png -------------------------------------------------------------------------------- /sum_on_hq-sam/figs/merged_iou_clean_HQSeg-44k-f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kangningthu/SUM/HEAD/sum_on_hq-sam/figs/merged_iou_clean_HQSeg-44k-f1.png -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | 6 | __pycache__/ 7 | 8 | 9 | .DS_Store 10 | 11 | 12 | # Jupyter Notebook 13 | .ipynb_checkpoints 14 | notebook/* 15 | utils/example_vis.py 16 | utils/eval_metrics.py 17 | utils/interactive_sampling_inference.py 18 | dataloaders/* 19 | 20 | -------------------------------------------------------------------------------- /utils/tools_ui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted based on Segment Anything Model from Meta Platforms, Inc. and affiliates. 2 | 3 | from .build_sam import ( 4 | build_sam, 5 | build_sam_vit_h, 6 | build_sam_vit_l, 7 | build_sam_vit_b, 8 | sam_model_registry, 9 | ) 10 | from .predictor import SamPredictor 11 | from .automatic_mask_generator import SamAutomaticMaskGenerator 12 | -------------------------------------------------------------------------------- /sum_on_hq-sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | 13 | 14 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder_hq import MaskDecoderHQ 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .build_sam_baseline import sam_model_registry_baseline 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /sum_on_hq-sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/README.md: -------------------------------------------------------------------------------- 1 | # Training instruction for SUM (HQ-SAM architecture) 2 | 3 | We closely adhere to the HQ-SAM training repository and framework in this implementation. To ensure a fair comparison, we have followed the HQ-SAM implementation and did not use the interactive training method for point sampling. For details on the interactive sampling method, please refer to the main folder. 4 | 5 | We organize the training folder as follows. 6 | ``` 7 | train 8 | |____data 9 | |____pretrained_checkpoint 10 | |____train.py 11 | |____utils 12 | | |____dataloader.py 13 | | |____misc.py 14 | | |____loss_mask.py 15 | |____segment_anything_training 16 | |____work_dirs 17 | ``` 18 | 19 | ## 1. Data Preparation 20 | 21 | HQSeg-44K can be downloaded from [hugging face link](https://huggingface.co/sam-hq-team/sam-hq-training/tree/main/data) 22 | 23 | ### Expected dataset structure for HQSeg-44K 24 | 25 | ``` 26 | data 27 | |____DIS5K 28 | |____cascade_psp 29 | | |____DUTS-TE 30 | | |____DUTS-TR 31 | | |____ecssd 32 | | |____fss_all 33 | | |____MSRA_10K 34 | |____thin_object_detection 35 | | |____COIFT 36 | | |____HRSOD 37 | | |____ThinObject5K 38 | ``` 39 | 40 | ### SAM1B 41 | ``` 42 | SAM1B dataset can be downloaded from the official website, you will need to obtain the SAM pseudo label and use the mask-refinement module to quantity the uncertainty map 43 | (##todo release the mask-refinement module and the uncertainty map) 44 | ``` 45 | 46 | ## 2. Init Checkpoint 47 | Init checkpoint can be downloaded from [hugging face link](https://huggingface.co/sam-hq-team/sam-hq-training/tree/main/pretrained_checkpoint) 48 | 49 | ### Expected checkpoint 50 | 51 | ``` 52 | pretrained_checkpoint 53 | |____sam_vit_b_maskdecoder.pth 54 | |____sam_vit_b_01ec64.pth 55 | |____sam_vit_l_maskdecoder.pth 56 | |____sam_vit_l_0b3195.pth 57 | |____sam_vit_h_maskdecoder.pth 58 | |____sam_vit_h_4b8939.pth 59 | 60 | ``` 61 | 62 | ## 3. Training 63 | To train HQ-SAM on HQSeg-44K dataset 64 | 65 | ``` 66 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=1333 train_uncertainty_aware.py \ 67 | --checkpoint xxx/sam_vit_h_4b8939.pth --model-type vit_h --output xxx \ 68 | --use_uncertainmap yes \ 69 | --min_ratio 100 \ 70 | --min_refine_ratio 100 \ 71 | --use_task_prompt_token yes \ 72 | --find_unused_params 73 | ``` 74 | 75 | 76 | ## 4. Evaluation 77 | To evaluate on 4 HQ-datasets for the bounding box prompt segmentation 78 | 79 | ``` 80 | python -m torch.distributed.launch --nproc_per_node= train_uncertainty_aware.py --checkpoint --model-type --output --eval --restore-model 81 | ``` 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uncertainty-aware Fine-tuning of Segmentation Foundation Models (SUM) 2 | 3 | 4 | Official implementation of **Uncertainty-aware Fine-tuning of Segmentation Foundation Models** (NeurIPS 2024). 5 | 6 | [Kangning Liu](https://kangning-liu.github.io/)1,2, [Brian Price](https://research.adobe.com/person/brian-price/)2, [Jason Kuen](https://research.adobe.com/person/jason-kuen/)2, [Yifei Fan](https://openreview.net/profile?id=~Yifei_Fan1)2, [Zijun Wei](https://scholar.google.com/citations?user=8l3bFYYAAAAJ&hl=en)2, [Luis Figueroa](https://luisf.me/)2, [Krzysztof J. Geras](https://cs.nyu.edu/~kgeras/)1, [Carlos Fernandez-Granda](https://math.nyu.edu/~cfgranda/)1 7 | 8 | 1 New York University 9 | 2 Adobe 10 | 11 | [NeurIPS 2024 Poster](https://neurips.cc/virtual/2024/poster/93500) 12 | 13 | [Project Website](https://kangning-liu.github.io/SUM_website/) 14 | 15 | ## Table of Contents 16 | 17 | - [Status Update](#status-update) 18 | - [Current Progress](#current-progress) 19 | - [Next Steps](#next-steps) 20 | - [Known Issues](#known-issues) 21 | - [Prerequisites](#prerequisites) 22 | - [Dataset](#dataset) 23 | - [Notebook](#notebook) 24 | - [Contact](#contact) 25 | 26 | 27 | 28 | ## Status Update 29 | 30 | ### Current Progress 31 | 32 | 33 | - **[NEW]** SUM (HQ-SAM arch.) Provide the training code and inference code of SUM implemented with HQ-SAM architecture [sum_on_hq-sam](sum_on_hq-sam) 34 | 35 | - Main experiments 36 | - Provided the model building code [build_sam.py](build_sam.py) 37 | - Provided the key components of uncertainty-aware fine-tuning for the main experiment: 38 | - Uncertainty-aware loss [losses.py](utils%2Flosses.py) 39 | - Uncertainty-aware prompt sampling [interactive_sampling.py](utils%2Finteractive_sampling.py) 40 | 41 | 42 | ### Next Steps 43 | - Main experiments 44 | - Provide demo Jupyter notebooks 45 | - Add support for the evaluation dataloader 46 | - Release model weights trained on the public dataset 47 | - Provide the full training code 48 | 49 | ### Known Issues 50 | - Some scripts may require additional dependencies not listed in the prerequisites. 51 | - Documentation is still in progress and may lack detailed instructions for some scripts. 52 | 53 | ## Prerequisites 54 | 55 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 56 | 57 | 58 | 59 | ## Dataset 60 | *TODO* 61 | 62 | 63 | ## Notebook 64 | 65 | *TODO* 66 | 67 | 68 | 69 | 70 | ## Contact 71 | For any questions or issues, please contact: 72 | - Kangning Liu - [kangning.liu@nyu.edu](mailto:kangning.liu@nyu.edu) 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/build_sam_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry_baseline = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torchvision.ops import sigmoid_focal_loss 6 | 7 | 8 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 9 | pred_mask = (pred_mask >= 0.5).float() 10 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) 11 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection 12 | epsilon = 1e-7 13 | batch_iou = intersection / (union + epsilon) 14 | 15 | batch_iou = batch_iou.unsqueeze(1) 16 | return batch_iou 17 | 18 | 19 | 20 | def uncertainty_aware_dice_loss(y_pred, y_true, smooth=1, uncertain_map=None, uncertain_map_threshold=0.5): 21 | # Ensure the input tensors have shape (batch_size, channels, height, width) 22 | # y_true and y_pred should have the same shape 23 | y_pred = F.sigmoid(y_pred) 24 | # y_pred = torch.clamp(y_pred, min=0, max=1) 25 | 26 | # if uncertain_map is not None: 27 | # # adjust the y_true based on the uncertain_map 28 | # y_true = y_true * (1 - uncertain_map) 29 | # y_pred = y_pred * (1 - uncertain_map) 30 | 31 | if uncertain_map is not None: 32 | # adjust the y_true and y_pred based on the uncertain_map 33 | y_true = y_true * ((1 - uncertain_map) > uncertain_map_threshold) 34 | y_pred = y_pred * ((1 - uncertain_map) > uncertain_map_threshold) 35 | 36 | 37 | # Compute the intersection and the sum of cardinalities per sample 38 | intersection = (y_true * y_pred).sum(dim=(1, 2, 3)) 39 | cardinalities = y_true.sum(dim=(1, 2, 3)) + y_pred.sum(dim=(1, 2, 3)) 40 | 41 | # Compute the DICE score per sample 42 | dice_scores = (2. * intersection + smooth) / (cardinalities + smooth) 43 | 44 | # print(dice_scores) 45 | 46 | # Return the DICE loss per sample 47 | return 1 - dice_scores 48 | 49 | 50 | 51 | def uncertainty_aware_focal_loss(pred_mask: torch.Tensor, gt_mask: torch.Tensor, uncertain_map: torch.Tensor, alpha: float = 0.25, gamma: float = 2): 52 | """ 53 | Compute the uncertainty-aware focal loss. 54 | 55 | This function calculates the focal loss and adjusts it based on an uncertainty map. 56 | The focal loss is reduced in regions where the uncertainty is high, which can help 57 | the model focus on more certain regions during training. 58 | 59 | Args: 60 | pred_mask (torch.Tensor): The predicted mask tensor of shape (N, H, W), 61 | where N is the batch size, H is the height, W is the width, 62 | and C is the number of channels. 63 | gt_mask (torch.Tensor): The ground truth mask tensor of shape (N, H, W. 64 | uncertain_map (torch.Tensor): The uncertainty map tensor of shape (N, H, W), 65 | where higher values indicate higher uncertainty. 66 | alpha (float, optional): The alpha parameter for the focal loss. Default is 0.25. 67 | gamma (float, optional): The gamma parameter for the focal loss. Default is 2. 68 | 69 | Returns: 70 | torch.Tensor: The computed uncertainty-aware focal loss, averaged over the batch, 71 | height, and width dimensions. 72 | """ 73 | # Calculate the focal loss without reduction 74 | loss_focal_tempt = sigmoid_focal_loss( 75 | pred_mask, 76 | gt_mask, 77 | gamma=gamma, 78 | alpha=alpha, 79 | reduction='none' 80 | ) 81 | 82 | # Multiply the focal loss by (1 - uncertain_map) to reduce the loss in uncertain regions 83 | loss_focal_tempt *= (1 - uncertain_map) 84 | 85 | # Average the loss over the batch, height, and width dimensions 86 | loss_focal_tempt = loss_focal_tempt.mean(dim=(1, 2, 3)) 87 | 88 | return loss_focal_tempt -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None,global_local_fusion=True): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | global_local_fusion=global_local_fusion, 22 | ) 23 | 24 | 25 | build_sam = build_sam_vit_h 26 | 27 | 28 | def build_sam_vit_l(checkpoint=None,global_local_fusion=True): 29 | return _build_sam( 30 | encoder_embed_dim=1024, 31 | encoder_depth=24, 32 | encoder_num_heads=16, 33 | encoder_global_attn_indexes=[5, 11, 17, 23], 34 | checkpoint=checkpoint, 35 | global_local_fusion=global_local_fusion, 36 | ) 37 | 38 | 39 | def build_sam_vit_b(checkpoint=None,global_local_fusion=True): 40 | return _build_sam( 41 | encoder_embed_dim=768, 42 | encoder_depth=12, 43 | encoder_num_heads=12, 44 | encoder_global_attn_indexes=[2, 5, 8, 11], 45 | checkpoint=checkpoint, 46 | global_local_fusion=global_local_fusion, 47 | ) 48 | 49 | 50 | sam_model_registry = { 51 | "default": build_sam_vit_h, 52 | "vit_h": build_sam_vit_h, 53 | "vit_l": build_sam_vit_l, 54 | "vit_b": build_sam_vit_b, 55 | } 56 | 57 | 58 | def _build_sam( 59 | encoder_embed_dim, 60 | encoder_depth, 61 | encoder_num_heads, 62 | encoder_global_attn_indexes, 63 | checkpoint=None, 64 | global_local_fusion=True, 65 | ): 66 | prompt_embed_dim = 256 67 | image_size = 1024 68 | vit_patch_size = 16 69 | image_embedding_size = image_size // vit_patch_size 70 | sam = Sam( 71 | image_encoder=ImageEncoderViT( 72 | depth=encoder_depth, 73 | embed_dim=encoder_embed_dim, 74 | img_size=image_size, 75 | mlp_ratio=4, 76 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 77 | num_heads=encoder_num_heads, 78 | patch_size=vit_patch_size, 79 | qkv_bias=True, 80 | use_rel_pos=True, 81 | global_attn_indexes=encoder_global_attn_indexes, 82 | window_size=14, 83 | out_chans=prompt_embed_dim, 84 | ), 85 | prompt_encoder=PromptEncoder( 86 | embed_dim=prompt_embed_dim, 87 | image_embedding_size=(image_embedding_size, image_embedding_size), 88 | input_image_size=(image_size, image_size), 89 | mask_in_chans=16, 90 | ), 91 | mask_decoder=MaskDecoderHQ( 92 | num_multimask_outputs=3, 93 | transformer=TwoWayTransformer( 94 | depth=2, 95 | embedding_dim=prompt_embed_dim, 96 | mlp_dim=2048, 97 | num_heads=8, 98 | ), 99 | transformer_dim=prompt_embed_dim, 100 | iou_head_depth=3, 101 | iou_head_hidden_dim=256, 102 | vit_dim=encoder_embed_dim, 103 | global_local_fusion=global_local_fusion, 104 | ), 105 | pixel_mean=[123.675, 116.28, 103.53], 106 | pixel_std=[58.395, 57.12, 57.375], 107 | ) 108 | # sam.eval() 109 | if checkpoint is not None: 110 | with open(checkpoint, "rb") as f: 111 | state_dict = torch.load(f) 112 | info = sam.load_state_dict(state_dict, strict=False) 113 | print(info) 114 | for n, p in sam.named_parameters(): 115 | if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: 116 | p.requires_grad = False 117 | 118 | return sam 119 | -------------------------------------------------------------------------------- /sum_on_hq-sam/README.md: -------------------------------------------------------------------------------- 1 | # SUM implemented with the HQ-SAM architecture 2 | 3 | This implementation builds on the HQ-SAM (Segment Anything in High Quality) architecture by incorporating the SUM method. We largely follow the HQ-SAM training approach, with the key difference being the addition of the sam1b dataset and the use of an uncertainty-aware training method. 4 | 5 | To ensure a fair comparison, we followed the HQ-SAM implementation and did not use the interactive training method for point sampling, unlike the main experiments. 6 | 7 | For more details, refer to our [paper](https://openreview.net/pdf?id=qNXRXUC90b). 8 | 9 | 10 | 11 | image 12 | 13 | 14 | Quantitative comparison between HQ-SAM and SUM (HQ-SAM architecture) 15 | ----------------- 16 | 17 | ### Comparison of HQ-SAM with Vanilla and SUM fine-tuned Using the Same Lightweight Scheme as HQ-SAM 18 | SUM Matches HQ-SAM and outperforms Vanilla in salient-object segmentation and is superior in entity and part segmentation. 19 | 20 | ![backbones](figs/merged_iou_clean_HQSeg-44k-f1.png) 21 | 22 | 23 | ### **Installation** 24 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 25 | 26 | Go to the repository locally and install with 27 | 28 | ``` 29 | pip install -e . 30 | ``` 31 | 32 | The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks. 33 | 34 | ``` 35 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 36 | ``` 37 | 38 | ### Example conda environment setup 39 | ```bash 40 | conda create --name sam_hq python=3.8 -y 41 | conda activate sam_hq 42 | conda install pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.1 -c pytorch -c nvidia 43 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 44 | 45 | # under your working directory 46 | git clone https://github.com/SysCV/sam-hq.git 47 | cd sam-hq 48 | pip install -e . 49 | export PYTHONPATH=$(pwd) 50 | ``` 51 | 52 | ### **Model Checkpoints** 53 | 54 | Three HQ-SAM model versions of the model are available with different backbone sizes. These models can be instantiated by running 55 | 56 | ``` 57 | from segment_anything import sam_model_registry 58 | sam = sam_model_registry[""](checkpoint="") 59 | ``` 60 | 61 | Download the provided trained model below and put them into the pretrained_checkpoint folder: 62 | ``` 63 | mkdir pretrained_checkpoint 64 | ``` 65 | 66 | Click the links below to download the checkpoint (##TODO) 67 | 68 | ### **Getting Started** 69 | 70 | First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt: 71 | 72 | ``` 73 | from segment_anything import SamPredictor, sam_model_registry 74 | sam = sam_model_registry[""](checkpoint="") 75 | predictor = SamPredictor(sam) 76 | predictor.set_image() 77 | masks, _, _ = predictor.predict() 78 | ``` 79 | 80 | Additionally, see the usage examples in our [demo](/demo/demo_hqsam.py) 81 | 82 | To obtain model's visual result: 83 | ``` 84 | python demo/demo_hqsam.py 85 | ``` 86 | 87 | ### **SUM (HQ-SAM architecture) Tuning** 88 | We provide detailed training, evaluation, visualization and data downloading instructions in [SUM (HQ-SAM architecture) training](train/README.md). 89 | 90 | Please change the current folder path to: 91 | ``` 92 | cd train 93 | ``` 94 | and then refer to detailed [readme instruction](train/README.md). 95 | 96 | 97 | 98 | 99 | ## Acknowledgments 100 | - Thanks [SAM](https://github.com/facebookresearch/segment-anything) for their public code and released models. 101 | - Thanks [HQ-SAM](https://github.com/SysCV/sam-hq) for their public code and released models, this implementation largely follows the repo -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sum_on_hq-sam/demo/demo_sam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from segment_anything import sam_model_registry_baseline, SamPredictor 6 | import os 7 | 8 | def show_mask(mask, ax, random_color=False): 9 | if random_color: 10 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 11 | else: 12 | color = np.array([30/255, 144/255, 255/255, 0.6]) 13 | h, w = mask.shape[-2:] 14 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 15 | ax.imshow(mask_image) 16 | 17 | def show_points(coords, labels, ax, marker_size=375): 18 | pos_points = coords[labels==1] 19 | neg_points = coords[labels==0] 20 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 21 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 22 | 23 | def show_box(box, ax): 24 | x0, y0 = box[0], box[1] 25 | w, h = box[2] - box[0], box[3] - box[1] 26 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 27 | 28 | 29 | def show_res(masks, scores, input_point, input_label, input_box, filename, image): 30 | for i, (mask, score) in enumerate(zip(masks, scores)): 31 | plt.figure(figsize=(10,10)) 32 | plt.imshow(image) 33 | show_mask(mask, plt.gca()) 34 | if input_box is not None: 35 | box = input_box[i] 36 | show_box(box, plt.gca()) 37 | if (input_point is not None) and (input_label is not None): 38 | show_points(input_point, input_label, plt.gca()) 39 | 40 | print(f"Score: {score:.3f}") 41 | plt.axis('off') 42 | plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1) 43 | plt.close() 44 | 45 | def show_res_multi(masks, scores, input_point, input_label, input_box, filename, image): 46 | plt.figure(figsize=(10, 10)) 47 | plt.imshow(image) 48 | for mask in masks: 49 | show_mask(mask, plt.gca(), random_color=True) 50 | for box in input_box: 51 | show_box(box, plt.gca()) 52 | for score in scores: 53 | print(f"Score: {score:.3f}") 54 | plt.axis('off') 55 | plt.savefig(filename +'.png',bbox_inches='tight',pad_inches=-0.1) 56 | plt.close() 57 | 58 | if __name__ == "__main__": 59 | sam_checkpoint = "./pretrained_checkpoint/sam_vit_l_0b3195.pth" 60 | model_type = "vit_l" 61 | device = "cuda" 62 | sam = sam_model_registry_baseline[model_type](checkpoint=sam_checkpoint) 63 | sam.to(device=device) 64 | predictor = SamPredictor(sam) 65 | 66 | for i in range(8): 67 | print("image: ",i) 68 | image = cv2.imread('demo/input_imgs/example'+str(i)+'.png') 69 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 70 | predictor.set_image(image) 71 | 72 | if i==0: 73 | input_box = np.array([[4,13,1007,1023]]) 74 | input_point, input_label = None, None 75 | elif i==1: 76 | input_box = np.array([[306, 132, 925, 893]]) 77 | input_point, input_label = None, None 78 | elif i==2: 79 | input_point = np.array([[495,518],[217,140]]) 80 | input_label = np.ones(input_point.shape[0]) 81 | input_box = None 82 | elif i==3: 83 | input_point = np.array([[221,482],[498,633],[750,379]]) 84 | input_label = np.ones(input_point.shape[0]) 85 | input_box = None 86 | elif i==4: 87 | input_box = np.array([[64,76,940,919]]) 88 | input_point, input_label = None, None 89 | elif i==5: 90 | input_point = np.array([[373,363], [452, 575]]) 91 | input_label = np.ones(input_point.shape[0]) 92 | input_box = None 93 | elif i==6: 94 | input_box = np.array([[181, 196, 757, 495]]) 95 | input_point, input_label = None, None 96 | elif i==7: 97 | # multi box input 98 | input_box = torch.tensor([[45,260,515,470], [310,228,424,296]],device=predictor.device) 99 | transformed_box = predictor.transform.apply_boxes_torch(input_box, image.shape[:2]) 100 | input_point, input_label = None, None 101 | 102 | batch_box = False if input_box is None else len(input_box)>1 103 | result_path = 'demo/baseline_sam_result/' 104 | os.makedirs(result_path, exist_ok=True) 105 | 106 | if not batch_box: 107 | masks, scores, logits = predictor.predict( 108 | point_coords=input_point, 109 | point_labels=input_label, 110 | box = input_box, 111 | multimask_output=False, 112 | ) 113 | show_res(masks,scores,input_point, input_label, input_box, result_path + 'example'+str(i), image) 114 | else: 115 | masks, scores, logits = predictor.predict_torch( 116 | point_coords=input_point, 117 | point_labels=input_label, 118 | boxes=transformed_box, 119 | multimask_output=False, 120 | ) 121 | masks = masks.squeeze(1).cpu().numpy() 122 | scores = scores.squeeze(1).cpu().numpy() 123 | input_box = input_box.cpu().numpy() 124 | show_res_multi(masks, scores, input_point, input_label, input_box, result_path + 'example'+str(i), image) 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /sum_on_hq-sam/demo/demo_hqsam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from segment_anything import sam_model_registry, SamPredictor 6 | import os 7 | 8 | def show_mask(mask, ax, random_color=False): 9 | if random_color: 10 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 11 | else: 12 | color = np.array([30/255, 144/255, 255/255, 0.6]) 13 | h, w = mask.shape[-2:] 14 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 15 | ax.imshow(mask_image) 16 | 17 | def show_points(coords, labels, ax, marker_size=375): 18 | pos_points = coords[labels==1] 19 | neg_points = coords[labels==0] 20 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 21 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 22 | 23 | def show_box(box, ax): 24 | x0, y0 = box[0], box[1] 25 | w, h = box[2] - box[0], box[3] - box[1] 26 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 27 | 28 | 29 | def show_res(masks, scores, input_point, input_label, input_box, filename, image): 30 | for i, (mask, score) in enumerate(zip(masks, scores)): 31 | plt.figure(figsize=(10,10)) 32 | plt.imshow(image) 33 | show_mask(mask, plt.gca()) 34 | if input_box is not None: 35 | box = input_box[i] 36 | show_box(box, plt.gca()) 37 | if (input_point is not None) and (input_label is not None): 38 | show_points(input_point, input_label, plt.gca()) 39 | 40 | print(f"Score: {score:.3f}") 41 | plt.axis('off') 42 | plt.savefig(filename+'_'+str(i)+'.png',bbox_inches='tight',pad_inches=-0.1) 43 | plt.close() 44 | 45 | def show_res_multi(masks, scores, input_point, input_label, input_box, filename, image): 46 | plt.figure(figsize=(10, 10)) 47 | plt.imshow(image) 48 | for mask in masks: 49 | show_mask(mask, plt.gca(), random_color=True) 50 | for box in input_box: 51 | show_box(box, plt.gca()) 52 | for score in scores: 53 | print(f"Score: {score:.3f}") 54 | plt.axis('off') 55 | plt.savefig(filename +'.png',bbox_inches='tight',pad_inches=-0.1) 56 | plt.close() 57 | 58 | 59 | if __name__ == "__main__": 60 | sam_checkpoint = "./pretrained_checkpoint/sam_hq_vit_l.pth" 61 | model_type = "vit_l" 62 | device = "cuda" 63 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 64 | sam.to(device=device) 65 | predictor = SamPredictor(sam) 66 | 67 | for i in range(8): 68 | print("image: ",i) 69 | # hq_token_only: False means use hq output to correct SAM output. 70 | # True means use hq output only. 71 | # Default: False 72 | hq_token_only = False 73 | # To achieve best visualization effect, for images contain multiple objects (like typical coco images), we suggest to set hq_token_only=False 74 | # For images contain single object, we suggest to set hq_token_only = True 75 | # For quantiative evaluation on COCO/YTVOS/DAVIS/UVO/LVIS etc., we set hq_token_only = False 76 | 77 | image = cv2.imread('demo/input_imgs/example'+str(i)+'.png') 78 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 79 | predictor.set_image(image) 80 | 81 | if i==0: 82 | input_box = np.array([[4,13,1007,1023]]) 83 | input_point, input_label = None, None 84 | elif i==1: 85 | input_box = np.array([[306, 132, 925, 893]]) 86 | input_point, input_label = None, None 87 | hq_token_only = True 88 | elif i==2: 89 | input_point = np.array([[495,518],[217,140]]) 90 | input_label = np.ones(input_point.shape[0]) 91 | input_box = None 92 | hq_token_only = True 93 | elif i==3: 94 | input_point = np.array([[221,482],[498,633],[750,379]]) 95 | input_label = np.ones(input_point.shape[0]) 96 | input_box = None 97 | elif i==4: 98 | input_box = np.array([[64,76,940,919]]) 99 | input_point, input_label = None, None 100 | hq_token_only = True 101 | elif i==5: 102 | input_point = np.array([[373,363], [452, 575]]) 103 | input_label = np.ones(input_point.shape[0]) 104 | input_box = None 105 | elif i==6: 106 | input_box = np.array([[181, 196, 757, 495]]) 107 | input_point, input_label = None, None 108 | elif i==7: 109 | # multi box input 110 | input_box = torch.tensor([[45,260,515,470], [310,228,424,296]],device=predictor.device) 111 | transformed_box = predictor.transform.apply_boxes_torch(input_box, image.shape[:2]) 112 | input_point, input_label = None, None 113 | 114 | batch_box = False if input_box is None else len(input_box)>1 115 | result_path = 'demo/hq_sam_result/' 116 | os.makedirs(result_path, exist_ok=True) 117 | 118 | if not batch_box: 119 | masks, scores, logits = predictor.predict( 120 | point_coords=input_point, 121 | point_labels=input_label, 122 | box = input_box, 123 | multimask_output=False, 124 | hq_token_only=hq_token_only, 125 | mask_type=1, # 1: task prompt saliency object segmentation; 0: SAM default task 126 | ) 127 | show_res(masks,scores,input_point, input_label, input_box, result_path + 'example'+str(i), image) 128 | 129 | else: 130 | masks, scores, logits = predictor.predict_torch( 131 | point_coords=input_point, 132 | point_labels=input_label, 133 | boxes=transformed_box, 134 | multimask_output=False, 135 | hq_token_only=hq_token_only, 136 | mask_type=1, 137 | ) 138 | masks = masks.squeeze(1).cpu().numpy() 139 | scores = scores.squeeze(1).cpu().numpy() 140 | input_box = input_box.cpu().numpy() 141 | show_res_multi(masks, scores, input_point, input_label, input_box, result_path + 'example'+str(i), image) 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | hq_token_only: bool = False, 29 | multimask_output: bool = False, 30 | use_stability_score: bool = False, 31 | return_extra_metrics: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.hq_token_only = hq_token_only 38 | self.multimask_output = multimask_output 39 | self.use_stability_score = use_stability_score 40 | self.stability_score_offset = 1.0 41 | self.return_extra_metrics = return_extra_metrics 42 | 43 | @staticmethod 44 | def resize_longest_image_size( 45 | input_image_size: torch.Tensor, longest_side: int 46 | ) -> torch.Tensor: 47 | input_image_size = input_image_size.to(torch.float32) 48 | scale = longest_side / torch.max(input_image_size) 49 | transformed_size = scale * input_image_size 50 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 51 | return transformed_size 52 | 53 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 61 | point_labels == -1 62 | ) 63 | 64 | for i in range(self.model.prompt_encoder.num_point_embeddings): 65 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 66 | i 67 | ].weight * (point_labels == i) 68 | 69 | return point_embedding 70 | 71 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 72 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 73 | mask_embedding = mask_embedding + ( 74 | 1 - has_mask_input 75 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 76 | return mask_embedding 77 | 78 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 79 | masks = F.interpolate( 80 | masks, 81 | size=(self.img_size, self.img_size), 82 | mode="bilinear", 83 | align_corners=False, 84 | ) 85 | 86 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 87 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 88 | 89 | orig_im_size = orig_im_size.to(torch.int64) 90 | h, w = orig_im_size[0], orig_im_size[1] 91 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 92 | return masks 93 | 94 | 95 | @torch.no_grad() 96 | def forward( 97 | self, 98 | image_embeddings: torch.Tensor, 99 | interm_embeddings: torch.Tensor, 100 | point_coords: torch.Tensor, 101 | point_labels: torch.Tensor, 102 | mask_input: torch.Tensor, 103 | has_mask_input: torch.Tensor, 104 | orig_im_size: torch.Tensor, 105 | ): 106 | sparse_embedding = self._embed_points(point_coords, point_labels) 107 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 108 | 109 | vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT 110 | hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features) 111 | 112 | masks, scores = self.model.mask_decoder.predict_masks( 113 | image_embeddings=image_embeddings, 114 | image_pe=self.model.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embedding, 116 | dense_prompt_embeddings=dense_embedding, 117 | hq_features=hq_features, 118 | ) 119 | 120 | if self.use_stability_score: 121 | scores = calculate_stability_score( 122 | masks, self.model.mask_threshold, self.stability_score_offset 123 | ) 124 | 125 | if self.multimask_output: 126 | # mask with highest score 127 | mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1) 128 | scores = scores[:, mask_slice] 129 | scores, max_iou_idx = torch.max(scores,dim=1) 130 | scores = scores.unsqueeze(1) 131 | masks_multi = masks[:, mask_slice, :, :] 132 | masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) 133 | else: 134 | # singale mask output, default 135 | mask_slice = slice(0, 1) 136 | scores = scores[:,mask_slice] 137 | masks_sam = masks[:,mask_slice] 138 | 139 | masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)] 140 | 141 | if self.hq_token_only: 142 | masks = masks_hq 143 | else: 144 | masks = masks_sam + masks_hq 145 | 146 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 147 | 148 | if self.return_extra_metrics: 149 | stability_scores = calculate_stability_score( 150 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 151 | ) 152 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 153 | return upscaled_masks, scores, stability_scores, areas, masks 154 | 155 | return upscaled_masks, scores, masks 156 | -------------------------------------------------------------------------------- /modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Aadapted from the Meta Segment Anything Model codebase 2 | 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from typing import List, Tuple, Type 11 | 12 | from .common import LayerNorm2d 13 | 14 | 15 | class MaskDecoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | transformer_dim: int, 20 | transformer: nn.Module, 21 | num_multimask_outputs: int = 3, 22 | activation: Type[nn.Module] = nn.GELU, 23 | iou_head_depth: int = 3, 24 | iou_head_hidden_dim: int = 256, 25 | ) -> None: 26 | """ 27 | Predicts masks given an image and prompt embeddings, using a 28 | transformer architecture. 29 | 30 | Arguments: 31 | transformer_dim (int): the channel dimension of the transformer 32 | transformer (nn.Module): the transformer used to predict masks 33 | num_multimask_outputs (int): the number of masks to predict 34 | when disambiguating masks 35 | activation (nn.Module): the type of activation to use when 36 | upscaling masks 37 | iou_head_depth (int): the depth of the MLP used to predict 38 | mask quality 39 | iou_head_hidden_dim (int): the hidden dimension of the MLP 40 | used to predict mask quality 41 | """ 42 | super().__init__() 43 | self.transformer_dim = transformer_dim 44 | self.transformer = transformer 45 | 46 | self.num_multimask_outputs = num_multimask_outputs 47 | 48 | self.iou_token = nn.Embedding(1, transformer_dim) 49 | self.num_mask_tokens = num_multimask_outputs + 1 50 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 51 | 52 | self.output_upscaling = nn.Sequential( 53 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 54 | LayerNorm2d(transformer_dim // 4), 55 | activation(), 56 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 57 | activation(), 58 | ) 59 | self.output_hypernetworks_mlps = nn.ModuleList( 60 | [ 61 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 62 | for i in range(self.num_mask_tokens) 63 | ] 64 | ) 65 | 66 | self.iou_prediction_head = MLP( 67 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 68 | ) 69 | 70 | def forward( 71 | self, 72 | image_embeddings: torch.Tensor, 73 | image_pe: torch.Tensor, 74 | sparse_prompt_embeddings: torch.Tensor, 75 | dense_prompt_embeddings: torch.Tensor, 76 | multimask_output: bool, 77 | ) -> Tuple[torch.Tensor, torch.Tensor]: 78 | """ 79 | Predict masks given image and prompt embeddings. 80 | 81 | Arguments: 82 | image_embeddings (torch.Tensor): the embeddings from the image encoder 83 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 84 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 85 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 86 | multimask_output (bool): Whether to return multiple masks or a single 87 | mask. 88 | 89 | Returns: 90 | torch.Tensor: batched predicted masks 91 | torch.Tensor: batched predictions of mask quality 92 | """ 93 | masks, iou_pred = self.predict_masks( 94 | image_embeddings=image_embeddings, 95 | image_pe=image_pe, 96 | sparse_prompt_embeddings=sparse_prompt_embeddings, 97 | dense_prompt_embeddings=dense_prompt_embeddings, 98 | ) 99 | 100 | # Select the correct mask or masks for output 101 | if multimask_output: 102 | mask_slice = slice(1, None) 103 | else: 104 | mask_slice = slice(0, 1) 105 | masks = masks[:, mask_slice, :, :] 106 | iou_pred = iou_pred[:, mask_slice] 107 | 108 | # Prepare output 109 | return masks, iou_pred 110 | 111 | def predict_masks( 112 | self, 113 | image_embeddings: torch.Tensor, 114 | image_pe: torch.Tensor, 115 | sparse_prompt_embeddings: torch.Tensor, 116 | dense_prompt_embeddings: torch.Tensor, 117 | ) -> Tuple[torch.Tensor, torch.Tensor]: 118 | """Predicts masks. See 'forward' for more details.""" 119 | # Concatenate output tokens 120 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 121 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 122 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 123 | 124 | # Expand per-image data in batch direction to be per-mask 125 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 126 | src = src + dense_prompt_embeddings 127 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 128 | b, c, h, w = src.shape 129 | 130 | # Run the transformer 131 | hs, src = self.transformer(src, pos_src, tokens) 132 | iou_token_out = hs[:, 0, :] 133 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 134 | 135 | # Upscale mask embeddings and predict masks using the mask tokens 136 | src = src.transpose(1, 2).view(b, c, h, w) 137 | upscaled_embedding = self.output_upscaling(src) 138 | hyper_in_list: List[torch.Tensor] = [] 139 | for i in range(self.num_mask_tokens): 140 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 141 | hyper_in = torch.stack(hyper_in_list, dim=1) 142 | b, c, h, w = upscaled_embedding.shape 143 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 144 | 145 | # Generate mask quality predictions 146 | iou_pred = self.iou_prediction_head(iou_token_out) 147 | 148 | return masks, iou_pred 149 | 150 | 151 | # Lightly adapted from 152 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 153 | class MLP(nn.Module): 154 | def __init__( 155 | self, 156 | input_dim: int, 157 | hidden_dim: int, 158 | output_dim: int, 159 | num_layers: int, 160 | sigmoid_output: bool = False, 161 | ) -> None: 162 | super().__init__() 163 | self.num_layers = num_layers 164 | h = [hidden_dim] * (num_layers - 1) 165 | self.layers = nn.ModuleList( 166 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 167 | ) 168 | self.sigmoid_output = sigmoid_output 169 | 170 | def forward(self, x): 171 | for i, layer in enumerate(self.layers): 172 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 173 | if self.sigmoid_output: 174 | x = F.sigmoid(x) 175 | return x 176 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | hq_token_only: bool, 79 | interm_embeddings: torch.Tensor, 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Predict masks given image and prompt embeddings. 83 | 84 | Arguments: 85 | image_embeddings (torch.Tensor): the embeddings from the image encoder 86 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 87 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 88 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 89 | multimask_output (bool): Whether to return multiple masks or a single 90 | mask. 91 | 92 | Returns: 93 | torch.Tensor: batched predicted masks 94 | torch.Tensor: batched predictions of mask quality 95 | """ 96 | masks, iou_pred = self.predict_masks( 97 | image_embeddings=image_embeddings, 98 | image_pe=image_pe, 99 | sparse_prompt_embeddings=sparse_prompt_embeddings, 100 | dense_prompt_embeddings=dense_prompt_embeddings, 101 | ) 102 | 103 | # Select the correct mask or masks for output 104 | if multimask_output: 105 | mask_slice = slice(1, None) 106 | else: 107 | mask_slice = slice(0, 1) 108 | masks = masks[:, mask_slice, :, :] 109 | iou_pred = iou_pred[:, mask_slice] 110 | 111 | # Prepare output 112 | return masks, iou_pred 113 | 114 | def predict_masks( 115 | self, 116 | image_embeddings: torch.Tensor, 117 | image_pe: torch.Tensor, 118 | sparse_prompt_embeddings: torch.Tensor, 119 | dense_prompt_embeddings: torch.Tensor, 120 | ) -> Tuple[torch.Tensor, torch.Tensor]: 121 | """Predicts masks. See 'forward' for more details.""" 122 | # Concatenate output tokens 123 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 124 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 125 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 126 | 127 | # Expand per-image data in batch direction to be per-mask 128 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 129 | src = src + dense_prompt_embeddings 130 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 131 | b, c, h, w = src.shape 132 | 133 | # Run the transformer 134 | hs, src = self.transformer(src, pos_src, tokens) 135 | iou_token_out = hs[:, 0, :] 136 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 137 | 138 | # Upscale mask embeddings and predict masks using the mask tokens 139 | src = src.transpose(1, 2).view(b, c, h, w) 140 | upscaled_embedding = self.output_upscaling(src) 141 | hyper_in_list: List[torch.Tensor] = [] 142 | for i in range(self.num_mask_tokens): 143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 144 | hyper_in = torch.stack(hyper_in_list, dim=1) 145 | b, c, h, w = upscaled_embedding.shape 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 147 | 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x 179 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | def forward( 54 | self, 55 | batched_input: List[Dict[str, Any]], 56 | multimask_output: bool, 57 | hq_token_only: bool =False, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 99 | interm_embeddings = interm_embeddings[0] # early layer 100 | 101 | outputs = [] 102 | for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | hq_token_only=hq_token_only, 119 | interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), 120 | ) 121 | masks = self.postprocess_masks( 122 | low_res_masks, 123 | input_size=image_record["image"].shape[-2:], 124 | original_size=image_record["original_size"], 125 | ) 126 | masks = masks > self.mask_threshold 127 | outputs.append( 128 | { 129 | "masks": masks, 130 | "iou_predictions": iou_predictions, 131 | "low_res_logits": low_res_masks, 132 | } 133 | ) 134 | return outputs 135 | 136 | def postprocess_masks( 137 | self, 138 | masks: torch.Tensor, 139 | input_size: Tuple[int, ...], 140 | original_size: Tuple[int, ...], 141 | ) -> torch.Tensor: 142 | """ 143 | Remove padding and upscale masks to the original image size. 144 | 145 | Arguments: 146 | masks (torch.Tensor): Batched masks from the mask_decoder, 147 | in BxCxHxW format. 148 | input_size (tuple(int, int)): The size of the image input to the 149 | model, in (H, W) format. Used to remove padding. 150 | original_size (tuple(int, int)): The original size of the image 151 | before resizing for input to the model, in (H, W) format. 152 | 153 | Returns: 154 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 155 | is given by original_size. 156 | """ 157 | masks = F.interpolate( 158 | masks, 159 | (self.image_encoder.img_size, self.image_encoder.img_size), 160 | mode="bilinear", 161 | align_corners=False, 162 | ) 163 | masks = masks[..., : input_size[0], : input_size[1]] 164 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 165 | return masks 166 | 167 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 168 | """Normalize pixel values and pad to a square input.""" 169 | # Normalize colors 170 | x = (x - self.pixel_mean) / self.pixel_std 171 | 172 | # Pad 173 | h, w = x.shape[-2:] 174 | padh = self.image_encoder.img_size - h 175 | padw = self.image_encoder.img_size - w 176 | x = F.pad(x, (0, padw, 0, padh)) 177 | return x 178 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/utils/loss_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import List, Optional 4 | import utils.misc as misc 5 | 6 | def point_sample(input, point_coords, **kwargs): 7 | """ 8 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 9 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 10 | [0, 1] x [0, 1] square. 11 | Args: 12 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 13 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 14 | [0, 1] x [0, 1] normalized point coordinates. 15 | Returns: 16 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 17 | features for points in `point_coords`. The features are obtained via bilinear 18 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 19 | """ 20 | add_dim = False 21 | if point_coords.dim() == 3: 22 | add_dim = True 23 | point_coords = point_coords.unsqueeze(2) 24 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 25 | if add_dim: 26 | output = output.squeeze(3) 27 | return output 28 | 29 | def cat(tensors: List[torch.Tensor], dim: int = 0): 30 | """ 31 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 32 | """ 33 | assert isinstance(tensors, (list, tuple)) 34 | if len(tensors) == 1: 35 | return tensors[0] 36 | return torch.cat(tensors, dim) 37 | 38 | def get_uncertain_point_coords_with_randomness( 39 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 40 | ): 41 | """ 42 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 43 | are calculated for each point using 'uncertainty_func' function that takes point's logit 44 | prediction as input. 45 | See PointRend paper for details. 46 | Args: 47 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 48 | class-specific or class-agnostic prediction. 49 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 50 | contains logit predictions for P points and returns their uncertainties as a Tensor of 51 | shape (N, 1, P). 52 | num_points (int): The number of points P to sample. 53 | oversample_ratio (int): Oversampling parameter. 54 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 55 | Returns: 56 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 57 | sampled points. 58 | """ 59 | assert oversample_ratio >= 1 60 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 61 | num_boxes = coarse_logits.shape[0] 62 | num_sampled = int(num_points * oversample_ratio) 63 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 64 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 65 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 66 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 67 | # to incorrect results. 68 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 69 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 70 | # However, if we calculate uncertainties for the coarse predictions first, 71 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 72 | point_uncertainties = uncertainty_func(point_logits) 73 | num_uncertain_points = int(importance_sample_ratio * num_points) 74 | num_random_points = num_points - num_uncertain_points 75 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 76 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 77 | idx += shift[:, None] 78 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 79 | num_boxes, num_uncertain_points, 2 80 | ) 81 | if num_random_points > 0: 82 | point_coords = cat( 83 | [ 84 | point_coords, 85 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 86 | ], 87 | dim=1, 88 | ) 89 | return point_coords 90 | 91 | def dice_loss( 92 | inputs: torch.Tensor, 93 | targets: torch.Tensor, 94 | num_masks: float, 95 | ): 96 | """ 97 | Compute the DICE loss, similar to generalized IOU for masks 98 | Args: 99 | inputs: A float tensor of arbitrary shape. 100 | The predictions for each example. 101 | targets: A float tensor with the same shape as inputs. Stores the binary 102 | classification label for each element in inputs 103 | (0 for the negative class and 1 for the positive class). 104 | """ 105 | inputs = inputs.sigmoid() 106 | inputs = inputs.flatten(1) 107 | numerator = 2 * (inputs * targets).sum(-1) 108 | denominator = inputs.sum(-1) + targets.sum(-1) 109 | loss = 1 - (numerator + 1) / (denominator + 1) 110 | return loss.sum() / num_masks 111 | 112 | 113 | dice_loss_jit = torch.jit.script( 114 | dice_loss 115 | ) # type: torch.jit.ScriptModule 116 | 117 | 118 | 119 | def sigmoid_ce_loss( 120 | inputs: torch.Tensor, 121 | targets: torch.Tensor, 122 | num_masks: float, 123 | ): 124 | """ 125 | Args: 126 | inputs: A float tensor of arbitrary shape. 127 | The predictions for each example. 128 | targets: A float tensor with the same shape as inputs. Stores the binary 129 | classification label for each element in inputs 130 | (0 for the negative class and 1 for the positive class). 131 | Returns: 132 | Loss tensor 133 | """ 134 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 135 | 136 | return loss.mean(1).sum() / num_masks 137 | 138 | 139 | 140 | dice_loss_jit = torch.jit.script( 141 | dice_loss 142 | ) # type: torch.jit.ScriptModule 143 | 144 | 145 | 146 | sigmoid_ce_loss_jit = torch.jit.script( 147 | sigmoid_ce_loss 148 | ) # type: torch.jit.ScriptModule 149 | 150 | 151 | def calculate_uncertainty(logits): 152 | """ 153 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 154 | foreground class in `classes`. 155 | Args: 156 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 157 | class-agnostic, where R is the total number of predicted masks in all images and C is 158 | the number of foreground classes. The values are logits. 159 | Returns: 160 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 161 | the most uncertain locations having the highest uncertainty score. 162 | """ 163 | assert logits.shape[1] == 1 164 | gt_class_logits = logits.clone() 165 | return -(torch.abs(gt_class_logits)) 166 | 167 | def loss_masks(src_masks, target_masks, num_masks, oversample_ratio=3.0): 168 | """Compute the losses related to the masks: the focal loss and the dice loss. 169 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 170 | """ 171 | 172 | # No need to upsample predictions as we are using normalized coordinates :) 173 | 174 | with torch.no_grad(): 175 | # sample point_coords 176 | point_coords = get_uncertain_point_coords_with_randomness( 177 | src_masks, 178 | lambda logits: calculate_uncertainty(logits), 179 | 112 * 112, 180 | oversample_ratio, 181 | 0.75, 182 | ) 183 | 184 | # print(point_coords) 185 | # get gt labels 186 | point_labels = point_sample( 187 | target_masks, 188 | point_coords, 189 | align_corners=False, 190 | ).squeeze(1) 191 | # print(point_labels) 192 | 193 | point_logits = point_sample( 194 | src_masks, 195 | point_coords, 196 | align_corners=False, 197 | ).squeeze(1) 198 | # print(point_logits) 199 | 200 | loss_mask = sigmoid_ce_loss_jit(point_logits, point_labels, num_masks) 201 | loss_dice = dice_loss_jit(point_logits, point_labels, num_masks) 202 | 203 | del src_masks 204 | del target_masks 205 | return loss_mask, loss_dice 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | 165 | 166 | def postprocess_masks_ori_res( 167 | self, 168 | masks: torch.Tensor, 169 | input_size: Tuple[int, ...], 170 | original_size: Tuple[int, ...], 171 | ) -> torch.Tensor: 172 | """ 173 | Remove padding and upscale masks to the original image size. 174 | 175 | Arguments: 176 | masks (torch.Tensor): Batched masks from the mask_decoder, 177 | in BxCxHxW format. 178 | input_size (tuple(int, int)): The size of the image input to the 179 | model, in (H, W) format. Used to remove padding. 180 | original_size (tuple(int, int)): The original size of the image 181 | before resizing for input to the model, in (H, W) format. 182 | 183 | Returns: 184 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 185 | is given by original_size. 186 | """ 187 | masks = F.interpolate( 188 | masks, 189 | (self.image_encoder.img_size, self.image_encoder.img_size), 190 | mode="bilinear", 191 | align_corners=False, 192 | ) 193 | masks = masks[..., : input_size[0], : input_size[1]] 194 | # masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 195 | return masks 196 | 197 | 198 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 199 | """Normalize pixel values and pad to a square input.""" 200 | # Normalize colors 201 | x = (x - self.pixel_mean) / self.pixel_std 202 | 203 | # Pad 204 | h, w = x.shape[-2:] 205 | padh = self.image_encoder.img_size - h 206 | padw = self.image_encoder.img_size - w 207 | x = F.pad(x, (0, padw, 0, padh)) 208 | return x 209 | -------------------------------------------------------------------------------- /modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /sum_on_hq-sam/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /utils/interactive_sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 7 | 8 | def add_jitter_to_bounding_boxes(bboxes: torch.Tensor, image_width: int, image_height: int, 9 | std_percentage: float = 0.1) -> torch.Tensor: 10 | """ 11 | Add jitter to the bounding boxes by adding random noise to each coordinate. 12 | The noise is sampled from a normal distribution with mean 0 and standard deviation 13 | equal to a percentage of the box side length. Noise is independently sampled for each bounding box. 14 | 15 | Args: 16 | bboxes (torch.Tensor): A tensor containing bounding boxes with shape (batch_size, 1, 4). 17 | Each bounding box is represented by [x_min, y_min, x_max, y_max]. 18 | image_width (int): The width of the image. 19 | image_height (int): The height of the image. 20 | std_percentage (float): Standard deviation as a percentage of the box side length. 21 | 22 | Returns: 23 | torch.Tensor: A tensor containing the jittered bounding boxes. 24 | """ 25 | 26 | # Check if input is a PyTorch tensor 27 | if not torch.is_tensor(bboxes): 28 | raise ValueError("Input bboxes should be a PyTorch tensor") 29 | 30 | # Initialize tensor to store jittered bounding boxes 31 | jittered_bboxes = torch.zeros_like(bboxes) 32 | 33 | # Iterate over each bounding box in the batch 34 | for i in range(bboxes.size(0)): 35 | # Extract x_min, y_min, x_max, y_max for the current bounding box 36 | x_min, y_min, x_max, y_max = bboxes[i, 0] 37 | 38 | # Calculate side lengths for the current bounding box 39 | width = x_max - x_min 40 | height = y_max - y_min 41 | 42 | # Sample noise independently for x and y coordinates 43 | noise_x = torch.clamp(torch.randn(1, device=bboxes.device) * width * std_percentage,0,20) 44 | noise_y = torch.clamp(torch.randn(1, device=bboxes.device) * height * std_percentage,0, 20) 45 | 46 | # Add noise to the coordinates for the current bounding box 47 | jittered_x_min = x_min + noise_x 48 | jittered_y_min = y_min + noise_y 49 | jittered_x_max = x_max + noise_x 50 | jittered_y_max = y_max + noise_y 51 | 52 | # Clamp the coordinates to lie within the image boundary 53 | jittered_x_min = torch.clamp(jittered_x_min, 0, image_width) 54 | jittered_y_min = torch.clamp(jittered_y_min, 0, image_height) 55 | jittered_x_max = torch.clamp(jittered_x_max, 0, image_width) 56 | jittered_y_max = torch.clamp(jittered_y_max, 0, image_height) 57 | 58 | # Store the jittered bounding box 59 | jittered_bboxes[i, 0, 0] = jittered_x_min 60 | jittered_bboxes[i, 0, 1] = jittered_y_min 61 | jittered_bboxes[i, 0, 2] = jittered_x_max 62 | jittered_bboxes[i, 0, 3] = jittered_y_max 63 | 64 | return jittered_bboxes 65 | 66 | 67 | def sample_different_position(pred_mask: torch.Tensor, gt_mask: torch.Tensor, backup_point, prob_weights=None): 68 | """ 69 | Samples a position where the prediction mask and ground truth mask are different. 70 | 71 | Args: 72 | pred_mask (torch.Tensor): The prediction mask (2D tensor). 73 | gt_mask (torch.Tensor): The ground truth mask (2D tensor). 74 | 75 | Returns: 76 | tuple: A tuple (position, gt_value) where: 77 | - position is a tensor of shape (1, 2) containing the sampled position, 78 | - gt_value is the value in the ground truth mask at the sampled position. 79 | If no differing points are found, returns (None, None). 80 | """ 81 | # Check if inputs are PyTorch tensors 82 | if not torch.is_tensor(pred_mask) or not torch.is_tensor(gt_mask): 83 | raise ValueError("Input masks should be PyTorch tensors") 84 | 85 | # Check if the masks have the same shape 86 | if pred_mask.shape != gt_mask.shape: 87 | raise ValueError("Input masks should have the same shape") 88 | 89 | # Finding the indices of positions where the masks are different 90 | differing_points = torch.nonzero(pred_mask != gt_mask) 91 | 92 | # If there are no differing points, return (None, None) 93 | if differing_points.size(0) == 0: 94 | # raise ValueError("two masks are the same! ") 95 | 96 | return backup_point, torch.ones(1, dtype=torch.int) 97 | 98 | 99 | # If probability weights are provided, sample according to the weights 100 | if prob_weights is not None: 101 | 102 | # Compute flat indices from differing_points 103 | flat_indices = differing_points[:, 0] * pred_mask.shape[1] + differing_points[:, 1] 104 | 105 | # Retrieve the corresponding weights 106 | weights = prob_weights.flatten()[flat_indices] 107 | 108 | 109 | # Handle zero weights 110 | if weights.sum() == 0: 111 | return backup_point, torch.ones(1, dtype=torch.int) 112 | 113 | # Sample from weights 114 | sample_idx_in_differing_points = torch.multinomial(weights, 1) 115 | 116 | # Get the sampled point 117 | sample_idx = differing_points[sample_idx_in_differing_points].float().view(1, -1) 118 | 119 | else: 120 | index = torch.randint(0, differing_points.size(0), (1,)) 121 | sample_idx = differing_points[index].float() 122 | 123 | sampled_position = sample_idx 124 | 125 | # Getting the value from the ground truth mask at the sampled position 126 | gt_value = gt_mask[tuple(sampled_position[0].long())] 127 | 128 | # Returning as (position tensor with the shape (1, 2), gt_value) 129 | return torch.flip(sampled_position, [1]), gt_value 130 | 131 | 132 | def sample_positive_point_from_binary_mask(mask: torch.Tensor, backup_point, prob_weights=None) -> torch.Tensor: 133 | """ 134 | Samples a point from a binary mask where the mask label is positive. 135 | 136 | Args: 137 | mask (torch.Tensor): A binary mask (2D tensor). 138 | 139 | Returns: 140 | torch.Tensor: A tensor of shape (1, 2) containing the sampled point, or None if no positive points. 141 | """ 142 | # Check if input is a PyTorch tensor 143 | if not torch.is_tensor(mask): 144 | raise ValueError("Input mask should be a PyTorch tensor") 145 | 146 | # Finding the indices of positive points 147 | positive_points = torch.nonzero(mask == 1) 148 | 149 | # If there are no positive points, return backup_point 150 | if positive_points.size(0) == 0: 151 | return backup_point 152 | 153 | # If probability weights are provided, sample according to the weights 154 | if prob_weights is not None: 155 | 156 | flat_indices = positive_points[:, 0] * mask.shape[1] + positive_points[:, 1] 157 | 158 | # Retrieve the corresponding weights 159 | weights = prob_weights.flatten()[flat_indices] 160 | 161 | if weights.sum() == 0: 162 | print("weights sum is zero init! ") 163 | return backup_point 164 | 165 | # Sample from weights 166 | sample_idx_in_positive_points = torch.multinomial(weights, 1) 167 | 168 | # Get the sampled point 169 | sample_idx = positive_points[sample_idx_in_positive_points].float() 170 | else: 171 | index = torch.randint(0, positive_points.size(0), (1,)) 172 | sample_idx = positive_points[index].float() 173 | 174 | sample_idx = sample_idx.view(1, -1) 175 | sampled_point = sample_idx 176 | 177 | return torch.flip(sampled_point, [1]) 178 | 179 | 180 | def sample_point_frommask(gt_binary_mask_stack, backup_points, binary_mask=None, previous_points=None, device=0, prob_weights=None): 181 | """ 182 | random sample point from the ground truth mask 183 | Args: 184 | gt_binary_mask_stack: the ground truth mask 185 | binary_mask: the predicted mask 186 | previous_points: the previous sampled points 187 | """ 188 | coords_torch_list = [] 189 | 190 | if binary_mask is None or previous_points is None: 191 | input_label_torch = torch.ones((gt_binary_mask_stack.size()[0], 1), dtype=torch.int) 192 | else: 193 | input_label_torch = torch.ones((previous_points[1].size()[0], previous_points[1].size()[1] + 1), 194 | dtype=torch.int) 195 | input_label_torch[:, :-1] = previous_points[1] 196 | 197 | for mask_index in range(gt_binary_mask_stack.size()[0]): 198 | gt_mask = gt_binary_mask_stack[mask_index][0] # 256, 256 199 | 200 | backup_point = backup_points[mask_index] 201 | assert len(gt_mask.size()) == 2 # two dimension case 202 | 203 | if prob_weights is not None: 204 | mask_prob_weights = prob_weights[mask_index][0] # 256, 256 205 | else: 206 | mask_prob_weights = None 207 | 208 | if not binary_mask is None: 209 | # assert previous_points is not None 210 | 211 | pred_mask = binary_mask[mask_index][0] # 256, 256 212 | sampled_position, gt_value = sample_different_position(pred_mask, gt_mask, backup_point, prob_weights=mask_prob_weights) 213 | # sampled_position size is (1, 2) 214 | # previous_points is tuple, previous_points[0] is the coords_torch_all, get the corresponding coords_torch 215 | if previous_points is not None: 216 | coords_torch = torch.cat((previous_points[0][mask_index], sampled_position.to(device)), dim=0) 217 | coords_torch_list.append(coords_torch) 218 | input_label_torch[mask_index, -1] = gt_value # assign the label for that point 219 | else: 220 | coords_torch = sampled_position 221 | coords_torch_list.append(coords_torch) 222 | input_label_torch[mask_index, -1] = gt_value 223 | else: 224 | coords_torch = sample_positive_point_from_binary_mask(gt_mask, backup_point, prob_weights=mask_prob_weights) 225 | coords_torch_list.append(coords_torch) 226 | 227 | coords_torch_all = torch.stack(coords_torch_list, dim=0) 228 | return (coords_torch_all.to(device), input_label_torch.to(device)) 229 | -------------------------------------------------------------------------------- /modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | use_task_prompt = False, 25 | zero_token_no_grad = False, 26 | ) -> None: 27 | """ 28 | Encodes prompts for input to SAM's mask decoder. 29 | 30 | Arguments: 31 | embed_dim (int): The prompts' embedding dimension 32 | image_embedding_size (tuple(int, int)): The spatial size of the 33 | image embedding, as (H, W). 34 | input_image_size (int): The padded size of the image as input 35 | to the image encoder, as (H, W). 36 | mask_in_chans (int): The number of hidden channels used for 37 | encoding input masks. 38 | activation (nn.Module): The activation to use when encoding 39 | input masks. 40 | """ 41 | super().__init__() 42 | self.embed_dim = embed_dim 43 | self.input_image_size = input_image_size 44 | self.image_embedding_size = image_embedding_size 45 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 46 | 47 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 48 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 49 | self.point_embeddings = nn.ModuleList(point_embeddings) 50 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 51 | 52 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 53 | self.mask_downscaling = nn.Sequential( 54 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(mask_in_chans // 4), 56 | activation(), 57 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 58 | LayerNorm2d(mask_in_chans), 59 | activation(), 60 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 61 | ) 62 | self.no_mask_embed = nn.Embedding(1, embed_dim) 63 | 64 | # add the mask quality token 65 | 66 | task_prompt_embedding = [nn.Embedding(1, embed_dim) for i in range(4)] # 4 task prompt types: 0: 'un-specified', 1:'salient object segmentation', 2:'entityseg', 3:'part segmentation' (if available in the fine-tuning data) 67 | self.task_prompt_embedding = nn.ModuleList(task_prompt_embedding) 68 | for i in range(4): 69 | nn.init.zeros_(self.task_prompt_embedding[i].weight) 70 | if not use_task_prompt: 71 | self.task_prompt_embedding[i].weight.data.zero_() 72 | self.task_prompt_embedding[i].weight.requires_grad = False # this will stay as zero 73 | 74 | if zero_token_no_grad: 75 | self.task_prompt_embedding[0].weight.requires_grad = False 76 | 77 | 78 | 79 | 80 | def get_dense_pe(self) -> torch.Tensor: 81 | """ 82 | Returns the positional encoding used to encode point prompts, 83 | applied to a dense set of points the shape of the image encoding. 84 | 85 | Returns: 86 | torch.Tensor: Positional encoding with shape 87 | 1x(embed_dim)x(embedding_h)x(embedding_w) 88 | """ 89 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 90 | 91 | def _embed_points( 92 | self, 93 | points: torch.Tensor, 94 | labels: torch.Tensor, 95 | pad: bool, 96 | ) -> torch.Tensor: 97 | """Embeds point prompts.""" 98 | points = points + 0.5 # Shift to center of pixel 99 | # print(points.size()) 100 | # print(labels.size()) 101 | if pad: 102 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 103 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 104 | points = torch.cat([points, padding_point], dim=1) 105 | labels = torch.cat([labels, padding_label], dim=1) 106 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 107 | point_embedding[labels == -1] = 0.0 108 | point_embedding[labels == -1] += self.not_a_point_embed.weight 109 | point_embedding[labels == 0] += self.point_embeddings[0].weight 110 | point_embedding[labels == 1] += self.point_embeddings[1].weight 111 | return point_embedding 112 | 113 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 114 | """Embeds box prompts.""" 115 | boxes = boxes + 0.5 # Shift to center of pixel 116 | coords = boxes.reshape(-1, 2, 2) 117 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 118 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 119 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 120 | return corner_embedding 121 | 122 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 123 | """Embeds mask inputs.""" 124 | mask_embedding = self.mask_downscaling(masks) 125 | return mask_embedding 126 | 127 | def _get_batch_size( 128 | self, 129 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 130 | boxes: Optional[torch.Tensor], 131 | masks: Optional[torch.Tensor], 132 | ) -> int: 133 | """ 134 | Gets the batch size of the output given the batch size of the input prompts. 135 | """ 136 | if points is not None: 137 | return points[0].shape[0] 138 | elif boxes is not None: 139 | return boxes.shape[0] 140 | elif masks is not None: 141 | return masks.shape[0] 142 | else: 143 | return 1 144 | 145 | def _get_device(self) -> torch.device: 146 | return self.point_embeddings[0].weight.device 147 | 148 | def forward( 149 | self, 150 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 151 | boxes: Optional[torch.Tensor], 152 | masks: Optional[torch.Tensor], 153 | task_prompt = 0, 154 | ) -> Tuple[torch.Tensor, torch.Tensor]: 155 | """ 156 | Embeds different types of prompts, returning both sparse and dense 157 | embeddings. 158 | 159 | Arguments: 160 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 161 | and labels to embed. 162 | boxes (torch.Tensor or none): boxes to embed 163 | masks (torch.Tensor or none): masks to embed 164 | task_prompt (str): type of mask to embed, 0: 'un-specified', 1:'salient object segmentation', 2:'entityseg' 165 | 166 | Returns: 167 | torch.Tensor: sparse embeddings for the points and boxes, with shape 168 | BxNx(embed_dim), where N is determined by the number of input points 169 | and boxes. 170 | torch.Tensor: dense embeddings for the masks, in the shape 171 | Bx(embed_dim)x(embed_H)x(embed_W) 172 | """ 173 | bs = self._get_batch_size(points, boxes, masks) 174 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 175 | if points is not None: 176 | coords, labels = points 177 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 178 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 179 | if boxes is not None: 180 | box_embeddings = self._embed_boxes(boxes) 181 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 182 | 183 | if masks is not None: 184 | dense_embeddings = self._embed_masks(masks) + self.task_prompt_embedding[task_prompt].weight.reshape(1, -1, 1, 1).expand( 185 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 186 | ) 187 | else: 188 | dense_embeddings = (self.no_mask_embed.weight + self.task_prompt_embedding[task_prompt].weight).reshape(1, -1, 1, 1).expand( 189 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 190 | ) 191 | 192 | return sparse_embeddings, dense_embeddings 193 | 194 | 195 | class PositionEmbeddingRandom(nn.Module): 196 | """ 197 | Positional encoding using random spatial frequencies. 198 | """ 199 | 200 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 201 | super().__init__() 202 | if scale is None or scale <= 0.0: 203 | scale = 1.0 204 | self.register_buffer( 205 | "positional_encoding_gaussian_matrix", 206 | scale * torch.randn((2, num_pos_feats)), 207 | ) 208 | 209 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 210 | """Positionally encode points that are normalized to [0,1].""" 211 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 212 | coords = 2 * coords - 1 213 | coords = coords @ self.positional_encoding_gaussian_matrix 214 | coords = 2 * np.pi * coords 215 | # outputs d_1 x ... x d_n x C shape 216 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 217 | 218 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 219 | """Generate positional encoding for a grid of the specified size.""" 220 | h, w = size 221 | device: Any = self.positional_encoding_gaussian_matrix.device 222 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 223 | y_embed = grid.cumsum(dim=0) - 0.5 224 | x_embed = grid.cumsum(dim=1) - 0.5 225 | y_embed = y_embed / h 226 | x_embed = x_embed / w 227 | 228 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 229 | return pe.permute(2, 0, 1) # C x H x W 230 | 231 | def forward_with_coords( 232 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 233 | ) -> torch.Tensor: 234 | """Positionally encode points that are not normalized to [0,1].""" 235 | coords = coords_input.clone() 236 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 237 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 238 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 239 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright by HQ-SAM team 2 | # All rights reserved. 3 | 4 | ## data loader 5 | from __future__ import print_function, division 6 | 7 | import numpy as np 8 | import random 9 | from copy import deepcopy 10 | from skimage import io 11 | import os 12 | from glob import glob 13 | 14 | import torch 15 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 16 | from torchvision import transforms, utils 17 | from torchvision.transforms.functional import normalize 18 | import torch.nn.functional as F 19 | from torch.utils.data.distributed import DistributedSampler 20 | 21 | #### --------------------- dataloader online ---------------------#### 22 | 23 | def get_im_gt_name_dict(datasets, flag='valid'): 24 | print("------------------------------", flag, "--------------------------------") 25 | name_im_gt_list = [] 26 | 27 | for i in range(len(datasets)): 28 | print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") 29 | tmp_im_list, tmp_gt_list = [], [] 30 | tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) 31 | print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) 32 | 33 | if(datasets[i]["gt_dir"]==""): 34 | print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') 35 | tmp_gt_list = [] 36 | else: 37 | tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] 38 | print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) 39 | 40 | 41 | name_im_gt_list.append({"dataset_name":datasets[i]["name"], 42 | "im_path":tmp_im_list, 43 | "gt_path":tmp_gt_list, 44 | "im_ext":datasets[i]["im_ext"], 45 | "gt_ext":datasets[i]["gt_ext"]}) 46 | 47 | return name_im_gt_list 48 | 49 | def create_dataloaders(name_im_gt_list, my_transforms=[], batch_size=1, training=False): 50 | gos_dataloaders = [] 51 | gos_datasets = [] 52 | 53 | if(len(name_im_gt_list)==0): 54 | return gos_dataloaders, gos_datasets 55 | 56 | num_workers_ = 1 57 | if(batch_size>1): 58 | num_workers_ = 2 59 | if(batch_size>4): 60 | num_workers_ = 4 61 | if(batch_size>8): 62 | num_workers_ = 8 63 | 64 | 65 | if training: 66 | for i in range(len(name_im_gt_list)): 67 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms)) 68 | gos_datasets.append(gos_dataset) 69 | 70 | gos_dataset = ConcatDataset(gos_datasets) 71 | sampler = DistributedSampler(gos_dataset) 72 | batch_sampler_train = torch.utils.data.BatchSampler( 73 | sampler, batch_size, drop_last=True) 74 | dataloader = DataLoader(gos_dataset, batch_sampler=batch_sampler_train, num_workers=num_workers_) 75 | 76 | gos_dataloaders = dataloader 77 | gos_datasets = gos_dataset 78 | 79 | else: 80 | for i in range(len(name_im_gt_list)): 81 | gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms), eval_ori_resolution = True) 82 | sampler = DistributedSampler(gos_dataset, shuffle=False) 83 | dataloader = DataLoader(gos_dataset, batch_size, sampler=sampler, drop_last=False, num_workers=num_workers_) 84 | 85 | gos_dataloaders.append(dataloader) 86 | gos_datasets.append(gos_dataset) 87 | 88 | return gos_dataloaders, gos_datasets 89 | 90 | class RandomHFlip(object): 91 | def __init__(self,prob=0.5): 92 | self.prob = prob 93 | def __call__(self,sample): 94 | imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] 95 | 96 | # random horizontal flip 97 | if random.random() >= self.prob: 98 | image = torch.flip(image,dims=[2]) 99 | label = torch.flip(label,dims=[2]) 100 | 101 | return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} 102 | 103 | class Resize(object): 104 | def __init__(self,size=[320,320]): 105 | self.size = size 106 | def __call__(self,sample): 107 | imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] 108 | 109 | image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0) 110 | label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0) 111 | 112 | return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} 113 | 114 | class RandomCrop(object): 115 | def __init__(self,size=[288,288]): 116 | self.size = size 117 | def __call__(self,sample): 118 | imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] 119 | 120 | h, w = image.shape[1:] 121 | new_h, new_w = self.size 122 | 123 | top = np.random.randint(0, h - new_h) 124 | left = np.random.randint(0, w - new_w) 125 | 126 | image = image[:,top:top+new_h,left:left+new_w] 127 | label = label[:,top:top+new_h,left:left+new_w] 128 | 129 | return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} 130 | 131 | 132 | class Normalize(object): 133 | def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): 134 | self.mean = mean 135 | self.std = std 136 | 137 | def __call__(self,sample): 138 | 139 | imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] 140 | image = normalize(image,self.mean,self.std) 141 | 142 | return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} 143 | 144 | 145 | 146 | class LargeScaleJitter(object): 147 | """ 148 | implementation of large scale jitter from copy_paste 149 | https://github.com/gaopengcuhk/Pretrained-Pix2Seq/blob/7d908d499212bfabd33aeaa838778a6bfb7b84cc/datasets/transforms.py 150 | """ 151 | 152 | def __init__(self, output_size=1024, aug_scale_min=0.1, aug_scale_max=2.0): 153 | self.desired_size = torch.tensor(output_size) 154 | self.aug_scale_min = aug_scale_min 155 | self.aug_scale_max = aug_scale_max 156 | 157 | def pad_target(self, padding, target): 158 | target = target.copy() 159 | if "masks" in target: 160 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0])) 161 | return target 162 | 163 | def __call__(self, sample): 164 | imidx, image, label, image_size = sample['imidx'], sample['image'], sample['label'], sample['shape'] 165 | 166 | #resize keep ratio 167 | out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() 168 | 169 | random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min 170 | scaled_size = (random_scale * self.desired_size).round() 171 | 172 | scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) 173 | scaled_size = (image_size * scale).round().long() 174 | 175 | scaled_image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),scaled_size.tolist(),mode='bilinear'),dim=0) 176 | scaled_label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),scaled_size.tolist(),mode='bilinear'),dim=0) 177 | 178 | # random crop 179 | crop_size = (min(self.desired_size, scaled_size[0]), min(self.desired_size, scaled_size[1])) 180 | 181 | margin_h = max(scaled_size[0] - crop_size[0], 0).item() 182 | margin_w = max(scaled_size[1] - crop_size[1], 0).item() 183 | offset_h = np.random.randint(0, margin_h + 1) 184 | offset_w = np.random.randint(0, margin_w + 1) 185 | crop_y1, crop_y2 = offset_h, offset_h + crop_size[0].item() 186 | crop_x1, crop_x2 = offset_w, offset_w + crop_size[1].item() 187 | 188 | scaled_image = scaled_image[:,crop_y1:crop_y2, crop_x1:crop_x2] 189 | scaled_label = scaled_label[:,crop_y1:crop_y2, crop_x1:crop_x2] 190 | 191 | # pad 192 | padding_h = max(self.desired_size - scaled_image.size(1), 0).item() 193 | padding_w = max(self.desired_size - scaled_image.size(2), 0).item() 194 | image = F.pad(scaled_image, [0,padding_w, 0,padding_h],value=128) 195 | label = F.pad(scaled_label, [0,padding_w, 0,padding_h],value=0) 196 | 197 | return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(image.shape[-2:])} 198 | 199 | 200 | 201 | 202 | 203 | 204 | class OnlineDataset(Dataset): 205 | def __init__(self, name_im_gt_list, transform=None, eval_ori_resolution=False): 206 | 207 | self.transform = transform 208 | self.dataset = {} 209 | ## combine different datasets into one 210 | dataset_names = [] 211 | dt_name_list = [] # dataset name per image 212 | im_name_list = [] # image name 213 | im_path_list = [] # im path 214 | gt_path_list = [] # gt path 215 | im_ext_list = [] # im ext 216 | gt_ext_list = [] # gt ext 217 | for i in range(0,len(name_im_gt_list)): 218 | dataset_names.append(name_im_gt_list[i]["dataset_name"]) 219 | # dataset name repeated based on the number of images in this dataset 220 | dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) 221 | im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) 222 | im_path_list.extend(name_im_gt_list[i]["im_path"]) 223 | gt_path_list.extend(name_im_gt_list[i]["gt_path"]) 224 | im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) 225 | gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) 226 | 227 | 228 | self.dataset["data_name"] = dt_name_list 229 | self.dataset["im_name"] = im_name_list 230 | self.dataset["im_path"] = im_path_list 231 | self.dataset["ori_im_path"] = deepcopy(im_path_list) 232 | self.dataset["gt_path"] = gt_path_list 233 | self.dataset["ori_gt_path"] = deepcopy(gt_path_list) 234 | self.dataset["im_ext"] = im_ext_list 235 | self.dataset["gt_ext"] = gt_ext_list 236 | 237 | self.eval_ori_resolution = eval_ori_resolution 238 | 239 | def __len__(self): 240 | return len(self.dataset["im_path"]) 241 | def __getitem__(self, idx): 242 | im_path = self.dataset["im_path"][idx] 243 | gt_path = self.dataset["gt_path"][idx] 244 | im = io.imread(im_path) 245 | gt = io.imread(gt_path) 246 | 247 | if len(gt.shape) > 2: 248 | gt = gt[:, :, 0] 249 | if len(im.shape) < 3: 250 | im = im[:, :, np.newaxis] 251 | if im.shape[2] == 1: 252 | im = np.repeat(im, 3, axis=2) 253 | im = torch.tensor(im.copy(), dtype=torch.float32) 254 | im = torch.transpose(torch.transpose(im,1,2),0,1) 255 | gt = torch.unsqueeze(torch.tensor(gt, dtype=torch.float32),0) 256 | 257 | sample = { 258 | "imidx": torch.from_numpy(np.array(idx)), 259 | "image": im, 260 | "label": gt, 261 | "shape": torch.tensor(im.shape[-2:]), 262 | } 263 | 264 | if self.transform: 265 | sample = self.transform(sample) 266 | 267 | if self.eval_ori_resolution: 268 | sample["ori_label"] = gt.type(torch.uint8) # NOTE for evaluation only. And no flip here 269 | sample['ori_im_path'] = self.dataset["im_path"][idx] 270 | sample['ori_gt_path'] = self.dataset["gt_path"][idx] 271 | 272 | return sample -------------------------------------------------------------------------------- /sum_on_hq-sam/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /sum_on_hq-sam/train/segment_anything_training/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input promts, 89 | C is determiend by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | 99 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 100 | 101 | outputs = [] 102 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output 118 | ) 119 | 120 | masks = self.postprocess_masks( 121 | low_res_masks, 122 | input_size=image_record["image"].shape[-2:], 123 | original_size=image_record["original_size"], 124 | ) 125 | assert image_record["image"].shape[-2:] == image_record["original_size"] 126 | masks = masks > self.mask_threshold 127 | 128 | outputs.append( 129 | { 130 | "masks": masks, 131 | "iou_predictions": iou_predictions, 132 | "low_res_logits": low_res_masks, 133 | "encoder_embedding": curr_embedding.unsqueeze(0), 134 | "image_pe": self.prompt_encoder.get_dense_pe(), 135 | "sparse_embeddings":sparse_embeddings, 136 | "dense_embeddings":dense_embeddings, 137 | } 138 | ) 139 | 140 | return outputs, interm_embeddings 141 | 142 | @torch.no_grad() 143 | def forward_one_example( 144 | self, 145 | image_record: Dict[str, Any], 146 | multimask_output: bool, 147 | num_masks: int 148 | ) -> [Dict[str, torch.Tensor]]: 149 | """ 150 | Predicts masks end-to-end from provided images and prompts. 151 | If prompts are not known in advance, using SamPredictor is 152 | recommended over calling the model directly. 153 | 154 | Arguments: 155 | image_record ((dict)): A dict input images, each a 156 | dictionary with the following keys. A prompt key can be 157 | excluded if it is not present. 158 | 'image': The image as a torch tensor in 3xHxW format, 159 | already transformed for input to the model. 160 | 'original_size': (tuple(int, int)) The original size of 161 | the image before transformation, as (H, W). 162 | 'point_coords': (torch.Tensor) Batched point prompts for 163 | this image, with shape BxNx2. Already transformed to the 164 | input frame of the model. 165 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 166 | with shape BxN. 167 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 168 | Already transformed to the input frame of the model. 169 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 170 | in the form Bx1xHxW. 171 | multimask_output (bool): Whether the model should predict multiple 172 | disambiguating masks, or return a single mask. 173 | 174 | Returns: 175 | (list(dict)): A list over input images, where each element is 176 | as dictionary with the following keys. 177 | 'masks': (torch.Tensor) Batched binary mask predictions, 178 | with shape BxCxHxW, where B is the number of input promts, 179 | C is determiend by multimask_output, and (H, W) is the 180 | original size of the image. 181 | 'iou_predictions': (torch.Tensor) The model's predictions 182 | of mask quality, in shape BxC. 183 | 'low_res_logits': (torch.Tensor) Low resolution logits with 184 | shape BxCxHxW, where H=W=256. Can be passed as mask input 185 | to subsequent iterations of prediction. 186 | """ 187 | # just a single image 188 | # input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 189 | # input_images = self.preprocess(image_record["image"]).unsqueeze(0) 190 | input_images = image_record["image"] 191 | # print(input_images.shape) 192 | 193 | 194 | image_embeddings, interm_embeddings = self.image_encoder(input_images) 195 | 196 | if "point_coords" in image_record: 197 | points = (image_record["point_coords"], image_record["point_labels"]) 198 | else: 199 | points = None 200 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 201 | points=points, 202 | boxes=image_record.get("boxes", None), 203 | masks=image_record.get("mask_inputs", None), 204 | ) 205 | 206 | 207 | low_res_masks, iou_predictions = self.mask_decoder( 208 | image_embeddings=image_embeddings, 209 | image_pe=self.prompt_encoder.get_dense_pe(), 210 | sparse_prompt_embeddings=sparse_embeddings, 211 | dense_prompt_embeddings=dense_embeddings, 212 | multimask_output=multimask_output 213 | ) 214 | 215 | 216 | 217 | masks = self.postprocess_masks( 218 | low_res_masks, 219 | input_size=image_record["image"].shape[-2:], 220 | original_size=image_record["image"].shape[-2:], 221 | ) 222 | 223 | 224 | masks = masks > self.mask_threshold 225 | 226 | outputs = { 227 | "masks": masks, 228 | "iou_predictions": iou_predictions, 229 | "low_res_logits": low_res_masks, 230 | "encoder_embedding": image_embeddings, 231 | "image_pe": self.prompt_encoder.get_dense_pe(), 232 | "sparse_embeddings": sparse_embeddings, 233 | "dense_embeddings": dense_embeddings, 234 | } 235 | 236 | 237 | return outputs, interm_embeddings 238 | def postprocess_masks( 239 | self, 240 | masks: torch.Tensor, 241 | input_size: Tuple[int, ...], 242 | original_size: Tuple[int, ...], 243 | ) -> torch.Tensor: 244 | """ 245 | Remove padding and upscale masks to the original image size. 246 | 247 | Arguments: 248 | masks (torch.Tensor): Batched masks from the mask_decoder, 249 | in BxCxHxW format. 250 | input_size (tuple(int, int)): The size of the image input to the 251 | model, in (H, W) format. Used to remove padding. 252 | original_size (tuple(int, int)): The original size of the image 253 | before resizing for input to the model, in (H, W) format. 254 | 255 | Returns: 256 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 257 | is given by original_size. 258 | """ 259 | masks = F.interpolate( 260 | masks, 261 | (self.image_encoder.img_size, self.image_encoder.img_size), 262 | mode="bilinear", 263 | align_corners=False, 264 | ) 265 | masks = masks[..., : input_size[0], : input_size[1]] 266 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 267 | return masks 268 | 269 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 270 | """Normalize pixel values and pad to a square input.""" 271 | # Normalize colors 272 | x = (x - self.pixel_mean) / self.pixel_std 273 | 274 | # Pad 275 | h, w = x.shape[-2:] 276 | padh = self.image_encoder.img_size - h 277 | padw = self.image_encoder.img_size - w 278 | x = F.pad(x, (0, padw, 0, padh)) 279 | return x 280 | --------------------------------------------------------------------------------