├── segment_anything ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── build_sam.cpython-37.pyc │ ├── predictor.cpython-37.pyc │ └── automatic_mask_generator.cpython-37.pyc ├── utils │ ├── __pycache__ │ │ ├── amg.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── transforms.cpython-37.pyc │ ├── __init__.py │ ├── transforms.py │ ├── onnx.py │ └── amg.py ├── modeling │ ├── __pycache__ │ │ ├── sam.cpython-37.pyc │ │ ├── common.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── transformer.cpython-37.pyc │ │ ├── image_encoder.cpython-37.pyc │ │ ├── mask_decoder.cpython-37.pyc │ │ └── prompt_encoder.cpython-37.pyc │ ├── __init__.py │ ├── common.py │ ├── mask_decoder.py │ ├── transformer.py │ ├── prompt_encoder.py │ ├── sam.py │ └── image_encoder.py ├── __init__.py ├── build_sam.py ├── predictor.py └── automatic_mask_generator.py ├── .gitignore ├── environment.yml ├── README.md ├── load_LIDC_data.py ├── evaluate.py ├── p2sam.py ├── train_second_stage.py ├── sam_lora_image_encoder.py ├── train_first_stage.py └── utils.py /segment_anything/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/__pycache__/build_sam.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/__pycache__/predictor.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/utils/__pycache__/amg.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/sam.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/image_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/mask_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/modeling/__pycache__/prompt_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu2hi13/P2SAM/HEAD/segment_anything/__pycache__/automatic_mask_generator.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | venv.bak/ 16 | *.egg-info/ 17 | dist/ 18 | build/ 19 | 20 | # Jupyter Notebook checkpoints 21 | .ipynb_checkpoints 22 | 23 | # PyCharm 24 | .idea/ 25 | 26 | # VS Code 27 | .vscode/ 28 | 29 | # Logs and databases 30 | *.log 31 | *.sqlite3 32 | 33 | # Ignore data and model checkpoints 34 | checkpoint/ 35 | tf-logs/ 36 | *.pth 37 | 38 | 39 | # Ignore any other sensitive or unnecessary files 40 | evaluation_log.txt 41 | training_log_first_stage.txt -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: p2sam 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - backcall=0.2.0=pyhd3eb1b0_0 8 | - ca-certificates=2023.12.12=h06a4308_0 9 | - certifi=2022.12.7=py37h06a4308_0 10 | - decorator=5.1.1=pyhd3eb1b0_0 11 | - entrypoints=0.4=py37h06a4308_0 12 | - jupyter_client=7.4.9=py37h06a4308_0 13 | - ld_impl_linux-64=2.38=h1181459_1 14 | - libffi=3.3=he6710b0_2 15 | - libgcc-ng=11.2.0=h1234567_1 16 | - libgomp=11.2.0=h1234567_1 17 | - libsodium=1.0.18=h7b6447c_0 18 | - libstdcxx-ng=11.2.0=h1234567_1 19 | - matplotlib-inline=0.1.6=py37h06a4308_0 20 | - ncurses=6.4=h6a678d5_0 21 | - openssl=1.1.1w=h7f8727e_0 22 | - parso=0.8.3=pyhd3eb1b0_0 23 | - pickleshare=0.7.5=pyhd3eb1b0_1003 24 | - pip=22.3.1=py37h06a4308_0 25 | - ptyprocess=0.7.0=pyhd3eb1b0_2 26 | - python=3.7.11=h12debd9_0 27 | - python-dateutil=2.8.2=pyhd3eb1b0_0 28 | - readline=8.2=h5eee18b_0 29 | - setuptools=65.6.3=py37h06a4308_0 30 | - six=1.16.0=pyhd3eb1b0_1 31 | - sqlite=3.41.2=h5eee18b_0 32 | - tk=8.6.12=h1ccaba5_0 33 | - tornado=6.2=py37h5eee18b_0 34 | - wheel=0.38.4=py37h06a4308_0 35 | - xz=5.4.2=h5eee18b_0 36 | - zeromq=4.3.4=h2531618_0 37 | - zlib=1.2.13=h5eee18b_0 38 | - pip: 39 | - absl-py==2.0.0 40 | - asttokens==2.4.1 41 | - cached-property==1.5.2 42 | - cachetools==5.3.2 43 | - charset-normalizer==3.3.2 44 | - colorama==0.4.6 45 | - coloredlogs==15.0.1 46 | - contextlib2==21.6.0 47 | - cycler==0.11.0 48 | - debugpy==1.7.0 49 | - einops==0.6.1 50 | - executing==2.0.1 51 | - flatbuffers==23.5.26 52 | - fonttools==4.38.0 53 | - google-auth==2.23.4 54 | - google-auth-oauthlib==0.4.6 55 | - grpcio==1.59.2 56 | - h5py==3.5.0 57 | - humanfriendly==10.0 58 | - icecream==2.1.3 59 | - idna==3.4 60 | - imageio==2.10.1 61 | - importlib-metadata==6.7.0 62 | - ipykernel==6.16.2 63 | - ipython==7.34.0 64 | - jedi==0.19.1 65 | - joblib==1.3.2 66 | - jupyter-core==4.12.0 67 | - kiwisolver==1.4.5 68 | - kornia==0.6.12 69 | - markdown==3.4.4 70 | - markupsafe==2.1.3 71 | - matplotlib==3.5.3 72 | - medpy==0.4.0 73 | - ml-collections==0.1.1 74 | - monai==1.1.0 75 | - mpmath==1.3.0 76 | - nest-asyncio==1.5.8 77 | - netron==8.0.4 78 | - nibabel==4.0.2 79 | - numpy==1.21.6 80 | - nvidia-cublas-cu11==11.10.3.66 81 | - nvidia-cuda-nvrtc-cu11==11.7.99 82 | - nvidia-cuda-runtime-cu11==11.7.99 83 | - nvidia-cudnn-cu11==8.5.0.96 84 | - oauthlib==3.2.2 85 | - onnx==1.13.1 86 | - onnxruntime==1.14.1 87 | - opencv-python==4.5.4.58 88 | - packaging==23.2 89 | - pandas==1.3.5 90 | - pexpect==4.9.0 91 | - pillow==9.5.0 92 | - prompt-toolkit==3.0.43 93 | - protobuf==3.20.3 94 | - psutil==5.9.7 95 | - pyasn1==0.5.0 96 | - pyasn1-modules==0.3.0 97 | - pycocotools==2.0.6 98 | - pydicom==2.4.3 99 | - pygments==2.16.1 100 | - pyparsing==3.1.1 101 | - pytz==2024.1 102 | - pyyaml==6.0.1 103 | - pyzmq==25.1.2 104 | - requests==2.31.0 105 | - requests-oauthlib==1.3.1 106 | - rsa==4.9 107 | - safetensors==0.3.1 108 | - scikit-learn==1.0.2 109 | - scipy==1.7.3 110 | - simpleitk==2.2.1 111 | - sympy==1.10.1 112 | - tensorboard==2.11.2 113 | - tensorboard-data-server==0.6.1 114 | - tensorboard-plugin-wit==1.8.1 115 | - tensorboardx==2.6 116 | - threadpoolctl==3.1.0 117 | - torch==1.9.1+cu111 118 | - torchaudio==0.9.1 119 | - torchvision==0.10.1+cu111 120 | - tqdm==4.62.3 121 | - traitlets==5.9.0 122 | - typing-extensions==4.7.1 123 | - urllib3==2.0.7 124 | - wcwidth==0.2.12 125 | - werkzeug==2.2.3 126 | - zipp==3.15.0 127 | prefix: /data/cxli/miniconda3/envs/p2sam 128 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # P2SAM: Probabilistically Prompted SAMs for Ambiguous Medical Images 2 | 3 | This repository contains the implementation of **P2SAM**, a novel framework for ambiguous medical image segmentation. P2SAM leverages the prior knowledge of the Segment Anything Model (SAM) to enhance segmentation precision and diversity with minimal annotated data. The framework addresses the challenges of ambiguity and limited annotations in medical imaging. 4 | 5 | --- 6 | 7 | ## Features 8 | 9 | - **Probabilistic Prompt Generation**: Generates prompt distributions to guide SAM in ambiguous segmentation tasks. 10 | - **Diversity-Aware Assembling**: Aggregates diverse segmentation masks with learnable weights. 11 | - **Efficient Training**: Achieves high performance with significantly reduced training data requirements. 12 | - **State-of-the-Art Performance**: Outperforms baseline methods on metrics such as GED, HM-IoU, and Dmax. 13 | 14 | --- 15 | 16 | ## Repository Structure 17 | 18 | ``` 19 | . 20 | ├── environment.yml # Conda environment file for dependencies 21 | ├── evaluate.py # Evaluation script for segmentation results 22 | ├── load_LIDC_data.py # Data loader for the LIDC-IDRI dataset 23 | ├── p2sam.py # Implementation of the P2SAM framework 24 | ├── sam_lora_image_encoder.py # Enhanced SAM image encoder using LoRA 25 | ├── train_first_stage.py # Training script for the first stage (fine-tuning SAM) 26 | ├── train_second_stage.py # Training script for the second stage (probabilistic prompts) 27 | ├── utils.py # Utility functions for data preprocessing and model operations 28 | ``` 29 | 30 | --- 31 | 32 | ## Installation 33 | 34 | ### Prerequisites 35 | 36 | - Python >= 3.7 37 | - Conda (recommended for environment setup) 38 | 39 | ### Setup 40 | 41 | 1. Clone the repository: 42 | ```bash 43 | git clone https://github.com/yu2hi13/p2sam.git 44 | cd p2sam 45 | ``` 46 | 47 | 2. Create the environment: 48 | ```bash 49 | conda env create -f environment.yml 50 | conda activate p2sam 51 | ``` 52 | 53 | 3. Download the `sam_vit_b` weights and place them in the current directory: 54 | - [Download sam_vit_b weights](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 55 | 56 | 4. Pre-trained Weights: 57 | - The pre-trained weights are available on Google Drive. Please download them and place them in the `checkpoint` directory. 58 | - [Download pre-trained weights](https://drive.google.com/drive/folders/1u7C67QTbJW8GN7oh9qrWIQGEJW47a2dH?usp=drive_link) 59 | - **best_lora_checkpoint.pth**: Fine-tuned LoRA SAM weights. 60 | - **best_prior_checkpoint.pth**: Prior network weights. 61 | - **best_weights.pt**: Integrable mask weights. 62 | 63 | --- 64 | 65 | ## Usage 66 | 67 | ### Data Preparation 68 | 69 | - Download the required datasets: 70 | - [LIDC-IDRI](https://drive.google.com/drive/folders/1xKfKCQo8qa6SAr3u7qWNtQjIphIrvmd5) 71 | - Place the downloaded datasets in the `/data` directory. 72 | - Ensure the datasets are preprocessed to the format required by `load_LIDC_data.py`. 73 | 74 | ### Training 75 | 76 | 1. Train the first stage (fine-tuning SAM): 77 | ```bash 78 | python train_first_stage.py 79 | ``` 80 | 81 | 2. Train the second stage (probabilistic prompt generation): 82 | ```bash 83 | python train_second_stage.py 84 | ``` 85 | 86 | ### Evaluation 87 | 88 | Evaluate the trained model on the test dataset: 89 | ```bash 90 | python evaluate.py 91 | ``` 92 | 93 | --- 94 | 95 | ## Results 96 | 97 | P2SAM demonstrates superior performance in ambiguous medical image segmentation compared to state-of-the-art methods, achieving: 98 | 99 | - Higher accuracy in segmentation 100 | - Enhanced diversity in predictions 101 | - Robust performance with limited annotated data 102 | 103 | Detailed experimental results can be found in the [paper](https://doi.org/10.1145/3664647.3680628). 104 | 105 | --- 106 | 107 | ## Citation 108 | 109 | If you use this repository, please cite: 110 | 111 | ``` 112 | @inproceedings{huang2024p2sam, 113 | title={P2SAM: Probabilistically Prompted SAMs Are Efficient Segmentator for Ambiguous Medical Images}, 114 | author={Huang, Yuzhi and Li, Chenxin and others}, 115 | booktitle={ACM MM '24}, 116 | year={2024}, 117 | doi={10.1145/3664647.3680628} 118 | } 119 | ``` 120 | 121 | --- 122 | 123 | ## License 124 | 125 | This project is licensed under the MIT License. See the LICENSE file for details.# P2SAM 126 | -------------------------------------------------------------------------------- /load_LIDC_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from scipy import ndimage 8 | from scipy.ndimage import zoom 9 | from einops import repeat 10 | 11 | def random_rot_flip(image, label): 12 | # Randomly flip the image and label horizontally 13 | if np.random.rand() < 0.5: 14 | image = np.flip(image, axis=(1, 2)) 15 | label = np.flip(label, axis=1) 16 | 17 | # Randomly flip the image and label vertically 18 | if np.random.rand() < 0.5: 19 | image = np.flip(image, axis=(0, 2)) 20 | label = np.flip(label, axis=0) 21 | 22 | # Randomly rotate the image and label by 0, 90, 180, or 270 degrees 23 | rots = np.random.randint(0, 4) 24 | image = np.rot90(image, rots, axes=(1, 2)) 25 | label = np.rot90(label, rots, axes=(0, 1)) 26 | 27 | return image, label 28 | 29 | def random_rotate(image, label): 30 | # Randomly rotate the image and label by an angle between -20 and 20 degrees 31 | angle = np.random.randint(-20, 20) 32 | image = ndimage.rotate(image, angle, order=0, reshape=False) 33 | label = ndimage.rotate(label, angle, order=0, reshape=False) 34 | return image, label 35 | 36 | class RandomGenerator(object): 37 | def __init__(self, output_size, low_res, test=False): 38 | self.output_size = output_size 39 | self.low_res = low_res 40 | self.test = test 41 | 42 | def __call__(self, sample): 43 | image, label, label_four = sample['image'], sample['label'], sample['label_four'] 44 | label_four = np.stack(label_four, axis=0).astype(np.int64) 45 | label = label.squeeze() 46 | 47 | if not self.test: 48 | image, label = random_rot_flip(image, label) 49 | 50 | image_oc = image.copy() 51 | x, y = image.shape[-2:] 52 | 53 | # Resize image and label to the specified output size 54 | if x != self.output_size[0] or y != self.output_size[1]: 55 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) 56 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 57 | 58 | label_h, label_w = label.shape 59 | low_res_label = zoom(label, (self.low_res[0] / label_h, self.low_res[1] / label_w), order=0) 60 | 61 | # Convert numpy arrays to torch tensors 62 | image = torch.from_numpy(image.astype(np.float32)) 63 | image = repeat(image, 'c h w -> (repeat c) h w', repeat=3) 64 | label = torch.from_numpy(label.astype(np.float32)) 65 | low_res_label = torch.from_numpy(low_res_label.astype(np.float32)) 66 | 67 | sample = { 68 | 'image': image, 69 | 'label': label.long(), 70 | 'low_res_label': low_res_label.long(), 71 | 'image_oc': image_oc, 72 | 'label_four': label_four 73 | } 74 | return sample 75 | 76 | class LIDC_IDRI(Dataset): 77 | def __init__(self, dataset_location, transform=None): 78 | self.transform = transform 79 | self.images = [] 80 | self.labels = [] 81 | self.series_uid = [] 82 | 83 | max_bytes = 2**31 - 1 84 | data = {} 85 | 86 | # Load data from pickle files 87 | for file in os.listdir(dataset_location): 88 | filename = os.fsdecode(file) 89 | if '.pickle' in filename: 90 | print("Loading file", filename) 91 | file_path = os.path.join(dataset_location, filename) 92 | bytes_in = bytearray(0) 93 | input_size = os.path.getsize(file_path) 94 | with open(file_path, 'rb') as f_in: 95 | for _ in range(0, input_size, max_bytes): 96 | bytes_in += f_in.read(max_bytes) 97 | new_data = pickle.loads(bytes_in) 98 | data.update(new_data) 99 | 100 | for key, value in data.items(): 101 | self.images.append(value['image'].astype(float)) 102 | self.labels.append(value['masks']) 103 | self.series_uid.append(value['series_uid']) 104 | 105 | assert len(self.images) == len(self.labels) == len(self.series_uid) 106 | 107 | for img in self.images: 108 | assert np.max(img) <= 1 and np.min(img) >= 0 109 | for label in self.labels: 110 | assert np.max(label) <= 1 and np.min(label) >= 0 111 | 112 | del new_data 113 | del data 114 | 115 | def __getitem__(self, index): 116 | image = np.expand_dims(self.images[index], axis=0) 117 | # Randomly select one of the four labels for this image 118 | label = self.labels[index][random.randint(0, 3)].astype(float) 119 | label_four = self.labels[index] 120 | while label.sum() == 0: 121 | label = self.labels[index][random.randint(0, 3)].astype(float) 122 | 123 | # Convert image and label to torch tensors 124 | image = torch.from_numpy(image).type(torch.FloatTensor) 125 | label = torch.from_numpy(label).type(torch.LongTensor).unsqueeze(0) 126 | 127 | image = np.array(image) 128 | label = np.array(label) 129 | 130 | sample = {'image': image, 'label': label, 'label_four': label_four} 131 | if self.transform: 132 | sample = self.transform(sample) 133 | 134 | return sample 135 | 136 | def __len__(self): 137 | return len(self.images) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import argparse 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader, Subset 11 | from torchvision import transforms 12 | from tensorboardX import SummaryWriter 13 | from utils import ( 14 | generalized_energy_distance_iou, hm_iou_cal, dice_max_cal2 15 | ) 16 | from load_LIDC_data import LIDC_IDRI, RandomGenerator 17 | from p2sam import P2SAM 18 | 19 | # Configure logging 20 | logging.basicConfig(filename='evaluation_log.txt', level=logging.INFO, format='%(asctime)s - %(message)s') 21 | 22 | # Parse command-line arguments 23 | parser = argparse.ArgumentParser(description='Evaluate the model with specified epochs and weights.') 24 | parser.add_argument('--epochs', nargs='+', type=int, default=[10, 50, 70, 100], help='Epochs to load weights from.') 25 | parser.add_argument('--combined_weights_path', type=str, default='best/best_prior_checkpoint.pth', help='Path to the combined weights file.') 26 | parser.add_argument('--weight_eight_path', type=str, default='best/best_weights.pt', help='Path to the weight_eight file.') 27 | parser.add_argument('--lora_sam_weights_path', type=str, default='best/best_lora_checkpoint.pth', help='Path to the lora_sam weights file.') 28 | parser.add_argument('--gpuid', type=int, default=3, help='ID of the GPU to use.') 29 | parser.add_argument('--batch_size', type=int, default=8, help='Batch size for data loading.') 30 | parser.add_argument('--total_samples', type=int, default=16, help='Total number of samples to generate.') 31 | args = parser.parse_args() 32 | 33 | # Set device 34 | device = torch.device(f'cuda:{args.gpuid}' if torch.cuda.is_available() else 'cpu') 35 | 36 | combined_weights = torch.load(args.combined_weights_path, map_location=device) 37 | weight_eight = torch.load(args.weight_eight_path, map_location=device) 38 | 39 | # Initialize networks 40 | def initialize_networks(epochs): 41 | networks = [] 42 | for epoch in epochs: 43 | epoch_key = f'epoch_{epoch}' 44 | if epoch_key in combined_weights: 45 | net = P2SAM(device=device, lora_ckpt=args.lora_sam_weights_path) # Initialize network 46 | non_lora_weights = combined_weights[epoch_key] 47 | net.load_state_dict(non_lora_weights, strict=False) # Load non-lora_sam weights 48 | net.to(device) 49 | networks.append(net) 50 | else: 51 | print(f"Warning: Weights for epoch {epoch} not found.") 52 | return networks 53 | 54 | ged_score = dice_max2_score = hm_iou_score = 0 55 | networks = initialize_networks(args.epochs) 56 | 57 | # Log the weights being used 58 | logging.info(f"Using combined weights from {args.combined_weights_path} for epochs: {args.epochs}") 59 | 60 | # Prepare dataset 61 | low_res = networks[0].img_embedding_size * 4 62 | db = LIDC_IDRI(dataset_location='/data/cxli/yuzhi/samed/SAMed-main/data/', transform=transforms.Compose([ 63 | RandomGenerator(output_size=[128, 128], low_res=[low_res, low_res], test=True) 64 | ])) 65 | dataset_size = len(db) 66 | indices = list(range(dataset_size)) 67 | train_split = int(np.floor(0.6 * dataset_size)) 68 | validation_split = int(np.floor(0.8 * dataset_size)) 69 | train_indices = indices[:train_split] 70 | validation_indices = indices[train_split:validation_split] 71 | test_indices = indices[validation_split:] 72 | 73 | train_dataset = Subset(db, train_indices) 74 | validation_dataset = Subset(db, validation_indices) 75 | test_dataset = Subset(db, test_indices) 76 | 77 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 78 | validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=False) 79 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 80 | 81 | print(f"Total dataset size: {dataset_size}") 82 | print(f"Training set size: {len(train_indices)}") 83 | print(f"Validation set size: {len(validation_indices)}") 84 | print(f"Test set size: {len(test_indices)}") 85 | 86 | # Hyperparameter: total number of samples 87 | samples_per_net = args.total_samples // len(networks) # Number of samples generated by each network 88 | 89 | # Process each batch and evaluate metrics 90 | for i_batch, sampled_batch in enumerate(test_loader): 91 | print(f'Processing batch {i_batch}') 92 | logging.info(f'Processing batch {i_batch}') 93 | image_batch, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device) 94 | label_four_batch = sampled_batch['label_four'] 95 | low_res_label_batch = sampled_batch['low_res_label'].to(device) 96 | assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}' 97 | 98 | pred_list = [[] for _ in range(image_batch.shape[0])] 99 | 100 | for i in range(samples_per_net): 101 | image_batch_oc = sampled_batch['image_oc'].to(device) 102 | outputs = [net.forward(image_batch, image_batch_oc, train=False) for net in networks] 103 | 104 | for j, output in enumerate(outputs): 105 | logits_high = output['masks'].to(device) * weight_eight.unsqueeze(-1) # Apply weight_eight 106 | logits_high_res = logits_high.sum(1).unsqueeze(1) 107 | 108 | for k in range(image_batch.shape[0]): 109 | pred_list[k].append(logits_high_res[k]) 110 | 111 | for index in range(len(pred_list)): 112 | label_four_filter = label_four_batch[index] 113 | pred_eval = torch.cat(pred_list[index], 0) 114 | pred_eval = (pred_eval > 0).cpu().detach().int() 115 | 116 | iou_score_iter, ged_score_iter = generalized_energy_distance_iou(pred_eval, label_four_filter) 117 | score, _ = hm_iou_cal(pred_eval, label_four_filter) 118 | hm_iou_score += score 119 | dice_max2_score += dice_max_cal2(pred_eval, label_four_filter) 120 | ged_score += ged_score_iter 121 | 122 | # Calculate average scores 123 | ged = ged_score / len(test_indices) 124 | dice_max2 = dice_max2_score / len(test_indices) 125 | hm_iou = hm_iou_score / len(test_indices) 126 | 127 | print(f"ged_score: {ged}, dice_max_score2: {dice_max2}, hm_iou_score: {hm_iou}") 128 | logging.info(f"ged_score: {ged}, dice_max_score2: {dice_max2}, hm_iou_score: {hm_iou}") -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from icecream import ic 10 | 11 | from functools import partial 12 | 13 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 14 | 15 | 16 | def build_sam_vit_h(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 17 | checkpoint=None): 18 | return _build_sam( 19 | encoder_embed_dim=1280, 20 | encoder_depth=32, 21 | encoder_num_heads=16, 22 | encoder_global_attn_indexes=[7, 15, 23, 31], 23 | checkpoint=checkpoint, 24 | num_classes=num_classes, 25 | image_size=image_size, 26 | pixel_mean=pixel_mean, 27 | pixel_std=pixel_std 28 | ) 29 | 30 | 31 | build_sam = build_sam_vit_h 32 | 33 | 34 | def build_sam_vit_l(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 35 | checkpoint=None): 36 | return _build_sam( 37 | encoder_embed_dim=1024, 38 | encoder_depth=24, 39 | encoder_num_heads=16, 40 | encoder_global_attn_indexes=[5, 11, 17, 23], 41 | checkpoint=checkpoint, 42 | num_classes=num_classes, 43 | image_size=image_size, 44 | pixel_mean=pixel_mean, 45 | pixel_std=pixel_std 46 | ) 47 | 48 | 49 | def build_sam_vit_b(image_size, num_classes, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], 50 | checkpoint=None): 51 | return _build_sam( 52 | encoder_embed_dim=768, 53 | encoder_depth=12, 54 | encoder_num_heads=12, 55 | encoder_global_attn_indexes=[2, 5, 8, 11], 56 | # adopt global attention at [3, 6, 9, 12] transform layer, else window attention layer 57 | checkpoint=checkpoint, 58 | num_classes=num_classes, 59 | image_size=image_size, 60 | pixel_mean=pixel_mean, 61 | pixel_std=pixel_std 62 | ) 63 | 64 | 65 | sam_model_registry = { 66 | "default": build_sam_vit_h, 67 | "vit_h": build_sam_vit_h, 68 | "vit_l": build_sam_vit_l, 69 | "vit_b": build_sam_vit_b, 70 | } 71 | 72 | 73 | def _build_sam( 74 | encoder_embed_dim, 75 | encoder_depth, 76 | encoder_num_heads, 77 | encoder_global_attn_indexes, 78 | num_classes, 79 | image_size, 80 | pixel_mean, 81 | pixel_std, 82 | checkpoint=None, 83 | ): 84 | prompt_embed_dim = 256 85 | image_size = image_size 86 | vit_patch_size = 16 87 | image_embedding_size = image_size // vit_patch_size # Divide by 16 here 88 | sam = Sam( 89 | image_encoder=ImageEncoderViT( 90 | depth=encoder_depth, 91 | embed_dim=encoder_embed_dim, 92 | img_size=image_size, 93 | mlp_ratio=4, 94 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 95 | num_heads=encoder_num_heads, 96 | patch_size=vit_patch_size, 97 | qkv_bias=True, 98 | use_rel_pos=True, 99 | global_attn_indexes=encoder_global_attn_indexes, 100 | window_size=14, 101 | out_chans=prompt_embed_dim, 102 | ), 103 | prompt_encoder=PromptEncoder( 104 | embed_dim=prompt_embed_dim, 105 | image_embedding_size=(image_embedding_size, image_embedding_size), 106 | input_image_size=(image_size, image_size), 107 | mask_in_chans=16, 108 | ), 109 | mask_decoder=MaskDecoder( 110 | # num_multimask_outputs=3, 111 | num_multimask_outputs=num_classes, 112 | transformer=TwoWayTransformer( 113 | depth=2, 114 | embedding_dim=prompt_embed_dim, 115 | mlp_dim=2048, 116 | num_heads=8, 117 | ), 118 | transformer_dim=prompt_embed_dim, 119 | iou_head_depth=3, 120 | iou_head_hidden_dim=256, 121 | ), 122 | # pixel_mean=[123.675, 116.28, 103.53], 123 | # pixel_std=[58.395, 57.12, 57.375], 124 | pixel_mean=pixel_mean, 125 | pixel_std=pixel_std 126 | ) 127 | # sam.eval() 128 | sam.train() 129 | if checkpoint is not None: 130 | with open(checkpoint, "rb") as f: 131 | state_dict = torch.load(f) 132 | try: 133 | sam.load_state_dict(state_dict) 134 | except: 135 | new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size) 136 | sam.load_state_dict(new_state_dict) 137 | return sam, image_embedding_size 138 | 139 | 140 | def load_from(sam, state_dict, image_size, vit_patch_size): 141 | sam_dict = sam.state_dict() 142 | except_keys = ['mask_tokens', 'output_hypernetworks_mlps', 'iou_prediction_head'] 143 | new_state_dict = {k: v for k, v in state_dict.items() if 144 | k in sam_dict.keys() and except_keys[0] not in k and except_keys[1] not in k and except_keys[2] not in k} 145 | pos_embed = new_state_dict['image_encoder.pos_embed'] 146 | token_size = int(image_size // vit_patch_size) 147 | if pos_embed.shape[1] != token_size: 148 | # resize pos embedding, which may sacrifice the performance, but I have no better idea 149 | pos_embed = pos_embed.permute(0, 3, 1, 2) # [b, c, h, w] 150 | pos_embed = F.interpolate(pos_embed, (token_size, token_size), mode='bilinear', align_corners=False) 151 | pos_embed = pos_embed.permute(0, 2, 3, 1) # [b, h, w, c] 152 | new_state_dict['image_encoder.pos_embed'] = pos_embed 153 | rel_pos_keys = [k for k in sam_dict.keys() if 'rel_pos' in k] 154 | global_rel_pos_keys = [k for k in rel_pos_keys if '2' in k or '5' in k or '8' in k or '11' in k] 155 | for k in global_rel_pos_keys: 156 | rel_pos_params = new_state_dict[k] 157 | h, w = rel_pos_params.shape 158 | rel_pos_params = rel_pos_params.unsqueeze(0).unsqueeze(0) 159 | rel_pos_params = F.interpolate(rel_pos_params, (token_size * 2 - 1, w), mode='bilinear', align_corners=False) 160 | new_state_dict[k] = rel_pos_params[0, 0, ...] 161 | sam_dict.update(new_state_dict) 162 | return sam_dict 163 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from icecream import ic 11 | 12 | from typing import List, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class MaskDecoder(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | transformer_dim: int, 22 | transformer: nn.Module, 23 | num_multimask_outputs: int = 3, 24 | activation: Type[nn.Module] = nn.GELU, 25 | iou_head_depth: int = 3, 26 | iou_head_hidden_dim: int = 256, 27 | ) -> None: 28 | """ 29 | Predicts masks given an image and prompt embeddings, using a 30 | tranformer architecture. 31 | 32 | Arguments: 33 | transformer_dim (int): the channel dimension of the transformer 34 | transformer (nn.Module): the transformer used to predict masks 35 | num_multimask_outputs (int): the number of masks to predict 36 | when disambiguating masks 37 | activation (nn.Module): the type of activation to use when 38 | upscaling masks 39 | iou_head_depth (int): the depth of the MLP used to predict 40 | mask quality 41 | iou_head_hidden_dim (int): the hidden dimension of the MLP 42 | used to predict mask quality 43 | """ 44 | super().__init__() 45 | self.transformer_dim = transformer_dim 46 | self.transformer = transformer 47 | 48 | self.num_multimask_outputs = num_multimask_outputs 49 | 50 | self.iou_token = nn.Embedding(1, transformer_dim) 51 | self.num_mask_tokens = num_multimask_outputs + 1 52 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 53 | 54 | self.output_upscaling = nn.Sequential( 55 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 56 | LayerNorm2d(transformer_dim // 4), 57 | activation(), 58 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 59 | activation(), 60 | ) 61 | self.output_hypernetworks_mlps = nn.ModuleList( 62 | [ 63 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 64 | for i in range(self.num_mask_tokens) 65 | ] 66 | ) 67 | 68 | self.iou_prediction_head = MLP( 69 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 70 | ) 71 | 72 | def forward( 73 | self, 74 | image_embeddings: torch.Tensor, 75 | image_pe: torch.Tensor, 76 | sparse_prompt_embeddings: torch.Tensor, 77 | dense_prompt_embeddings: torch.Tensor, 78 | multimask_output: bool, 79 | ) -> Tuple[torch.Tensor, torch.Tensor]: 80 | """ 81 | Predict masks given image and prompt embeddings. 82 | 83 | Arguments: 84 | image_embeddings (torch.Tensor): the embeddings from the image encoder 85 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 86 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 87 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 88 | multimask_output (bool): Whether to return multiple masks or a single 89 | mask. 90 | 91 | Returns: 92 | torch.Tensor: batched predicted masks 93 | torch.Tensor: batched predictions of mask quality 94 | """ 95 | masks, iou_pred = self.predict_masks( 96 | image_embeddings=image_embeddings, 97 | image_pe=image_pe, 98 | sparse_prompt_embeddings=sparse_prompt_embeddings, 99 | dense_prompt_embeddings=dense_prompt_embeddings, 100 | ) 101 | 102 | # Select the correct mask or masks for output 103 | if multimask_output: 104 | mask_slice = slice(1, None) 105 | else: 106 | mask_slice = slice(0, 1) 107 | masks = masks[:, mask_slice, :, :] 108 | iou_pred = iou_pred[:, mask_slice] 109 | 110 | # Prepare output 111 | return masks, iou_pred 112 | 113 | def predict_masks( 114 | self, 115 | image_embeddings: torch.Tensor, 116 | image_pe: torch.Tensor, 117 | sparse_prompt_embeddings: torch.Tensor, 118 | dense_prompt_embeddings: torch.Tensor, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """Predicts masks. See 'forward' for more details.""" 121 | # Concatenate output tokens 122 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 123 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 124 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 125 | 126 | # Expand per-image data in batch direction to be per-mask 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | src = src + dense_prompt_embeddings 129 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 130 | b, c, h, w = src.shape 131 | 132 | # Run the transformer 133 | hs, src = self.transformer(src, pos_src, tokens) 134 | iou_token_out = hs[:, 0, :] 135 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 136 | 137 | # Upscale mask embeddings and predict masks using the mask tokens 138 | src = src.transpose(1, 2).view(b, c, h, w) 139 | upscaled_embedding = self.output_upscaling(src) 140 | hyper_in_list: List[torch.Tensor] = [] 141 | for i in range(self.num_mask_tokens): 142 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 143 | hyper_in = torch.stack(hyper_in_list, dim=1) # [b, c, token_num] 144 | 145 | b, c, h, w = upscaled_embedding.shape # [h, token_num, h, w] 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 256, 256], 256 = 4 * 64, the size of image embeddings 147 | 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x 179 | -------------------------------------------------------------------------------- /p2sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal, Independent 5 | from utils import init_weights, init_weights_orthogonal_normal 6 | from segment_anything import sam_model_registry 7 | from sam_lora_image_encoder import LoRA_Sam 8 | import numpy as np 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True): 12 | super(Encoder, self).__init__() 13 | self.input_channels = input_channels 14 | self.num_filters = num_filters 15 | 16 | layers = [] 17 | for i in range(len(self.num_filters)): 18 | input_dim = self.input_channels if i == 0 else output_dim 19 | output_dim = num_filters[i] 20 | 21 | if i != 0: 22 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 23 | 24 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding))) 25 | layers.append(nn.ReLU(inplace=True)) 26 | 27 | for _ in range(no_convs_per_block-1): 28 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding))) 29 | layers.append(nn.ReLU(inplace=True)) 30 | 31 | self.layers = nn.Sequential(*layers) 32 | self.layers.apply(init_weights) 33 | 34 | def forward(self, x): 35 | return self.layers(x) 36 | 37 | class AxisAlignedConvGaussianPrior(nn.Module): 38 | def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers): 39 | super(AxisAlignedConvGaussianPrior, self).__init__() 40 | self.input_channels = input_channels 41 | self.channel_axis = 1 42 | self.num_filters = num_filters 43 | self.no_convs_per_block = no_convs_per_block 44 | self.latent_dim = latent_dim 45 | self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers) 46 | self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1) 47 | 48 | nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') 49 | nn.init.normal_(self.conv_layer.bias) 50 | 51 | 52 | def forward(self, x): 53 | x = x.to(torch.float32) 54 | encoding = self.encoder(x) 55 | encoding = torch.mean(encoding, dim=(2, 3), keepdim=True) 56 | mu_log_sigma = self.conv_layer(encoding).squeeze(-1).squeeze(-1) 57 | mu = mu_log_sigma[:, :self.latent_dim] 58 | log_sigma = mu_log_sigma[:, self.latent_dim:] 59 | return Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) 60 | 61 | class Fcomb(nn.Module): 62 | def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True): 63 | super(Fcomb, self).__init__() 64 | self.use_tile = use_tile 65 | 66 | if self.use_tile: 67 | layers = [nn.Conv2d(512, 256, kernel_size=1), nn.ReLU(inplace=True)] 68 | for _ in range(no_convs_fcomb - 2): 69 | layers.extend([nn.Conv2d(256, 256, kernel_size=1), nn.ReLU(inplace=True)]) 70 | self.layers = nn.Sequential(*layers) 71 | self.last_layer = nn.Conv2d(256, 256, kernel_size=1) 72 | 73 | if initializers['w'] == 'orthogonal': 74 | self.layers.apply(init_weights_orthogonal_normal) 75 | self.last_layer.apply(init_weights_orthogonal_normal) 76 | else: 77 | self.layers.apply(init_weights) 78 | self.last_layer.apply(init_weights) 79 | 80 | def tile(self, a, dim, n_tile): 81 | init_dim = a.size(dim) 82 | repeat_idx = [1] * a.dim() 83 | repeat_idx[dim] = n_tile 84 | a = a.repeat(*repeat_idx) 85 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(a.device) 86 | return torch.index_select(a, dim, order_index) 87 | 88 | def forward(self, feature_map, z): 89 | if self.use_tile: 90 | z = z.unsqueeze(2).unsqueeze(2) 91 | z = self.tile(z, 2, feature_map.size(2)) 92 | z = self.tile(z, 3, feature_map.size(3)) 93 | feature_map = torch.cat((feature_map, z), dim=1) 94 | output = self.layers(feature_map) 95 | return self.last_layer(output) 96 | 97 | class P2SAM(nn.Module): 98 | def __init__(self, device, lora_ckpt, input_channels=1, num_classes=8, img_size=128, num_filters=[32, 64, 128, 192], latent_dim=256, no_convs_fcomb=4, beta=10.0): 99 | super(P2SAM, self).__init__() 100 | self.device = device 101 | self.img_size = img_size 102 | self.input_channels = input_channels 103 | self.num_classes = num_classes 104 | self.num_filters = num_filters 105 | self.latent_dim = latent_dim 106 | self.no_convs_per_block = 3 107 | self.no_convs_fcomb = no_convs_fcomb 108 | self.initializers = {'w': 'he_normal', 'b': 'normal'} 109 | self.beta = beta 110 | 111 | self.sam, self.img_embedding_size = sam_model_registry["vit_b"]( 112 | image_size=self.img_size, 113 | num_classes=self.num_classes, 114 | checkpoint="sam_vit_b_01ec64.pth", 115 | pixel_mean=[0, 0, 0], 116 | pixel_std=[1, 1, 1] 117 | ) 118 | self.sam.to(self.device) 119 | self.lora_sam = LoRA_Sam(self.sam, 4).to(self.device) 120 | self.lora_sam.load_lora_parameters(lora_ckpt) 121 | self.prior_dense = AxisAlignedConvGaussianPrior(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers).to(self.device) 122 | self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, {'w': 'orthogonal', 'b': 'normal'}, use_tile=True).to(self.device) 123 | 124 | def forward(self, batch_input, batch_input_ori, input_size=128, train=True): 125 | img_size = input_size 126 | self.prior_dense_latent_space = self.prior_dense(batch_input_ori) 127 | input_images = self.lora_sam.sam.preprocess(batch_input) 128 | image_embeddings = self.lora_sam.sam.image_encoder(input_images) 129 | sparse_embeddings, dense_embeddings = self.lora_sam.sam.prompt_encoder(points=None, boxes=None, masks=None) 130 | batch_shape = batch_input.size(0) 131 | dense_embeddings = dense_embeddings.repeat(batch_shape, 1, 1, 1) 132 | 133 | if train: 134 | z_posterior_dense = self.prior_dense_latent_space.rsample() 135 | else: 136 | z_posterior_dense = self.prior_dense_latent_space.sample() 137 | 138 | dense_embeddings_ditsturb = self.fcomb(dense_embeddings, z_posterior_dense) 139 | low_res_masks, iou_predictions = self.lora_sam.sam.mask_decoder( 140 | image_embeddings=image_embeddings, 141 | image_pe=self.lora_sam.sam.prompt_encoder.get_dense_pe(), 142 | sparse_prompt_embeddings=sparse_embeddings, 143 | dense_prompt_embeddings=dense_embeddings_ditsturb, 144 | multimask_output=True 145 | ) 146 | masks = self.lora_sam.sam.postprocess_masks( 147 | low_res_masks, 148 | input_size=(img_size, img_size), 149 | original_size=(128, 128) 150 | ) 151 | 152 | return { 153 | 'masks': masks, 154 | 'iou_predictions': iou_predictions, 155 | 'low_res_logits': low_res_masks 156 | } -------------------------------------------------------------------------------- /train_second_stage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Subset 5 | from torchvision import transforms 6 | from p2sam import P2SAM 7 | from load_LIDC_data import LIDC_IDRI, RandomGenerator 8 | from utils import l2_regularisation, calculate_dice_loss, calculate_sigmoid_focal_loss 9 | from tensorboardX import SummaryWriter 10 | import argparse 11 | import logging 12 | 13 | # Configure logging 14 | logging.basicConfig(filename='training_log_second_stage.txt', level=logging.INFO, format='%(asctime)s - %(message)s') 15 | 16 | def evaluate(model, data_loader, device, weight_eight): 17 | model.eval() 18 | total_loss = 0.0 19 | with torch.no_grad(): 20 | for sampled_batch in data_loader: 21 | image_batch, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device) 22 | image_batch_oc = sampled_batch['image_oc'].to(device) 23 | outputs = model.forward(image_batch, image_batch_oc) 24 | output_masks = outputs['masks'] 25 | logits_high = output_masks * weight_eight.unsqueeze(-1) 26 | logits_high_res = logits_high.sum(1).unsqueeze(1) 27 | gt_mask = label_batch.unsqueeze(1) 28 | dice_loss = calculate_dice_loss(logits_high_res, gt_mask[:].long()) 29 | focal_loss = calculate_sigmoid_focal_loss(logits_high_res, gt_mask[:].float()) 30 | total_loss += (dice_loss + focal_loss).item() 31 | return total_loss / len(data_loader) 32 | 33 | def save_non_lora_weights(model, epoch, combined_weights): 34 | non_lora_state_dict = {k: v.cpu() for k, v in model.state_dict().items() if not k.startswith('lora')} 35 | combined_weights[f'epoch_{epoch}'] = non_lora_state_dict 36 | print(f"Non-LoRA weights for epoch {epoch} added to combined weights.") 37 | 38 | def main(args): 39 | # Set global device 40 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu') 41 | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) 42 | 43 | # Initialize model 44 | print(f"Loading best model from {args.model_path}") 45 | net = P2SAM(device=device, lora_ckpt=args.model_path).to(device) 46 | for param in net.lora_sam.parameters(): 47 | param.requires_grad = False 48 | 49 | # Load weights 50 | print(f"Loading best weights from {args.weight_path}") 51 | weight_eight = torch.load(args.weight_path).to(device) 52 | low_res = net.img_embedding_size * 4 53 | 54 | # Prepare dataset 55 | db = LIDC_IDRI(dataset_location='data/', transform=transforms.Compose([ 56 | RandomGenerator(output_size=[128, 128], low_res=[low_res, low_res], test=True) 57 | ])) 58 | dataset_size = len(db) 59 | indices = list(range(dataset_size)) 60 | train_split = int(np.floor(0.6 * dataset_size)) 61 | validation_split = int(np.floor(0.8 * dataset_size)) 62 | train_indices = indices[:500] 63 | validation_indices = indices[train_split:validation_split] 64 | test_indices = indices[validation_split:] 65 | 66 | train_dataset = Subset(db, train_indices) 67 | validation_dataset = Subset(db, validation_indices) 68 | test_dataset = Subset(db, test_indices) 69 | 70 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 71 | validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=False) 72 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 73 | 74 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, weight_decay=0) 75 | writer = SummaryWriter('tf-logs/train_second_stage') 76 | max_epoch = args.epochs 77 | start_epoch = 1 78 | 79 | # Dictionary to store all saved weights 80 | combined_weights = {} 81 | 82 | try: 83 | for epoch_num in range(start_epoch, max_epoch + 1): 84 | net.train() 85 | loss_epoch = 0.0 86 | reg_loss_epoch = 0.0 87 | print(f"Epoch {epoch_num}") 88 | 89 | for i_batch, sampled_batch in enumerate(train_loader): 90 | image_batch, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device) 91 | image_batch_oc = sampled_batch['image_oc'].to(device) 92 | assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}' 93 | 94 | # Forward pass 95 | outputs = net.forward(image_batch, image_batch_oc) 96 | output_masks = outputs['masks'] 97 | logits_high = output_masks * weight_eight.unsqueeze(-1) 98 | logits_high_res = logits_high.sum(1).unsqueeze(1) 99 | 100 | # Calculate loss 101 | cel = torch.nn.CrossEntropyLoss() 102 | cel_loss = cel(logits_high, label_batch[:].long()) 103 | reg_loss = l2_regularisation(net.prior_dense) + l2_regularisation(net.fcomb.layers) 104 | gt_mask = label_batch.unsqueeze(1) 105 | dice_loss = calculate_dice_loss(logits_high_res, gt_mask[:].long()) 106 | focal_loss = calculate_sigmoid_focal_loss(logits_high_res, gt_mask[:].float()) 107 | loss = cel_loss + args.reg_weight * reg_loss + dice_loss + focal_loss 108 | loss_epoch += loss.item() 109 | reg_loss_epoch += reg_loss.item() 110 | 111 | # Backward pass and optimization 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | avg_train_loss = loss_epoch / len(train_loader) 117 | writer.add_scalar("Train/Loss", avg_train_loss, epoch_num) 118 | print(f"Average Training Loss: {avg_train_loss}") 119 | logging.info(f"Epoch {epoch_num}: Average Training Loss: {avg_train_loss}") 120 | 121 | # Save model every 10 epochs 122 | if epoch_num % 10 == 0: 123 | save_non_lora_weights(net, epoch_num, combined_weights) 124 | print(f"Checkpoint saved at epoch {epoch_num}") 125 | logging.info(f"Checkpoint saved at epoch {epoch_num}") 126 | 127 | except Exception as e: 128 | logging.error(f"An error occurred: {e}") 129 | finally: 130 | # Save weights corresponding to epochs 131 | torch.save(combined_weights, args.save_path) 132 | print(f"Epoch weights saved to {args.save_path}.") 133 | logging.info(f"Epoch weights saved to {args.save_path}.") 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description='Train the second stage model with specified parameters.') 137 | parser.add_argument('--batch_size', type=int, default=10, help='Batch size for training.') 138 | parser.add_argument('--lr', type=float, default=1e-6, help='Learning rate.') 139 | parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train.') 140 | parser.add_argument('--gpu_id', type=int, default=2, help='GPU ID to use for training.') 141 | parser.add_argument('--model_path', type=str, default='checkpoint/last_model_epoch_101.pth', help='Path to the best model file.') 142 | parser.add_argument('--weight_path', type=str, default='checkpoint/last_mask_weights_epoch_101.pt', help='Path to the best weight file.') 143 | parser.add_argument('--save_path', type=str, default='checkpoint/final_weights.pth', help='Path to save the final weights.') 144 | parser.add_argument('--reg_weight', type=float, default=1e-5, help='Regularization weight.') 145 | 146 | args = parser.parse_args() 147 | main(args) -------------------------------------------------------------------------------- /sam_lora_image_encoder.py: -------------------------------------------------------------------------------- 1 | from segment_anything import build_sam, SamPredictor 2 | from segment_anything import sam_model_registry 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.parameter import Parameter 10 | from segment_anything.modeling import Sam 11 | from safetensors import safe_open 12 | from safetensors.torch import save_file 13 | 14 | from icecream import ic 15 | 16 | 17 | class _LoRA_qkv(nn.Module): 18 | """In Sam it is implemented as 19 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 20 | B, N, C = x.shape 21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 22 | q, k, v = qkv.unbind(0) 23 | """ 24 | 25 | def __init__( 26 | self, 27 | qkv: nn.Module, 28 | linear_a_q: nn.Module, 29 | linear_b_q: nn.Module, 30 | linear_a_v: nn.Module, 31 | linear_b_v: nn.Module, 32 | ): 33 | super().__init__() 34 | self.qkv = qkv 35 | self.linear_a_q = linear_a_q 36 | self.linear_b_q = linear_b_q 37 | self.linear_a_v = linear_a_v 38 | self.linear_b_v = linear_b_v 39 | self.dim = qkv.in_features 40 | self.w_identity = torch.eye(qkv.in_features) 41 | 42 | def forward(self, x): 43 | qkv = self.qkv(x) # B,N,N,3*org_C 44 | new_q = self.linear_b_q(self.linear_a_q(x)) 45 | new_v = self.linear_b_v(self.linear_a_v(x)) 46 | qkv[:, :, :, : self.dim] += new_q 47 | qkv[:, :, :, -self.dim:] += new_v 48 | return qkv 49 | 50 | class LoRA_Sam(nn.Module): 51 | """Applies low-rank adaptation to a Sam model's image encoder. 52 | 53 | Args: 54 | sam_model: a vision transformer model, see base_vit.py 55 | r: rank of LoRA 56 | num_classes: how many classes the model output, default to the vit model 57 | lora_layer: which layer we apply LoRA. 58 | 59 | Examples:: 60 | >>> model = ViT('B_16_imagenet1k') 61 | >>> lora_model = LoRA_ViT(model, r=4) 62 | >>> preds = lora_model(img) 63 | >>> print(preds.shape) 64 | torch.Size([1, 1000]) 65 | """ 66 | 67 | def __init__(self, sam_model: Sam, r: int, lora_layer=None): 68 | super(LoRA_Sam, self).__init__() 69 | 70 | assert r > 0 71 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels 72 | # dim = base_vit_dim 73 | if lora_layer: 74 | self.lora_layer = lora_layer 75 | else: 76 | self.lora_layer = list( 77 | range(len(sam_model.image_encoder.blocks))) # Only apply lora to the image encoder by default 78 | # create for storage, then we can init them or load weights 79 | self.w_As = [] # These are linear layers 80 | self.w_Bs = [] 81 | 82 | # lets freeze first 先将image_encoder中的参数冻结 83 | for param in sam_model.image_encoder.parameters(): 84 | param.requires_grad = False 85 | 86 | # Here, we do the surgery 87 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): 88 | # If we only want few lora layer instead of all 89 | if t_layer_i not in self.lora_layer: 90 | continue 91 | w_qkv_linear = blk.attn.qkv 92 | self.dim = w_qkv_linear.in_features 93 | w_a_linear_q = nn.Linear(self.dim, r, bias=False) 94 | w_b_linear_q = nn.Linear(r, self.dim, bias=False) 95 | w_a_linear_v = nn.Linear(self.dim, r, bias=False) 96 | w_b_linear_v = nn.Linear(r, self.dim, bias=False) 97 | self.w_As.append(w_a_linear_q) 98 | self.w_Bs.append(w_b_linear_q) 99 | self.w_As.append(w_a_linear_v) 100 | self.w_Bs.append(w_b_linear_v) 101 | blk.attn.qkv = _LoRA_qkv( #在原来的blk中加入lora 102 | w_qkv_linear, 103 | w_a_linear_q, 104 | w_b_linear_q, 105 | w_a_linear_v, 106 | w_b_linear_v, 107 | ) 108 | self.reset_parameters() 109 | self.sam = sam_model 110 | 111 | def save_lora_parameters(self, filename: str) -> None: 112 | r"""Only safetensors is supported now. 113 | 114 | pip install safetensor if you do not have one installed yet. 115 | 116 | save both lora and fc parameters. 117 | """ 118 | 119 | assert filename.endswith(".pt") or filename.endswith('.pth') 120 | 121 | num_layer = len(self.w_As) # actually, it is half 122 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} 123 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} 124 | prompt_encoder_tensors = {} 125 | mask_decoder_tensors = {} 126 | 127 | # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted 128 | if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam, torch.nn.parallel.DistributedDataParallel): 129 | state_dict = self.sam.module.state_dict() 130 | else: 131 | state_dict = self.sam.state_dict() 132 | for key, value in state_dict.items(): 133 | if 'prompt_encoder' in key: 134 | prompt_encoder_tensors[key] = value 135 | if 'mask_decoder' in key: 136 | mask_decoder_tensors[key] = value 137 | 138 | merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors} 139 | torch.save(merged_dict, filename) 140 | 141 | def load_lora_parameters(self, filename: str) -> None: 142 | r"""Only safetensors is supported now. 143 | 144 | pip install safetensor if you do not have one installed yet.\ 145 | 146 | load both lora and fc parameters. 147 | """ 148 | 149 | assert filename.endswith(".pt") or filename.endswith('.pth') 150 | 151 | state_dict = torch.load(filename) 152 | 153 | for i, w_A_linear in enumerate(self.w_As): 154 | saved_key = f"w_a_{i:03d}" 155 | saved_tensor = state_dict[saved_key] 156 | w_A_linear.weight = Parameter(saved_tensor) 157 | 158 | for i, w_B_linear in enumerate(self.w_Bs): 159 | saved_key = f"w_b_{i:03d}" 160 | saved_tensor = state_dict[saved_key] 161 | w_B_linear.weight = Parameter(saved_tensor) 162 | 163 | sam_dict = self.sam.state_dict() 164 | sam_keys = sam_dict.keys() 165 | 166 | # load prompt encoder 167 | prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k] 168 | prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys] 169 | prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)} 170 | sam_dict.update(prompt_encoder_new_state_dict) 171 | 172 | # load mask decoder 173 | mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k] 174 | mask_decoder_values = [state_dict[k] for k in mask_decoder_keys] 175 | mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)} 176 | sam_dict.update(mask_decoder_new_state_dict) 177 | self.sam.load_state_dict(sam_dict) 178 | 179 | def reset_parameters(self) -> None: 180 | for w_A in self.w_As: 181 | nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) 182 | for w_B in self.w_Bs: 183 | nn.init.zeros_(w_B.weight) 184 | 185 | def forward(self, batched_input, multimask_output, image_size): 186 | return self.sam(batched_input, multimask_output, image_size) 187 | 188 | 189 | # def forward(self, x: Tensor) -> Tensor: 190 | # return self.lora_vit(x) 191 | 192 | 193 | if __name__ == "__main__": 194 | sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") 195 | lora_sam = LoRA_Sam(sam, 4) 196 | lora_sam.sam.image_encoder(torch.rand(size=(1, 3, 1024, 1024))) 197 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) # downsample to 1/4 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from icecream import ic 11 | 12 | from typing import Any, Dict, List, Tuple 13 | 14 | from .image_encoder import ImageEncoderViT 15 | from .mask_decoder import MaskDecoder 16 | from .prompt_encoder import PromptEncoder 17 | 18 | 19 | class Sam(nn.Module): 20 | mask_threshold: float = 0.0 21 | image_format: str = "RGB" 22 | 23 | def __init__( 24 | self, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 29 | pixel_std: List[float] = [58.395, 57.12, 57.375], 30 | ) -> None: 31 | """ 32 | SAM predicts object masks from an image and input prompts. 33 | 34 | Arguments: 35 | image_encoder (ImageEncoderViT): The backbone used to encode the 36 | image into image embeddings that allow for efficient mask prediction. 37 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 38 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 39 | and encoded prompts. 40 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 42 | """ 43 | super().__init__() 44 | self.image_encoder = image_encoder 45 | self.prompt_encoder = prompt_encoder 46 | self.mask_decoder = mask_decoder 47 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 48 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 49 | 50 | @property 51 | def device(self) -> Any: 52 | return self.pixel_mean.device 53 | 54 | def forward(self, batched_input, multimask_output, image_size): 55 | if isinstance(batched_input, list): 56 | print("testing") 57 | outputs = self.forward_test(batched_input, multimask_output) 58 | else: 59 | print("training") 60 | outputs = self.forward_train(batched_input, multimask_output, image_size) 61 | return outputs 62 | 63 | #需要将sparse和dense增加扰动再进行forward 64 | def forward_train(self, batched_input, multimask_output, image_size): 65 | input_images = self.preprocess(batched_input) 66 | image_embeddings = self.image_encoder(input_images) 67 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 68 | points=None, boxes=None, masks=None 69 | ) 70 | low_res_masks, iou_predictions = self.mask_decoder( 71 | image_embeddings=image_embeddings, 72 | image_pe=self.prompt_encoder.get_dense_pe(), 73 | sparse_prompt_embeddings=sparse_embeddings, 74 | dense_prompt_embeddings=dense_embeddings, 75 | multimask_output=multimask_output 76 | ) 77 | masks = self.postprocess_masks( 78 | low_res_masks, 79 | input_size=(image_size, image_size), 80 | original_size=(image_size, image_size) 81 | ) 82 | outputs = { 83 | 'masks': masks, 84 | 'iou_predictions': iou_predictions, 85 | 'low_res_logits': low_res_masks 86 | } 87 | return outputs 88 | 89 | @torch.no_grad() 90 | def forward_test( 91 | self, 92 | batched_input: List[Dict[str, Any]], 93 | multimask_output: bool, 94 | ) -> List[Dict[str, torch.Tensor]]: 95 | """ 96 | Predicts masks end-to-end from provided images and prompts. 97 | If prompts are not known in advance, using SamPredictor is 98 | recommended over calling the model directly. 99 | 100 | Arguments: 101 | batched_input (list(dict)): A list over input images, each a 102 | dictionary with the following keys. A prompt key can be 103 | excluded if it is not present. 104 | 'image': The image as a torch tensor in 3xHxW format, 105 | already transformed for input to the model. 106 | 'original_size': (tuple(int, int)) The original size of 107 | the image before transformation, as (H, W). 108 | 'point_coords': (torch.Tensor) Batched point prompts for 109 | this image, with shape BxNx2. Already transformed to the 110 | input frame of the model. 111 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 112 | with shape BxN. 113 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 114 | Already transformed to the input frame of the model. 115 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 116 | in the form Bx1xHxW. 117 | multimask_output (bool): Whether the model should predict multiple 118 | disambiguating masks, or return a single mask. 119 | 120 | Returns: 121 | (list(dict)): A list over input images, where each element is 122 | as dictionary with the following keys. 123 | 'masks': (torch.Tensor) Batched binary mask predictions, 124 | with shape BxCxHxW, where B is the number of input promts, 125 | C is determiend by multimask_output, and (H, W) is the 126 | original size of the image. 127 | 'iou_predictions': (torch.Tensor) The model's predictions 128 | of mask quality, in shape BxC. 129 | 'low_res_logits': (torch.Tensor) Low resolution logits with 130 | shape BxCxHxW, where H=W=256. Can be passed as mask input 131 | to subsequent iterations of prediction. 132 | """ 133 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 134 | image_embeddings = self.image_encoder(input_images) 135 | 136 | outputs = [] 137 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 138 | if "point_coords" in image_record: 139 | points = (image_record["point_coords"], image_record["point_labels"]) 140 | else: 141 | points = None 142 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 143 | points=points, 144 | boxes=image_record.get("boxes", None), 145 | masks=image_record.get("mask_inputs", None), 146 | ) 147 | low_res_masks, iou_predictions = self.mask_decoder( 148 | image_embeddings=curr_embedding.unsqueeze(0), 149 | image_pe=self.prompt_encoder.get_dense_pe(), 150 | sparse_prompt_embeddings=sparse_embeddings, 151 | dense_prompt_embeddings=dense_embeddings, 152 | multimask_output=multimask_output, 153 | ) 154 | masks = self.postprocess_masks( 155 | low_res_masks, 156 | input_size=image_record["image"].shape[-2:], 157 | original_size=image_record["original_size"], 158 | ) 159 | masks = masks > self.mask_threshold 160 | outputs.append( 161 | { 162 | "masks": masks, 163 | "iou_predictions": iou_predictions, 164 | "low_res_logits": low_res_masks, 165 | } 166 | ) 167 | return outputs 168 | 169 | def postprocess_masks( 170 | self, 171 | masks: torch.Tensor, 172 | input_size: Tuple[int, ...], 173 | original_size: Tuple[int, ...], 174 | ) -> torch.Tensor: 175 | """ 176 | Remove padding and upscale masks to the original image size. 177 | 178 | Arguments: 179 | masks (torch.Tensor): Batched masks from the mask_decoder, 180 | in BxCxHxW format. 181 | input_size (tuple(int, int)): The size of the image input to the 182 | model, in (H, W) format. Used to remove padding. 183 | original_size (tuple(int, int)): The original size of the image 184 | before resizing for input to the model, in (H, W) format. 185 | 186 | Returns: 187 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 188 | is given by original_size. 189 | """ 190 | masks = F.interpolate( 191 | masks, 192 | (self.image_encoder.img_size, self.image_encoder.img_size), 193 | mode="bilinear", 194 | align_corners=False, 195 | ) 196 | masks = masks[..., : input_size[0], : input_size[1]] 197 | 198 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 199 | return masks 200 | 201 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 202 | """Normalize pixel values and pad to a square input.""" 203 | # Normalize colors 204 | x = (x - self.pixel_mean) / self.pixel_std 205 | 206 | # Pad 207 | h, w = x.shape[-2:] 208 | padh = self.image_encoder.img_size - h 209 | padw = self.image_encoder.img_size - w 210 | x = F.pad(x, (0, padw, 0, padh)) 211 | return x 212 | 213 | -------------------------------------------------------------------------------- /train_first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Subset 3 | from torchvision import transforms 4 | from tensorboardX import SummaryWriter 5 | from utils import l2_regularisation, calculate_dice_loss, calculate_sigmoid_focal_loss 6 | from load_LIDC_data import LIDC_IDRI, RandomGenerator 7 | import torch.nn as nn 8 | from sam_lora_image_encoder import LoRA_Sam 9 | from segment_anything import sam_model_registry 10 | import numpy as np 11 | import os 12 | import logging 13 | import argparse 14 | 15 | # Configure logging 16 | logging.basicConfig(filename='training_log_first_stage.txt', level=logging.INFO, format='%(asctime)s - %(message)s') 17 | 18 | class MaskWeights(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.weights = nn.Parameter(torch.ones(7, 1, requires_grad=True) / 8) 22 | 23 | def inference(batched_input, lora_sam): 24 | img_size = 128 25 | input_images = lora_sam.sam.preprocess(batched_input) 26 | image_embeddings = lora_sam.sam.image_encoder(input_images) 27 | sparse_embeddings, dense_embeddings = lora_sam.sam.prompt_encoder( 28 | points=None, boxes=None, masks=None 29 | ) 30 | low_res_masks, iou_predictions = lora_sam.sam.mask_decoder( 31 | image_embeddings=image_embeddings, 32 | image_pe=lora_sam.sam.prompt_encoder.get_dense_pe(), 33 | sparse_prompt_embeddings=sparse_embeddings, 34 | dense_prompt_embeddings=dense_embeddings, 35 | multimask_output=True 36 | ) 37 | masks = lora_sam.sam.postprocess_masks( 38 | low_res_masks, 39 | input_size=(img_size, img_size), 40 | original_size=(img_size, img_size) 41 | ) 42 | 43 | return { 44 | 'masks': masks, 45 | 'iou_predictions': iou_predictions, 46 | 'low_res_logits': low_res_masks 47 | } 48 | 49 | def evaluate(model, data_loader, device, mask_weights): 50 | model.eval() 51 | total_loss = 0.0 52 | with torch.no_grad(): 53 | for sampled_batch in data_loader: 54 | image_batch, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device) 55 | outputs = inference(image_batch, model) 56 | output_masks = outputs['masks'] 57 | logits_high = output_masks.to(device) 58 | weights_eight = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0).to(device) 59 | logits_high = logits_high * weights_eight.unsqueeze(-1) 60 | logits_high_res = logits_high.sum(1).unsqueeze(1) 61 | gt_mask = label_batch.unsqueeze(1) 62 | dice_loss = calculate_dice_loss(logits_high_res, gt_mask[:].long()) 63 | focal_loss = calculate_sigmoid_focal_loss(logits_high_res, gt_mask[:].float()) 64 | total_loss += (dice_loss + focal_loss).item() 65 | return total_loss / len(data_loader) 66 | 67 | def main(args): 68 | # Set global device 69 | device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu') 70 | os.makedirs("best", exist_ok=True) 71 | 72 | ckpt = args.checkpoint 73 | img_size = 128 74 | sam, img_embedding_size = sam_model_registry["vit_b"]( 75 | image_size=img_size, 76 | num_classes=8, 77 | pixel_mean=[0, 0, 0], 78 | pixel_std=[1, 1, 1], 79 | checkpoint=ckpt 80 | ) 81 | low_res = img_embedding_size * 4 82 | 83 | db = LIDC_IDRI(dataset_location='data/', transform=transforms.Compose([ 84 | RandomGenerator(output_size=[128, 128], low_res=[low_res, low_res], test=True) 85 | ])) 86 | dataset_size = len(db) 87 | indices = list(range(dataset_size)) 88 | train_split = int(np.floor(0.6 * dataset_size)) 89 | validation_split = int(np.floor(0.8 * dataset_size)) 90 | train_indices = indices[:train_split] 91 | validation_indices = indices[train_split:validation_split] 92 | test_indices = indices[validation_split:] 93 | 94 | train_dataset = Subset(db, train_indices) 95 | validation_dataset = Subset(db, validation_indices) 96 | test_dataset = Subset(db, test_indices) 97 | 98 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 99 | validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=False) 100 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 101 | 102 | print(f"Total dataset size: {dataset_size}") 103 | print(f"Training set size: {len(train_indices)}") 104 | print(f"Validation set size: {len(validation_indices)}") 105 | print(f"Test set size: {len(test_indices)}") 106 | 107 | mask_weights = MaskWeights().to(device) 108 | lora_sam = LoRA_Sam(sam, 4).to(device) 109 | 110 | for param in lora_sam.sam.prompt_encoder.parameters(): 111 | param.requires_grad = True 112 | for param in lora_sam.sam.mask_decoder.parameters(): 113 | param.requires_grad = True 114 | 115 | optimizer1 = torch.optim.Adam(filter(lambda p: p.requires_grad, lora_sam.parameters()), lr=args.lr, weight_decay=0) 116 | optimizer2 = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) 117 | 118 | writer = SummaryWriter('tf-logs/train_first_stage') 119 | max_epoch = args.epochs 120 | best_val_loss = float('inf') 121 | best_epoch = 0 122 | start_epoch = 1 123 | 124 | # Check for the latest checkpoint 125 | latest_checkpoint = f"checkpoint/last_model_epoch_{max_epoch-1}.pth" 126 | if os.path.exists(latest_checkpoint): 127 | lora_sam.load_lora_parameters(latest_checkpoint) 128 | weights_eight = torch.load(f"checkpoint/last_mask_weights_epoch_{max_epoch-1}.pt") 129 | start_epoch = max_epoch 130 | 131 | try: 132 | for epoch_num in range(start_epoch, max_epoch + 1): 133 | lora_sam.train() 134 | mask_weights.train() 135 | loss_epoch = 0.0 136 | print(f"Epoch {epoch_num}") 137 | 138 | for i_batch, sampled_batch in enumerate(train_loader): 139 | image_batch, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device) 140 | assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}' 141 | 142 | outputs = inference(image_batch, lora_sam) 143 | output_masks = outputs['masks'] 144 | logits_high = output_masks.to(device) 145 | weights_eight = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0).to(device) 146 | logits_high = logits_high * weights_eight.unsqueeze(-1) 147 | logits_high_res = logits_high.sum(1).unsqueeze(1) 148 | 149 | cel = torch.nn.CrossEntropyLoss() 150 | loss1 = cel(logits_high, label_batch[:].long()) 151 | loss_epoch += loss1.item() 152 | 153 | gt_mask = label_batch.unsqueeze(1) 154 | dice_loss = calculate_dice_loss(logits_high_res, gt_mask[:].long()) 155 | focal_loss = calculate_sigmoid_focal_loss(logits_high_res, gt_mask[:].float()) 156 | loss2 = dice_loss + focal_loss 157 | loss = loss1 + loss2 158 | 159 | optimizer1.zero_grad() 160 | optimizer2.zero_grad() 161 | loss.backward() 162 | optimizer1.step() 163 | optimizer2.step() 164 | 165 | avg_train_loss = loss_epoch / len(train_loader) 166 | writer.add_scalar("Train/Loss", avg_train_loss, epoch_num) 167 | print(f"Average Training Loss: {avg_train_loss}") 168 | logging.info(f"Epoch {epoch_num}: Average Training Loss: {avg_train_loss}") 169 | 170 | # Validation evaluation 171 | val_loss = evaluate(lora_sam, validation_loader, device, mask_weights) 172 | writer.add_scalar("Validation/Loss", val_loss, epoch_num) 173 | print(f"Validation Loss: {val_loss}") 174 | logging.info(f"Epoch {epoch_num}: Validation Loss: {val_loss}") 175 | 176 | # Save the best model 177 | if val_loss < best_val_loss: 178 | best_val_loss = val_loss 179 | best_epoch = epoch_num 180 | 181 | # Remove previous best model files 182 | for file in os.listdir("checkpoint"): 183 | if file.startswith("best_mask_weights_epoch_") or file.startswith("best_model_epoch_"): 184 | os.remove(os.path.join("checkpoint", file)) 185 | 186 | # Save new best weights 187 | torch.save(weights_eight, f"checkpoint/best_mask_weights_epoch_{best_epoch}.pt") 188 | file_name = f"checkpoint/best_model_epoch_{best_epoch}.pth" 189 | lora_sam.save_lora_parameters(file_name) 190 | 191 | print(f"Best Validation Loss: {best_val_loss} at epoch {best_epoch}") 192 | logging.info(f"Best Validation Loss: {best_val_loss} at epoch {best_epoch}") 193 | 194 | except Exception as e: 195 | logging.error(f"An error occurred: {e}") 196 | finally: 197 | # Save the current checkpoint 198 | torch.save(weights_eight, f"checkpoint/last_mask_weights_epoch_{epoch_num}.pt") 199 | file_name = f"checkpoint/last_model_epoch_{epoch_num}.pth" 200 | lora_sam.save_lora_parameters(file_name) 201 | print(f"Saved checkpoint at epoch {epoch_num}") 202 | logging.info(f"Saved checkpoint at epoch {epoch_num}") 203 | 204 | if __name__ == "__main__": 205 | parser = argparse.ArgumentParser(description='Train the model with specified parameters.') 206 | parser.add_argument('--checkpoint', type=str, default='sam_vit_b_01ec64.pth', help='Path to the checkpoint file.') 207 | parser.add_argument('--batch_size', type=int, default=10, help='Batch size for training.') 208 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.') 209 | parser.add_argument('--epochs', type=int, default=101, help='Number of epochs to train.') 210 | parser.add_argument('--gpu_id', type=int, default=4, help='GPU ID to use for training.') 211 | 212 | args = parser.parse_args() 213 | main(args) -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks = masks[0].detach().cpu().numpy() 164 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 166 | return masks, iou_predictions, low_res_masks 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | box (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold#0.0 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from medpy import metric 5 | from scipy.ndimage import zoom 6 | import torch.nn as nn 7 | import SimpleITK as sitk 8 | import torch.nn.functional as F 9 | import imageio 10 | from einops import repeat 11 | from icecream import ic 12 | from torch.autograd import Variable 13 | import matplotlib.pyplot as plt 14 | from scipy.optimize import linear_sum_assignment 15 | 16 | class FocalLoss(nn.Module): 17 | def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True): 18 | super(FocalLoss, self).__init__() 19 | self.size_average = size_average 20 | if isinstance(alpha, list): 21 | assert len(alpha) == num_classes 22 | print(f'Focal loss alpha={alpha}, will assign alpha values for each class') 23 | self.alpha = torch.Tensor(alpha) 24 | else: 25 | assert alpha < 1 26 | print(f'Focal loss alpha={alpha}, will shrink the impact in background') 27 | self.alpha = torch.zeros(num_classes) 28 | self.alpha[0] = alpha 29 | self.alpha[1:] = 1 - alpha 30 | self.gamma = gamma 31 | self.num_classes = num_classes 32 | 33 | def forward(self, preds, labels): 34 | """ 35 | Calculate focal loss 36 | :param preds: size: [B, N, C] or [B, C], corresponds to detection and classification tasks [B, C, H, W]: segmentation 37 | :param labels: size: [B, N] or [B] [B, H, W]: segmentation 38 | :return: loss 39 | """ 40 | self.alpha = self.alpha.to(preds.device) 41 | preds = preds.permute(0, 2, 3, 1).contiguous() 42 | preds = preds.view(-1, preds.size(-1)) 43 | B, H, W = labels.shape 44 | assert B * H * W == preds.shape[0] 45 | assert preds.shape[-1] == self.num_classes 46 | preds_logsoft = F.log_softmax(preds, dim=1) # log softmax 47 | preds_softmax = torch.exp(preds_logsoft) # softmax 48 | 49 | preds_softmax = preds_softmax.gather(1, labels.view(-1, 1)) 50 | preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1)) 51 | alpha = self.alpha.gather(0, labels.view(-1)) 52 | loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma), preds_logsoft) 53 | 54 | loss = torch.mul(alpha, loss.t()) 55 | if self.size_average: 56 | loss = loss.mean() 57 | else: 58 | loss = loss.sum() 59 | return loss 60 | 61 | class DiceLoss(nn.Module): 62 | def __init__(self, n_classes): 63 | super(DiceLoss, self).__init__() 64 | self.n_classes = n_classes 65 | 66 | def _one_hot_encoder(self, input_tensor): 67 | tensor_list = [] 68 | for i in range(self.n_classes): 69 | temp_prob = input_tensor == i 70 | tensor_list.append(temp_prob.unsqueeze(1)) 71 | output_tensor = torch.cat(tensor_list, dim=1) 72 | return output_tensor.float() 73 | 74 | def _dice_loss(self, score, target): 75 | target = target.float() 76 | smooth = 1e-5 77 | intersect = torch.sum(score * target) 78 | y_sum = torch.sum(target * target) 79 | z_sum = torch.sum(score * score) 80 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 81 | loss = 1 - loss 82 | return loss 83 | 84 | def forward(self, inputs, target, weight=None, softmax=False): 85 | if softmax: 86 | inputs = torch.softmax(inputs, dim=1) 87 | target = self._one_hot_encoder(target) 88 | if weight is None: 89 | weight = [1] * self.n_classes 90 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 91 | class_wise_dice = [] 92 | loss = 0.0 93 | for i in range(0, self.n_classes): 94 | dice = self._dice_loss(inputs[:, i], target[:, i]) 95 | class_wise_dice.append(1.0 - dice.item()) 96 | loss += dice * weight[i] 97 | return loss / self.n_classes 98 | 99 | 100 | def init_weights(m): 101 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 102 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 103 | truncated_normal_(m.bias, mean=0, std=0.001) 104 | 105 | def init_weights_orthogonal_normal(m): 106 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 107 | nn.init.orthogonal_(m.weight) 108 | truncated_normal_(m.bias, mean=0, std=0.001) 109 | 110 | def l2_regularisation(m): 111 | l2_reg = None 112 | for W in m.parameters(): 113 | if l2_reg is None: 114 | l2_reg = W.norm(2) 115 | else: 116 | l2_reg = l2_reg + W.norm(2) 117 | return l2_reg 118 | 119 | def save_mask_prediction_example(mask, pred, iter): 120 | plt.imshow(pred[0, :, :], cmap='Greys') 121 | plt.savefig(f'images/{iter}_prediction.png') 122 | plt.imshow(mask[0, :, :], cmap='Greys') 123 | plt.savefig(f'images/{iter}_mask.png') 124 | 125 | def show_mask(mask, ax, random_color=False): 126 | if random_color: 127 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 128 | else: 129 | color = np.array([251/255, 252/255, 30/255, 0.6]) 130 | h, w = mask.shape[-2:] 131 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 132 | ax.imshow(mask_image) 133 | 134 | def show_box(box, ax): 135 | x0, y0 = box[0], box[1] 136 | w, h = box[2] - box[0], box[3] - box[1] 137 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2)) 138 | 139 | def iou_score_cal(output, target): 140 | smooth = 1e-5 141 | if torch.is_tensor(output): 142 | output = torch.sigmoid(output).data.cpu().numpy() 143 | if torch.is_tensor(target): 144 | target = target.data.cpu().numpy() 145 | output_ = output > 0.5 146 | target_ = target > 0.5 147 | intersection = (output_ & target_).sum() 148 | union = (output_ | target_).sum() 149 | iou = (intersection + smooth) / (union + smooth) 150 | dice = (2 * iou) / (iou + 1) 151 | return iou 152 | 153 | def mask_IoU(prediction, groundtruth): 154 | prediction = prediction.detach().cpu().numpy() 155 | groundtruth = groundtruth.detach().cpu().numpy() 156 | intersection = np.logical_and(groundtruth, prediction) 157 | union = np.logical_or(groundtruth, prediction) 158 | if np.sum(union) == 0: 159 | return 1 160 | iou_score = np.sum(intersection) / np.sum(union) 161 | return iou_score 162 | 163 | def generalized_energy_distance_iou(predictions, masks): 164 | n = predictions.shape[0] 165 | m = masks.shape[0] 166 | d1 = d2 = d3 = 0 167 | for i in range(n): 168 | for j in range(m): 169 | d1 += (1 - mask_IoU(predictions[i], masks[j])) 170 | 171 | for i in range(n): 172 | for j in range(n): 173 | d2 += (1 - mask_IoU(predictions[i], predictions[j])) 174 | 175 | for i in range(m): 176 | for j in range(m): 177 | d3 += (1 - mask_IoU(masks[i], masks[j])) 178 | 179 | d1 *= (2 / (n * m)) 180 | d2 *= (1 / (n * n)) 181 | d3 *= (1 / (m * m)) 182 | 183 | ed = d1 - d2 - d3 184 | scores = mask_IoU(predictions[0], masks[0]) 185 | 186 | return ed, scores 187 | 188 | def dice_score_cal(pred, targs): 189 | pred = (pred > 0).float() 190 | intersection = (pred * targs).sum() 191 | union = pred.sum() + targs.sum() 192 | if union == 0: 193 | return 1.0 # If both prediction and target are zero, return 1 194 | dice_score = 2. * intersection / union 195 | return dice_score 196 | 197 | def dice_coef_cal(output, target): 198 | smooth = 1e-5 199 | output = output.view(-1).data.cpu().numpy() 200 | target = target.view(-1).data.cpu().numpy() 201 | intersection = (output * target).sum() 202 | return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth) 203 | 204 | def iou(pred, true): 205 | """Calculate IOU, input as PyTorch tensors""" 206 | pred_bool = pred.bool().detach().cpu() 207 | true_bool = true.bool().detach().cpu() 208 | intersection = (pred_bool & true_bool).float().sum() 209 | union = (pred_bool | true_bool).float().sum() 210 | if union == 0 and intersection == 0: 211 | return 1 212 | else: 213 | return intersection / union 214 | 215 | def hm_iou_cal(preds, trues): 216 | """Calculate Hungarian-Matched IOU, input as lists of PyTorch tensors""" 217 | num_preds = len(preds) 218 | num_trues = len(trues) 219 | cost_matrix = torch.zeros((num_preds, num_trues)) 220 | for i, pred in enumerate(preds): 221 | for j, true in enumerate(trues): 222 | cost_matrix[i, j] = 1 - iou(pred, true) 223 | row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy()) 224 | matched_iou = [iou(preds[i], trues[j]) for i, j in zip(row_ind, col_ind)] 225 | avg_iou = torch.tensor(matched_iou).mean().item() 226 | return avg_iou, matched_iou 227 | 228 | def calculate_dice_loss(inputs, targets, num_masks=5): 229 | """ 230 | Compute the DICE loss, similar to generalized IOU for masks 231 | Args: 232 | inputs: A float tensor of arbitrary shape. 233 | The predictions for each example. 234 | targets: A float tensor with the same shape as inputs. Stores the binary 235 | classification label for each element in inputs 236 | (0 for the negative class and 1 for the positive class). 237 | """ 238 | inputs = inputs.sigmoid() 239 | numerator = 2 * (inputs * targets).sum(-1) 240 | denominator = inputs.sum(-1) + targets.sum(-1) 241 | loss = 1 - (numerator + 1) / (denominator + 1) 242 | return loss.sum() / num_masks 243 | 244 | def calculate_sigmoid_focal_loss(inputs, targets, num_masks=5, alpha: float = 0.25, gamma: float = 2): 245 | """ 246 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 247 | Args: 248 | inputs: A float tensor of arbitrary shape. 249 | The predictions for each example. 250 | targets: A float tensor with the same shape as inputs. Stores the binary 251 | classification label for each element in inputs 252 | (0 for the negative class and 1 for the positive class). 253 | alpha: (optional) Weighting factor in range (0,1) to balance 254 | positive vs negative examples. Default = -1 (no weighting). 255 | gamma: Exponent of the modulating factor (1 - p_t) to 256 | balance easy vs hard examples. 257 | Returns: 258 | Loss tensor 259 | """ 260 | prob = inputs.sigmoid() 261 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 262 | p_t = prob * targets + (1 - prob) * (1 - targets) 263 | loss = ce_loss * ((1 - p_t) ** gamma) 264 | 265 | if alpha >= 0: 266 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 267 | loss = alpha_t * loss 268 | 269 | return loss.mean(1).sum() / num_masks 270 | 271 | def mean_dice_cal(pred_list, label_four): 272 | n = len(pred_list) 273 | m = len(label_four) 274 | dice = 0 275 | for i in range(n): 276 | for j in range(m): 277 | dice += dice_score_cal(pred_list[i].to(dtype=torch.float).squeeze().cpu().detach(), label_four[j].squeeze(0).cpu().detach()) 278 | dice_mean = dice / (n * m) 279 | return dice_mean 280 | 281 | def dice_max_cal1(pred_eval, label_four): 282 | dice_max = 0 283 | for i in range(pred_eval.shape[0]): 284 | dice_max_iter = 0 285 | for j in range(label_four.shape[0]): 286 | dice_score_iter = dice_score_cal(pred_eval[i].to(dtype=torch.float).squeeze().cpu().detach(), label_four[j].squeeze(0).cpu().detach()) 287 | if j == 0: 288 | dice_max_iter = dice_score_iter 289 | else: 290 | if dice_score_iter > dice_max_iter: 291 | dice_max_iter = dice_score_iter 292 | dice_max += dice_max_iter 293 | return dice_max / pred_eval.shape[0] 294 | 295 | def dice_max_cal2(pred_eval, label_four): 296 | dice_max = -1 297 | for i in range(pred_eval.shape[0]): 298 | for j in range(label_four.shape[0]): 299 | dice_score_iter = dice_score_cal(pred_eval[i].to(dtype=torch.float).squeeze().cpu().detach(), label_four[j].squeeze(0).cpu().detach()) 300 | if dice_score_iter > dice_max: 301 | dice_max = dice_score_iter 302 | return dice_max 303 | 304 | def dice_avg_cal(pred_list, label_four): 305 | dice_all = 0 306 | pred_stack = torch.stack(pred_list, dim=0) 307 | pred_avg = (pred_stack > 0).cpu().detach() 308 | pred_avg = torch.where(pred_avg, torch.tensor(1), torch.tensor(0)) 309 | pred_avg = torch.mean(pred_stack, dim=0) 310 | pred_avg = (pred_avg > 0).cpu().detach() 311 | pred_avg = torch.where(pred_avg, torch.tensor(1), torch.tensor(0)) 312 | for i in range(label_four.shape[0]): 313 | dice_score_iter = dice_score_cal(pred_avg.to(dtype=torch.float).squeeze().cpu().detach(), label_four[i].squeeze(0).cpu().detach()) 314 | dice_all += dice_score_iter 315 | return dice_all / label_four.shape[0], pred_avg 316 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecesary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from icecream import ic 11 | 12 | from typing import Optional, Tuple, Type 13 | 14 | from .common import LayerNorm2d, MLPBlock 15 | 16 | 17 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 18 | class ImageEncoderViT(nn.Module): 19 | def __init__( 20 | self, 21 | img_size: int = 1024, 22 | patch_size: int = 16, 23 | in_chans: int = 3, 24 | embed_dim: int = 768, 25 | depth: int = 12, 26 | num_heads: int = 12, 27 | mlp_ratio: float = 4.0, 28 | out_chans: int = 256, 29 | qkv_bias: bool = True, 30 | norm_layer: Type[nn.Module] = nn.LayerNorm, 31 | act_layer: Type[nn.Module] = nn.GELU, 32 | use_abs_pos: bool = True, 33 | use_rel_pos: bool = False, 34 | rel_pos_zero_init: bool = True, 35 | window_size: int = 0, 36 | global_attn_indexes: Tuple[int, ...] = (), 37 | ) -> None: 38 | """ 39 | Args: 40 | img_size (int): Input image size. 41 | patch_size (int): Patch size. 42 | in_chans (int): Number of input image channels. 43 | embed_dim (int): Patch embedding dimension. 44 | depth (int): Depth of ViT. 45 | num_heads (int): Number of attention heads in each ViT block. 46 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 47 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 48 | norm_layer (nn.Module): Normalization layer. 49 | act_layer (nn.Module): Activation layer. 50 | use_abs_pos (bool): If True, use absolute positional embeddings. 51 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 52 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 53 | window_size (int): Window size for window attention blocks. 54 | global_attn_indexes (list): Indexes for blocks using global attention. 55 | """ 56 | super().__init__() 57 | self.img_size = img_size 58 | 59 | self.patch_embed = PatchEmbed( 60 | kernel_size=(patch_size, patch_size), 61 | stride=(patch_size, patch_size), 62 | in_chans=in_chans, 63 | embed_dim=embed_dim, 64 | ) 65 | 66 | self.pos_embed: Optional[nn.Parameter] = None 67 | if use_abs_pos: 68 | # Initialize absolute positional embedding with pretrain image size. 69 | self.pos_embed = nn.Parameter( 70 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 71 | ) 72 | 73 | self.blocks = nn.ModuleList() 74 | for i in range(depth): 75 | block = Block( 76 | dim=embed_dim, 77 | num_heads=num_heads, 78 | mlp_ratio=mlp_ratio, 79 | qkv_bias=qkv_bias, 80 | norm_layer=norm_layer, 81 | act_layer=act_layer, 82 | use_rel_pos=use_rel_pos, 83 | rel_pos_zero_init=rel_pos_zero_init, 84 | window_size=window_size if i not in global_attn_indexes else 0, 85 | input_size=(img_size // patch_size, img_size // patch_size), 86 | ) 87 | self.blocks.append(block) 88 | 89 | self.neck = nn.Sequential( 90 | nn.Conv2d( 91 | embed_dim, 92 | out_chans, 93 | kernel_size=1, 94 | bias=False, 95 | ), 96 | LayerNorm2d(out_chans), 97 | nn.Conv2d( 98 | out_chans, 99 | out_chans, 100 | kernel_size=3, 101 | padding=1, 102 | bias=False, 103 | ), 104 | LayerNorm2d(out_chans), 105 | ) 106 | 107 | def forward(self, x: torch.Tensor) -> torch.Tensor: 108 | x = self.patch_embed(x) # pre embed: [1, 3, 1024, 1024], post embed: [1, 64, 64, 768] 109 | if self.pos_embed is not None: 110 | x = x + self.pos_embed 111 | 112 | for blk in self.blocks: 113 | x = blk(x) 114 | 115 | x = self.neck(x.permute(0, 3, 1, 2)) # [b, c, h, w], [1, 256, 64, 64] 116 | 117 | return x 118 | 119 | 120 | class Block(nn.Module): 121 | """Transformer blocks with support of window attention and residual propagation blocks""" 122 | 123 | def __init__( 124 | self, 125 | dim: int, 126 | num_heads: int, 127 | mlp_ratio: float = 4.0, 128 | qkv_bias: bool = True, 129 | norm_layer: Type[nn.Module] = nn.LayerNorm, 130 | act_layer: Type[nn.Module] = nn.GELU, 131 | use_rel_pos: bool = False, 132 | rel_pos_zero_init: bool = True, 133 | window_size: int = 0, 134 | input_size: Optional[Tuple[int, int]] = None, 135 | ) -> None: 136 | """ 137 | Args: 138 | dim (int): Number of input channels. 139 | num_heads (int): Number of attention heads in each ViT block. 140 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 141 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 142 | norm_layer (nn.Module): Normalization layer. 143 | act_layer (nn.Module): Activation layer. 144 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 145 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 146 | window_size (int): Window size for window attention blocks. If it equals 0, then 147 | use global attention. 148 | input_size (int or None): Input resolution for calculating the relative positional 149 | parameter size. 150 | """ 151 | super().__init__() 152 | self.norm1 = norm_layer(dim) 153 | self.attn = Attention( 154 | dim, 155 | num_heads=num_heads, 156 | qkv_bias=qkv_bias, 157 | use_rel_pos=use_rel_pos, 158 | rel_pos_zero_init=rel_pos_zero_init, 159 | input_size=input_size if window_size == 0 else (window_size, window_size), 160 | ) 161 | 162 | self.norm2 = norm_layer(dim) 163 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 164 | 165 | self.window_size = window_size 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | shortcut = x 169 | x = self.norm1(x) 170 | # Window partition 171 | if self.window_size > 0: 172 | H, W = x.shape[1], x.shape[2] 173 | x, pad_hw = window_partition(x, self.window_size) # [B * num_windows, window_size, window_size, C] 174 | 175 | x = self.attn(x) 176 | # Reverse window partition 177 | if self.window_size > 0: 178 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 179 | 180 | x = shortcut + x 181 | x = x + self.mlp(self.norm2(x)) 182 | 183 | return x 184 | 185 | 186 | class Attention(nn.Module): 187 | """Multi-head Attention block with relative position embeddings.""" 188 | 189 | def __init__( 190 | self, 191 | dim: int, 192 | num_heads: int = 8, 193 | qkv_bias: bool = True, 194 | use_rel_pos: bool = False, 195 | rel_pos_zero_init: bool = True, 196 | input_size: Optional[Tuple[int, int]] = None, 197 | ) -> None: 198 | """ 199 | Args: 200 | dim (int): Number of input channels. 201 | num_heads (int): Number of attention heads. 202 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 203 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 204 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 205 | input_size (int or None): Input resolution for calculating the relative positional 206 | parameter size. 207 | """ 208 | super().__init__() 209 | self.num_heads = num_heads 210 | head_dim = dim // num_heads 211 | self.scale = head_dim**-0.5 212 | 213 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 214 | self.proj = nn.Linear(dim, dim) 215 | 216 | self.use_rel_pos = use_rel_pos 217 | if self.use_rel_pos: 218 | assert ( 219 | input_size is not None 220 | ), "Input size must be provided if using relative positional encoding." 221 | # initialize relative positional embeddings 222 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 223 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 224 | 225 | def forward(self, x: torch.Tensor) -> torch.Tensor: 226 | B, H, W, _ = x.shape 227 | # qkv with shape (3, B, nHead, H * W, C) 228 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 229 | # q, k, v with shape (B * nHead, H * W, C) 230 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 231 | 232 | attn = (q * self.scale) @ k.transpose(-2, -1) 233 | 234 | if self.use_rel_pos: 235 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 236 | 237 | attn = attn.softmax(dim=-1) 238 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 239 | x = self.proj(x) 240 | 241 | return x 242 | 243 | 244 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 245 | """ 246 | Partition into non-overlapping windows with padding if needed. 247 | Args: 248 | x (tensor): input tokens with [B, H, W, C]. 249 | window_size (int): window size. 250 | 251 | Returns: 252 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 253 | (Hp, Wp): padded height and width before partition 254 | """ 255 | B, H, W, C = x.shape 256 | 257 | pad_h = (window_size - H % window_size) % window_size 258 | pad_w = (window_size - W % window_size) % window_size 259 | if pad_h > 0 or pad_w > 0: 260 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 261 | Hp, Wp = H + pad_h, W + pad_w 262 | 263 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 264 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 265 | return windows, (Hp, Wp) 266 | 267 | 268 | def window_unpartition( 269 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 270 | ) -> torch.Tensor: 271 | """ 272 | Window unpartition into original sequences and removing padding. 273 | Args: 274 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 275 | window_size (int): window size. 276 | pad_hw (Tuple): padded height and width (Hp, Wp). 277 | hw (Tuple): original height and width (H, W) before padding. 278 | 279 | Returns: 280 | x: unpartitioned sequences with [B, H, W, C]. 281 | """ 282 | Hp, Wp = pad_hw 283 | H, W = hw 284 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 285 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 286 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 287 | 288 | if Hp > H or Wp > W: 289 | x = x[:, :H, :W, :].contiguous() 290 | return x 291 | 292 | 293 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 294 | """ 295 | Get relative positional embeddings according to the relative positions of 296 | query and key sizes. 297 | Args: 298 | q_size (int): size of query q. 299 | k_size (int): size of key k. 300 | rel_pos (Tensor): relative position embeddings (L, C). 301 | 302 | Returns: 303 | Extracted positional embeddings according to relative positions. 304 | """ 305 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 306 | # Interpolate rel pos if needed. 307 | if rel_pos.shape[0] != max_rel_dist: 308 | # Interpolate rel pos. 309 | rel_pos_resized = F.interpolate( 310 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 311 | size=max_rel_dist, 312 | mode="linear", 313 | ) 314 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 315 | else: 316 | rel_pos_resized = rel_pos 317 | 318 | # Scale the coords with short length if shapes for q and k are different. 319 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 320 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 321 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 322 | 323 | return rel_pos_resized[relative_coords.long()] 324 | 325 | 326 | def add_decomposed_rel_pos( 327 | attn: torch.Tensor, 328 | q: torch.Tensor, 329 | rel_pos_h: torch.Tensor, 330 | rel_pos_w: torch.Tensor, 331 | q_size: Tuple[int, int], 332 | k_size: Tuple[int, int], 333 | ) -> torch.Tensor: 334 | """ 335 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 336 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 337 | Args: 338 | attn (Tensor): attention map. 339 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 340 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 341 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 342 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 343 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 344 | 345 | Returns: 346 | attn (Tensor): attention map with added relative positional embeddings. 347 | """ 348 | q_h, q_w = q_size 349 | k_h, k_w = k_size 350 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 351 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 352 | 353 | B, _, dim = q.shape 354 | r_q = q.reshape(B, q_h, q_w, dim) 355 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 356 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 357 | 358 | attn = ( 359 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 360 | ).view(B, q_h * q_w, k_h * k_w) 361 | 362 | return attn 363 | 364 | 365 | class PatchEmbed(nn.Module): 366 | """ 367 | Image to Patch Embedding. 368 | """ 369 | 370 | def __init__( 371 | self, 372 | kernel_size: Tuple[int, int] = (16, 16), 373 | stride: Tuple[int, int] = (16, 16), 374 | padding: Tuple[int, int] = (0, 0), 375 | in_chans: int = 3, 376 | embed_dim: int = 768, 377 | ) -> None: 378 | """ 379 | Args: 380 | kernel_size (Tuple): kernel size of the projection layer. 381 | stride (Tuple): stride of the projection layer. 382 | padding (Tuple): padding size of the projection layer. 383 | in_chans (int): Number of input image channels. 384 | embed_dim (int): embed_dim (int): Patch embedding dimension. 385 | """ 386 | super().__init__() 387 | 388 | self.proj = nn.Conv2d( 389 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 390 | ) 391 | 392 | def forward(self, x: torch.Tensor) -> torch.Tensor: 393 | x = self.proj(x) 394 | # B C H W -> B H W C 395 | x = x.permute(0, 2, 3, 1) 396 | return x 397 | -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros(len(data["boxes"])), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros(len(data["boxes"])), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros(len(boxes)), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | --------------------------------------------------------------------------------