├── work_dir_cod └── weight_place.txt ├── data └── BEGIN.png ├── segment_anything ├── utils │ ├── __init__.py │ ├── transforms.py │ ├── onnx.py │ └── amg.py ├── __init__.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── PUA_res_plus.py │ ├── bias_correction.py │ ├── CWDLoss.py │ ├── mix_enbedding.py │ ├── loop_finer.py │ ├── guided_filter.py │ ├── DWT.py │ ├── sd_merge.py │ ├── agent_swin.py │ ├── mask_decoder.py │ ├── sam.py │ ├── transformer.py │ ├── prompt_encoder.py │ └── image_encoder.py ├── build_sam.py ├── predictor.py └── automatic_mask_generator.py ├── MSCAF_COD_evaluation ├── test_speed_for_count_nonzero.py ├── evaluation.py └── sod_metrics │ └── __init__.py ├── transformer_npz_2_gt.py ├── utils ├── precompute_img_embed.py └── dataset.py ├── README.md ├── environment.yaml ├── Mytrain.py ├── Mytest.py └── pre_npz.py /work_dir_cod/weight_place.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/BEGIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/DSAM/HEAD/data/BEGIN.png -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .pvtv2 import pvt_v2_b2 13 | from .DWT import extract_high_frequency 14 | from .mix_enbedding import ME 15 | from .bias_correction import bias_correction 16 | from .loop_finer import Loop_Finer 17 | # from .IRB import InvertedResidual 18 | # from torchvision.models.mobilenetv2 import InvertedResidual 19 | -------------------------------------------------------------------------------- /MSCAF_COD_evaluation/test_speed_for_count_nonzero.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/11/27 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | import time 6 | 7 | import numpy as np 8 | 9 | 10 | # 快速统计numpy数组的非零值建议使用np.count_nonzero,一个简单的小实验 11 | def cal_nonzero(size): 12 | a = np.random.randn(size, size) 13 | a = a > 0 14 | start = time.time() 15 | print(np.count_nonzero(a), time.time() - start) 16 | start = time.time() 17 | print(np.sum(a), time.time() - start) 18 | start = time.time() 19 | print(len(np.nonzero(a)[0]), time.time() - start) 20 | 21 | 22 | if __name__ == '__main__': 23 | cal_nonzero(1000) 24 | # 499950 6.723403930664062e-05 25 | # 499950 0.0006949901580810547 26 | # 499950 0.007088184356689453 27 | -------------------------------------------------------------------------------- /transformer_npz_2_gt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io 3 | import os 4 | join = os.path.join 5 | 6 | 7 | path = "data/inference_npz/DSAM/" 8 | save_path = 'data/inference_img/DSAM/' 9 | # data = np.load("data/demo2D_vit_b/CVC-ColonDB/CVC-ColonDB.npz") 10 | npz_folders = sorted(os.listdir(path)) 11 | for npz_folder in npz_folders: 12 | # 加载npz文件 13 | data = np.load(join(path, npz_folder,npz_folder) + '.npz') 14 | # 获取图像数据 15 | imgs = data["medsam_segs"] # 假设图像数据保存在名为"imgs"的数组中 16 | name = data['number'] 17 | 18 | # 将图像数据转换为正确的数据类型和范围 19 | imgs = imgs.astype(np.uint8) 20 | imgs = (imgs * 255.0).astype(np.uint8) # 根据图像数据的范围进行调整 21 | 22 | # 保存图像数据为图片文件 23 | for i, img in enumerate(imgs): 24 | num = 149 + i 25 | img_path = join(save_path, npz_folder) + "/" + name[i] # 设置保存图片的路径和文件名 26 | if not os.path.exists(join(save_path, npz_folder)): 27 | os.makedirs(join(save_path, npz_folder), exist_ok=True) 28 | io.imsave(img_path, img) 29 | 30 | print("Images saved successfully.") -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/PUA_res_plus.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def _upsample_like(src, tar): 5 | # 将 src 移动到与 tar 相同的设备 (GPU) 6 | # src = src.to(device) 7 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear') 8 | return src 9 | class PUAModule(nn.Module): 10 | def __init__(self, channel): 11 | super(PUAModule, self).__init__() 12 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, dilation=1, groups=channel//2) 13 | self.conv2 = nn.Conv2d(channel, 2*channel, kernel_size=3, stride=1, padding=1, dilation=2, groups=channel//2) 14 | self.conv3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=2, padding=1, dilation=3, groups=channel//2) 15 | 16 | # self.conv4 = nn.Conv2d(channel, channel/2, kernel_size=1, stride=1, padding=1) 17 | # self.classifier = nn.Conv2d(channel, 1, kernel_size=3, stride=2, padding=1) 18 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 19 | self.bn1 = nn.BatchNorm2d(channel) 20 | self.bn2 = nn.BatchNorm2d(2*channel) 21 | self.bn3 = nn.BatchNorm2d(channel) 22 | # self.bn4 = nn.BatchNorm2d(channel) 23 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 24 | # #self.sigmoid = nn.Sigmoid() 25 | def forward(self, x): 26 | res_x = x 27 | x = self.conv1(x) 28 | x = self.bn1(x) 29 | x = self.leaky_relu(x) 30 | res_x = _upsample_like(res_x, x) 31 | x = x + res_x 32 | res_x_2 = x 33 | x = self.conv2(x) 34 | x = self.bn2(x) 35 | x = self.leaky_relu(x) 36 | # res_x_2 = _upsample_like(res_x_2, x) 37 | # x = x + res_x_2 38 | # res_x_3 = x 39 | x = self.conv3(x) 40 | x = self.bn3(x) 41 | x = self.leaky_relu(x) 42 | res_x_2 = _upsample_like(res_x_2, x) 43 | x = x + res_x_2 44 | # x = self.conv4(x) 45 | # x = self.conv4(x) 46 | # x = self.bn4(x) 47 | # x = self.leaky_relu(x) 48 | # x = self.classifier(x) 49 | return x -------------------------------------------------------------------------------- /segment_anything/modeling/bias_correction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from segment_anything.modeling.PUA_res_plus import PUAModule 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class CrossAttention(nn.Module): 7 | def __init__(self, in_channels): 8 | super(CrossAttention, self).__init__() 9 | 10 | # Query, Key, and Value projections for both input vectors 11 | self.query_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 12 | self.key_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 13 | self.value_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 14 | 15 | self.query_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 16 | self.key_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 17 | self.value_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1) 18 | 19 | def forward(self, v1, v2): 20 | # Project vectors to Query, Key, and Value 21 | query_v1 = self.query_v1(v1) 22 | key_v2 = self.key_v2(v2) 23 | value_v2 = self.value_v2(v2) 24 | 25 | # Compute attention scores 26 | scores = torch.matmul(query_v1.view(query_v1.size(0), -1, query_v1.size(-1)), 27 | key_v2.view(key_v2.size(0), -1, key_v2.size(-1)).transpose(1, 2)) 28 | attention = F.softmax(scores, dim=-1) 29 | 30 | # Apply attention to values 31 | output_v1 = torch.matmul(attention, value_v2.view(value_v2.size(0), -1, value_v2.size(-1))) 32 | output_v1 = output_v1.view(v1.size()) 33 | 34 | return output_v1 35 | 36 | def _upsample_like_64(src): 37 | src = F.interpolate(src, size=(64,64), mode='bilinear') 38 | return src 39 | 40 | class bias_correction(nn.Module): 41 | def __init__(self, out_channels): 42 | super(bias_correction, self).__init__() 43 | self.PUA = PUAModule(out_channels) 44 | self.conv = nn.Conv2d(out_channels, out_channels // 2, kernel_size=1, stride=1, padding=1) 45 | def forward(self, embedding): 46 | embedding_64 = _upsample_like_64(embedding) 47 | pua_em = self.PUA(embedding) 48 | embedding = self.conv(embedding) 49 | pua_em = self.conv(pua_em) 50 | pua_em = _upsample_like_64(pua_em) 51 | embedding = _upsample_like_64(embedding) 52 | out_em = pua_em + embedding 53 | return out_em, embedding_64 54 | 55 | -------------------------------------------------------------------------------- /segment_anything/modeling/CWDLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ChannelNorm(nn.Module): 4 | '''Channel-wise knowledge distillation for dense prediction''' 5 | 6 | def __init__(self): 7 | super(ChannelNorm, self).__init__() 8 | def forward(self,featmap): 9 | n,c,h,w = featmap.shape 10 | featmap = featmap.reshape((n,c,-1)) 11 | featmap = featmap.softmax(dim=-1) 12 | return featmap 13 | 14 | class CriterionCWD(nn.Module): 15 | '''Channel-wise knowledge distillation for dense prediction''' 16 | 17 | def __init__(self, norm_type='none', divergence='mse', temperature=1.0): 18 | 19 | super(CriterionCWD, self).__init__() 20 | 21 | # define normalize function 22 | if norm_type == 'channel': 23 | self.normalize = ChannelNorm() 24 | elif norm_type == 'spatial': 25 | self.normalize = nn.Softmax(dim=1) 26 | elif norm_type == 'channel_mean': 27 | self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1) 28 | else: 29 | self.normalize = None 30 | self.norm_type = norm_type 31 | 32 | self.temperature = 1.0 33 | 34 | # define loss function 35 | if divergence == 'mse': 36 | self.criterion = nn.MSELoss(reduction='sum') 37 | elif divergence == 'kl': 38 | self.criterion = nn.KLDivLoss(reduction='sum') 39 | self.temperature = temperature 40 | self.divergence = divergence 41 | 42 | def forward(self, preds_S, preds_T): 43 | 44 | n, c, h, w = preds_S.shape 45 | # import pdb;pdb.set_trace() 46 | if self.normalize is not None: 47 | norm_s = self.normalize(preds_S / self.temperature) 48 | norm_t = self.normalize(preds_T.detach() / self.temperature) 49 | else: 50 | norm_s = preds_S[0] 51 | norm_t = preds_T[0].detach() 52 | 53 | if self.divergence == 'kl': 54 | norm_s = norm_s.log() 55 | loss = self.criterion(norm_s, norm_t) 56 | 57 | # item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)] 58 | # import pdb;pdb.set_trace() 59 | if self.norm_type == 'channel' or self.norm_type == 'channel_mean': 60 | loss /= n * c 61 | # loss /= n * h * w 62 | else: 63 | loss /= n * h * w 64 | 65 | return loss * (self.temperature ** 2) -------------------------------------------------------------------------------- /utils/precompute_img_embed.py: -------------------------------------------------------------------------------- 1 | #%% import packages 2 | # precompute image embeddings and save them to disk for model training 3 | 4 | import numpy as np 5 | import os 6 | join = os.path.join 7 | from skimage import io, segmentation 8 | from tqdm import tqdm 9 | import torch 10 | from segment_anything import sam_model_registry 11 | from segment_anything.utils.transforms import ResizeLongestSide 12 | import argparse 13 | 14 | #%% parse arguments 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-i', '--img_path', type=str, default='./data/Tr_Release_Part1', help='# and also Tr_Release_Part2 when part1 is done') 17 | parser.add_argument('-o', '--save_path', type=str, default='./data/Tr_npy', help='path to save the image embeddings') 18 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type') 19 | parser.add_argument('--checkpoint', type=str, default='../work_dir/SAM/sam_vit_b_01ec64.pth', help='path to the pre-trained SAM model') 20 | args = parser.parse_args() 21 | 22 | pre_img_path = args.img_path 23 | save_img_emb_path = join(args.save_path, 'npy_embs') 24 | save_gt_path = join(args.save_path, 'npy_gts') 25 | os.makedirs(save_img_emb_path, exist_ok=True) 26 | os.makedirs(save_gt_path, exist_ok=True) 27 | npz_files = sorted(os.listdir(pre_img_path)) 28 | #%% set up the model 29 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0') 30 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 31 | 32 | # compute image embeddings 33 | for name in tqdm(npz_files): 34 | img = np.load(join(pre_img_path, name))['img'] # (256, 256, 3) 35 | gt = np.load(join(pre_img_path, name))['gt'] 36 | resize_img = sam_transform.apply_image(img) 37 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0') 38 | # model input: (1, 3, 1024, 1024) 39 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024) 40 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024' 41 | with torch.no_grad(): 42 | embedding = sam_model.image_encoder(input_image) 43 | 44 | # save as npy 45 | np.save(join(save_img_emb_path, name.split('.npz')[0]+'.npy'), embedding.cpu().numpy()[0]) 46 | np.save(join(save_gt_path, name.split('.npz')[0]+'.npy'), gt) 47 | # sanity check 48 | img_idx = img.copy() 49 | bd = segmentation.find_boundaries(gt, mode='inner') 50 | img_idx[bd, :] = [255, 0, 0] 51 | io.imsave(save_img_emb_path + '.png', img_idx) 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Exploring Deeper! Segment Anything Model with Depth Perception for Camouflaged Object Detection

3 | Zhenni Yu, Xiaoqin Zhang, Li Zhao, Yi Bin, Guobao Xiao 4 | ACM MM, 2024 5 |
6 | 7 | ![ Comparison of our COMPrompter and other methods in COD](data/BEGIN.png) 8 | 9 | ## Usage 10 | 11 | ### Installation 12 | 13 | ```bash 14 | git clone https://github.com/guobaoxiao/DSAM 15 | cd DSAM 16 | ``` 17 | 18 | ### environment 19 | 20 | ```bash 21 | conda env create -f environment.yaml 22 | ``` 23 | 24 | ## From datasets to npz 25 | you can load down the COD datasets and run this to get npz for train. 26 | ```bash 27 | python pre_npz.py 28 | ``` 29 | 30 | - **COD datasets**: 31 | download the COD datasets set from [here](https://github.com/lartpang/awesome-segmentation-saliency-dataset#camouflaged-object-detection-cod)(CAMO, COD10K, NC4K), and put into 'data/' 32 | 33 | - **depth datasets**: 34 | download the depth datasets set, put into 'data/'. The depth image is from PopNet. 35 | 36 | - 通过百度网盘分享的文件:Train_depth.zip 37 | 链接:https://pan.baidu.com/s/1grcASolza9GLpHIVk8mESQ 38 | 提取码:wocz 39 | - 通过百度网盘分享的文件:Test_depth.zip 40 | 链接:https://pan.baidu.com/s/1HobAvMBpfSUfUHNXGZeFLw 41 | 提取码:32ut 42 | 43 | 44 | ### Weights 45 | - **pre-weigth**: 46 | download the weight of sam from [here](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth), the weight of pvt form xxx, put into 'work_dir_cod/SAM/' 47 | 48 | - **DSAM**: 49 | download the weight of well-trained DSAM, put into 'work_dir_cod/DSAM' 50 | - 通过百度网盘分享的文件:DSAM.pth 51 | 链接:https://pan.baidu.com/s/1148mXSjTv7OKlWHcfZFh5A 52 | 提取码:39xx 53 | 54 | 55 | ### The predicted image 56 | - **DSAM**: 57 | - 通过百度网盘分享的文件:DSAM.zip 58 | 链接:https://pan.baidu.com/s/1V5372Z_GdHzYEyOR3iEu4Q 59 | 提取码:fu49 60 | 61 | ### Train 62 | ```bash 63 | python Mytrain.py 64 | ``` 65 | 66 | ### Test 67 | 68 | ```bash 69 | python Mytest.py 70 | ``` 71 | 72 | ### Translate npz to img 73 | 74 | ```bash 75 | python transformer_nzp_2_gt.py 76 | ``` 77 | 78 | ### eval 79 | 80 | ```bash 81 | python MSCAF_COD_evaluation/evaluation.py 82 | ``` 83 | ## Citation 84 | 85 | If you find this project useful, please consider citing: 86 | 87 | ```bibtex 88 | @inproceedings{yu2024exploring, 89 | title={Exploring Deeper! Segment Anything Model with Depth Perception for Camouflaged Object Detection}, 90 | author={Zhenni Yu and Xiaoqin Zhang and LiZhao and Yi Bin and Guobao Xiao}, 91 | booktitle={ACM Multimedia 2024}, 92 | year={2024}, 93 | url={https://openreview.net/forum?id=d4A0Cw1gVS} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /MSCAF_COD_evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/11/21 3 | # @Author : Lart Pang 4 | # @GitHub : https://github.com/lartpang 5 | 6 | import os 7 | import sys 8 | 9 | import cv2 10 | from tqdm import tqdm 11 | import torch 12 | import sod_metrics as M 13 | import torch.nn.functional as F 14 | 15 | FM = M.Fmeasure() 16 | WFM = M.WeightedFmeasure() 17 | SM = M.Smeasure() 18 | EM = M.Emeasure() 19 | MAE = M.MAE() 20 | 21 | mask_root = '/data/Jenny/2309COMPrompter_plus_DSAM/data/TestDataset/NC4K/GT' 22 | pred_root = '/data/Jenny/2309COMPrompter_plus_DSAM/data/inference_img/240611_offfice_DSAM_2/NC4K' 23 | 24 | def _upsample_like(src, tar): 25 | src = torch.tensor(src, dtype=torch.float32) 26 | tar = torch.tensor(tar) 27 | src = F.interpolate(src.unsqueeze(0).unsqueeze(0), size=tar.shape, mode='bilinear') 28 | src = src.squeeze(0).squeeze(0).numpy() 29 | return src 30 | mask_name_list = sorted(os.listdir(mask_root)) 31 | for mask_name in tqdm(mask_name_list, total=len(mask_name_list)): 32 | mask_path = os.path.join(mask_root, mask_name) 33 | mask_name_for_pred = mask_name.replace(".png", ".jpg") 34 | # mask_name_for_pred = mask_name 35 | # pred_path = os.path.join(pred_root, mask_name_for_pred) 36 | pred_path = os.path.join(pred_root, mask_name_for_pred) 37 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 38 | pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE) 39 | if len(pred.shape) != 2: 40 | pred = pred[:, :, 0] # 返回(height, width) 41 | if len(mask.shape) != 2: 42 | mask = mask[:, :, 0] 43 | pred = _upsample_like(pred, mask) 44 | assert pred.shape == mask.shape 45 | 46 | FM.step(pred=pred, gt=mask) 47 | WFM.step(pred=pred, gt=mask) 48 | SM.step(pred=pred, gt=mask) 49 | EM.step(pred=pred, gt=mask) 50 | MAE.step(pred=pred, gt=mask) 51 | 52 | fm = FM.get_results()['fm'] 53 | wfm = WFM.get_results()['wfm'] 54 | sm = SM.get_results()['sm'] 55 | em = EM.get_results()['em'] 56 | mae = MAE.get_results()['mae'] 57 | 58 | print( 59 | 'Smeasure:', sm.round(3), '; ', 60 | 'wFmeasure:', wfm.round(3), '; ', 61 | 'MAE:', mae.round(3), '; ', 62 | 'adpEm:', em['adp'].round(3), '; ', 63 | 'meanEm:', '-' if em['curve'] is None else em['curve'].mean().round(3), '; ', 64 | 'maxEm:', '-' if em['curve'] is None else em['curve'].max().round(3), '; ', 65 | 'adpFm:', fm['adp'].round(3), '; ', 66 | 'meanFm:', fm['curve'].mean().round(3), '; ', 67 | 'maxFm:', fm['curve'].max().round(3), 68 | sep='' 69 | ) 70 | 71 | with open("../result.txt", "a+") as f: 72 | print('Smeasure:', sm.round(3), '; ', 73 | 'meanEm:', '-' if em['curve'] is None else em['curve'].mean().round(3), '; ', 74 | 'wFmeasure:', wfm.round(3), '; ', 75 | 'MAE:', mae.round(3), '; ', 76 | file=f 77 | ) 78 | -------------------------------------------------------------------------------- /segment_anything/modeling/mix_enbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class BasicConv2d(nn.Module): 4 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True, 5 | bn=nn.BatchNorm2d): 6 | super(BasicConv2d, self).__init__() 7 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 8 | stride=stride, padding=padding, dilation=dilation, bias=False) 9 | self.bn = bn(out_channels) 10 | self.relu = nn.ReLU() 11 | self.need_relu = need_relu 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | x = self.bn(x) 16 | if self.need_relu: 17 | x = self.relu(x) 18 | return x 19 | class ME(torch.nn.Module): 20 | def __init__(self, in_channels, out_channels): 21 | super(ME, self).__init__() 22 | 23 | self.depthwise_conv_reduce_channels_1 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=1, 24 | stride=1, padding=0, groups=in_channels // 16) 25 | self.depthwise_conv_reduce_channels_2 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, 26 | stride=1, padding=0, groups=in_channels // 16) 27 | self.relu = nn.ReLU(True) 28 | self.branch1 = nn.Sequential( 29 | BasicConv2d(out_channels, out_channels, 1), 30 | BasicConv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)), 31 | BasicConv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)), 32 | BasicConv2d(out_channels, out_channels, 3, padding=3, dilation=3) 33 | ) 34 | 35 | def initialize_parameters(self): 36 | for name, param in self.named_parameters(): 37 | if 'weight' in name: 38 | 39 | if len(param.shape) == 1: 40 | param_unsqueeze = param.unsqueeze(0) 41 | nn.init.xavier_uniform_(param_unsqueeze) 42 | param.data.copy_(param_unsqueeze.squeeze(0)) 43 | else: 44 | nn.init.xavier_uniform_(param) 45 | 46 | elif 'bias' in name: 47 | # print("bias:" + name) 48 | # The bias term is initialized 49 | nn.init.zeros_(param) 50 | def forward(self,dense_embeddings_box, high_frequency, sparse_embeddings_box): # high_frequency, 这里删了一个这个 51 | 52 | # dense_cat = torch.cat([dense_embeddings_boundary, dense_embeddings_box], dim=1) 53 | dense_cat_tmp = self.depthwise_conv_reduce_channels_1(dense_embeddings_box) 54 | dense_em = self.branch1(dense_embeddings_box) 55 | dense_em =dense_em + dense_cat_tmp 56 | 57 | dense_embeddings = torch.cat([dense_em, high_frequency], dim=1) # 这里注销了一个这个 58 | 59 | dense_embeddings = self.depthwise_conv_reduce_channels_2(dense_embeddings) 60 | sparse_embeddings = sparse_embeddings_box 61 | return dense_embeddings, sparse_embeddings 62 | -------------------------------------------------------------------------------- /segment_anything/modeling/loop_finer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .guided_filter import GuidedFilter 5 | from .agent_swin import AgentAttention 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, need_relu=True, 9 | bn=nn.BatchNorm2d): 10 | super(BasicConv2d, self).__init__() 11 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 12 | stride=stride, padding=padding) 13 | self.bn = bn(out_channels) 14 | self.relu = nn.ReLU() 15 | self.need_relu = need_relu 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | if self.need_relu: 21 | x = self.relu(x) 22 | return x 23 | 24 | def _upsample_like_64(src): 25 | src = F.interpolate(src, size=(64, 64), mode='bilinear') 26 | return src 27 | def _upsample_like_128(src): 28 | src = F.interpolate(src, size=(128, 128), mode='bilinear') 29 | return src 30 | def _upsample_like_256(src): 31 | src = F.interpolate(src, size=(256, 256), mode='bilinear') 32 | return src 33 | 34 | 35 | class Loop_Finer(nn.Module): 36 | def __init__(self, mid_ch): 37 | super(Loop_Finer, self).__init__() 38 | self.gfp = GuidedFilter() 39 | self.finer0 = BasicConv2d(264, mid_ch, kernel_size=3, stride=1, padding=1) 40 | self.finer1 = BasicConv2d(264, mid_ch, kernel_size=3, stride=1, padding=1) 41 | self.finer2 = BasicConv2d(mid_ch*2, mid_ch, kernel_size=1, stride=1, padding=1) 42 | self.finer_atten = AgentAttention(dim=64, num_heads=2) 43 | self.finer4 = BasicConv2d(mid_ch, 1, kernel_size=1, stride=1, padding=0) # fini_ch 44 | def forward(self, pred, image_em, depth_em): 45 | pred = _upsample_like_64(pred) 46 | reversed_pred = 1 - pred 47 | image_cut = torch.chunk(image_em, 8, dim=1) 48 | image_cat = torch.cat((image_cut[0], reversed_pred, image_cut[1], reversed_pred, image_cut[2], reversed_pred, image_cut[3], reversed_pred 49 | , image_cut[4], reversed_pred, image_cut[5], reversed_pred, image_cut[6], reversed_pred, image_cut[7], reversed_pred), 1) 50 | 51 | depth_cut = torch.chunk(depth_em, 8, dim=1) 52 | depth_cat = torch.cat((depth_cut[0], reversed_pred, depth_cut[1], reversed_pred, depth_cut[2], reversed_pred, depth_cut[3], reversed_pred 53 | , depth_cut[4], reversed_pred, depth_cut[5], reversed_pred, depth_cut[6], reversed_pred, depth_cut[7], reversed_pred), 1) 54 | 55 | image_fliter = self.gfp(image_cat, depth_cat) 56 | image_fliter = image_fliter + image_cat 57 | 58 | i_f0 = self.finer0(image_fliter) 59 | i_f0 = _upsample_like_64(i_f0) 60 | 61 | d_f1 = self.finer1(depth_cat) 62 | d_f1 = _upsample_like_64(d_f1) 63 | tmp = d_f1 64 | d_f1 = self.finer_atten(d_f1) 65 | d_f1 = d_f1 + tmp 66 | 67 | i_f2 = self.finer2(torch.cat((i_f0, d_f1), 1)) 68 | i_f2 = _upsample_like_128(i_f2) 69 | 70 | d_f3 = self.finer_atten(i_f2) 71 | d_f3 = _upsample_like_128(d_f3) 72 | d_f4 = self.finer4(d_f3) 73 | d_f4 = _upsample_like_256(d_f4) 74 | 75 | return d_f4 -------------------------------------------------------------------------------- /segment_anything/modeling/guided_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | 8 | def diff_x(input, r): 9 | assert input.dim() == 4 10 | 11 | left = input[:, :, r:2 * r + 1] 12 | middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1] 13 | right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1] 14 | 15 | output = torch.cat([left, middle, right], dim=2) 16 | 17 | return output 18 | 19 | def diff_y(input, r): 20 | assert input.dim() == 4 21 | 22 | left = input[:, :, :, r:2 * r + 1] 23 | middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1] 24 | right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1] 25 | 26 | output = torch.cat([left, middle, right], dim=3) 27 | 28 | return output 29 | 30 | class BoxFilter(nn.Module): 31 | def __init__(self, r): 32 | super(BoxFilter, self).__init__() 33 | 34 | self.r = r 35 | 36 | def forward(self, x): 37 | assert x.dim() == 4 38 | 39 | return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r) 40 | 41 | 42 | class FastGuidedFilter(nn.Module): 43 | def __init__(self, r, eps=1e-8): 44 | super(FastGuidedFilter, self).__init__() 45 | 46 | self.r = r 47 | self.eps = eps 48 | self.boxfilter = BoxFilter(r) 49 | 50 | 51 | def forward(self, lr_x, lr_y, hr_x): 52 | n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size() 53 | n_lry, c_lry, h_lry, w_lry = lr_y.size() 54 | n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size() 55 | 56 | assert n_lrx == n_lry and n_lry == n_hrx 57 | assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry) 58 | assert h_lrx == h_lry and w_lrx == w_lry 59 | assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1 60 | 61 | ## N 62 | N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))) 63 | 64 | ## mean_x 65 | mean_x = self.boxfilter(lr_x) / N 66 | ## mean_y 67 | mean_y = self.boxfilter(lr_y) / N 68 | ## cov_xy 69 | cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y 70 | ## var_x 71 | var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x 72 | 73 | ## A 74 | A = cov_xy / (var_x + self.eps) 75 | ## b 76 | b = mean_y - A * mean_x 77 | 78 | ## mean_A; mean_b 79 | mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True) 80 | mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True) 81 | 82 | return mean_A*hr_x+mean_b 83 | 84 | 85 | class GuidedFilter(nn.Module): 86 | def __init__(self, r=4, eps=1e-2): 87 | super(GuidedFilter, self).__init__() 88 | 89 | self.r = r 90 | self.eps = eps 91 | self.boxfilter = BoxFilter(r) 92 | 93 | 94 | def forward(self, x, y): 95 | n_x, c_x, h_x, w_x = x.size() 96 | n_y, c_y, h_y, w_y = y.size() 97 | 98 | assert n_x == n_y 99 | assert c_x == 1 or c_x == c_y 100 | assert h_x == h_y and w_x == w_y 101 | assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1 102 | 103 | # N 104 | N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0))) 105 | 106 | # mean_x 107 | mean_x = self.boxfilter(x) / N 108 | # mean_y 109 | mean_y = self.boxfilter(y) / N 110 | # cov_xy 111 | cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y 112 | # var_x 113 | var_x = self.boxfilter(x * x) / N - mean_x * mean_x 114 | 115 | # A 116 | A = cov_xy / (var_x + self.eps) 117 | # b 118 | b = mean_y - A * mean_x 119 | 120 | # mean_A; mean_b 121 | mean_A = self.boxfilter(A) / N 122 | mean_b = self.boxfilter(b) / N 123 | 124 | return mean_A * x + mean_b 125 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import cv2 5 | from typing import Any, Tuple 6 | import torch 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms.functional as TF 9 | 10 | 11 | class MedSamDataset(Dataset): 12 | def __init__( 13 | self, 14 | df: pd.DataFrame, 15 | image_col: str, 16 | mask_col: str, 17 | image_dir: Any = None, 18 | mask_dir: str = None, 19 | image_size: Tuple = (256, 256), 20 | ): 21 | """ 22 | PyTorch dataset class for loading image,mask and bbox pairs from a dataframe. 23 | The dataframe will need to have atleast two columns for the image and mask file names. The columns can either have the full or relative 24 | path of the images or just the file names. 25 | If only file names are given in the columns, the `image_dir` and `mask_dir` arguments should be specified. 26 | 27 | Args: 28 | df (pd.DataFrame): the pandas dataframe object 29 | image_col (str): the name of the column on the dataframe that holds the image file names. 30 | mask_col (str): the name of the column on the dataframe that holds the mask file names. 31 | image_dir (Any, optional): Path to the input image directory. Defaults to None. 32 | mask_dir (str, optional): Path to the mask images directory. Defaults to None. 33 | image_size (Tuple, optional): image size. Defaults to (256, 256). 34 | """ 35 | self.df = df 36 | self.image_dir = image_dir 37 | self.mask_dir = mask_dir 38 | self.image_col = image_col 39 | self.mask_col = mask_col 40 | self.image_size = image_size 41 | 42 | def __len__(self): 43 | return len(self.df) 44 | 45 | def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 46 | # read dataframe row 47 | row = self.df.iloc[idx] 48 | # If the `image_dir` attribute is set, the path will be relative to that directory. 49 | # Otherwise, the path will be the value of the `row[self.image_col]` attribute. 50 | image_file = ( 51 | os.path.join(self.image_dir, row[self.image_col]) 52 | if self.image_dir 53 | else row[self.image_col] 54 | ) 55 | mask_file = ( 56 | os.path.join(self.mask_dir, row[self.mask_col]) 57 | if self.mask_dir 58 | else row[self.mask_col] 59 | ) 60 | 61 | if not os.path.exists(image_file): 62 | raise FileNotFoundError(f"Couldn't find image {image_file}") 63 | if not os.path.exists(mask_file): 64 | raise FileNotFoundError(f"Couldn't find image {mask_file}") 65 | 66 | # read image and mask files 67 | image_data = cv2.imread(image_file) 68 | # read mask as gray scale 69 | mask_data = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) 70 | 71 | return self._preprocess(image_data, mask_data) 72 | 73 | def _preprocess( 74 | self, image: np.ndarray, mask: np.ndarray 75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 76 | # Threshold mask to binary 77 | mask = cv2.threshold(mask, 127.0, 255.0, cv2.THRESH_BINARY)[1] 78 | # convert to tensor 79 | image = TF.to_tensor(image) 80 | mask = TF.to_tensor(mask) 81 | # min-max normalize and scale 82 | image = (image - image.min()) / (image.max() - image.min()) * 255.0 83 | # resize 84 | image = TF.resize(image, self.image_size, antialias=True) 85 | mask = TF.resize(mask, self.image_size, antialias=True) 86 | 87 | bbox = self._get_bbox(mask) 88 | 89 | return image, mask, bbox 90 | 91 | def _get_bbox(self, mask: torch.Tensor) -> torch.Tensor: 92 | _, y_indices, x_indices = torch.where(mask > 0) 93 | 94 | x_min, y_min = (x_indices.min(), y_indices.min()) 95 | x_max, y_max = (x_indices.max(), y_indices.max()) 96 | 97 | # add perturbation to bounding box coordinates 98 | H, W = mask.shape[1:] 99 | # add perfurbation to the bbox 100 | assert H == W, f"{W} and {H} are not equal size!!" 101 | x_min = max(0, x_min - np.random.randint(0, 10)) 102 | x_max = min(W, x_max + np.random.randint(0, 10)) 103 | y_min = max(0, y_min - np.random.randint(0, 10)) 104 | y_max = min(H, y_max + np.random.randint(0, 10)) 105 | 106 | return torch.tensor([x_min, y_min, x_max, y_max]) 107 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | # for pvt's image size 34 | def apply_image_pvt(self, image: np.ndarray) -> np.ndarray: 35 | """ 36 | Expects a numpy array with shape HxWxC in uint8 format. 37 | """ 38 | target_size = self.get_preprocess_shape(image.shape[1], image.shape[2], self.target_length) 39 | return np.array(resize(to_pil_image(image), target_size)) 40 | 41 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 42 | """ 43 | Expects a numpy array of length 2 in the final dimension. Requires the 44 | original image size in (H, W) format. 45 | """ 46 | old_h, old_w = original_size 47 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 48 | new_coords = np.empty_like(coords) 49 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 50 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 51 | return new_coords 52 | 53 | 54 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 55 | """ 56 | Expects a numpy array shape Bx4. Requires the original image size 57 | in (H, W) format. 58 | """ 59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 60 | return boxes.reshape(-1, 4) 61 | 62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Expects batched images with shape BxCxHxW and float format. This 65 | transformation may not exactly match apply_image. apply_image is 66 | the transformation expected by the model. 67 | """ 68 | # Expects an image in BCHW format. May not exactly match apply_image. 69 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 70 | return F.interpolate( 71 | image, target_size, mode="bilinear", align_corners=False, antialias=True 72 | ) 73 | 74 | def apply_coords_torch( 75 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 76 | ) -> torch.Tensor: 77 | """ 78 | Expects a torch tensor with length 2 in the last dimension. Requires the 79 | original image size in (H, W) format. 80 | """ 81 | old_h, old_w = original_size 82 | new_h, new_w = self.get_preprocess_shape( 83 | original_size[0], original_size[1], self.target_length 84 | ) 85 | coords = deepcopy(coords).to(torch.float) 86 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 87 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 88 | return coords 89 | 90 | def apply_boxes_torch( 91 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 92 | ) -> torch.Tensor: 93 | """ 94 | Expects a torch tensor with shape Bx4. Requires the original image 95 | size in (H, W) format. 96 | """ 97 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 98 | return boxes.reshape(-1, 4) 99 | 100 | @staticmethod 101 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 102 | """ 103 | Compute the output size given input size and target long side length. 104 | """ 105 | scale = long_side_length * 1.0 / max(oldh, oldw) 106 | newh, neww = oldh * scale, oldw * scale 107 | neww = int(neww + 0.5) 108 | newh = int(newh + 0.5) 109 | return (newh, neww) 110 | -------------------------------------------------------------------------------- /segment_anything/modeling/DWT.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | # affine_par = True 5 | class BasicConv2d(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True, 7 | bn=nn.BatchNorm2d): 8 | super(BasicConv2d, self).__init__() 9 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 10 | stride=stride, padding=padding, dilation=dilation, bias=False) 11 | self.bn = bn(out_channels) 12 | self.relu = nn.ReLU() 13 | self.need_relu = need_relu 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | if self.need_relu: 19 | x = self.relu(x) 20 | return x 21 | 22 | # class ETM(nn.Module): 23 | # def __init__(self, in_channels, out_channels): 24 | # super(ETM, self).__init__() 25 | # self.relu = nn.ReLU(True) 26 | # self.branch0 = BasicConv2d(in_channels, out_channels, 1) 27 | # self.branch1 = nn.Sequential( 28 | # BasicConv2d(in_channels, out_channels, 1), 29 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)), 30 | # BasicConv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)), 31 | # BasicConv2d(out_channels, out_channels, 3, padding=3, dilation=3) 32 | # ) 33 | # self.branch2 = nn.Sequential( 34 | # BasicConv2d(in_channels, out_channels, 1), 35 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 5), padding=(0, 2)), 36 | # BasicConv2d(out_channels, out_channels, kernel_size=(5, 1), padding=(2, 0)), 37 | # BasicConv2d(out_channels, out_channels, 3, padding=5, dilation=5) 38 | # ) 39 | # self.branch3 = nn.Sequential( 40 | # BasicConv2d(in_channels, out_channels, 1), 41 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 7), padding=(0, 3)), 42 | # BasicConv2d(out_channels, out_channels, kernel_size=(7, 1), padding=(3, 0)), 43 | # BasicConv2d(out_channels, out_channels, 3, padding=7, dilation=7) 44 | # ) 45 | # self.conv_cat = BasicConv2d(4 * out_channels, out_channels, 3, padding=1) 46 | # self.conv_res = BasicConv2d(in_channels, out_channels, 1) 47 | # 48 | # def forward(self, x): 49 | # x0 = self.branch0(x) 50 | # x1 = self.branch1(x) 51 | # x2 = self.branch2(x) 52 | # x3 = self.branch3(x) 53 | # x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 54 | # 55 | # x = self.relu(x_cat + self.conv_res(x)) 56 | # return x 57 | def resize_tensor(tensor, size): 58 | """ 59 | 使用插值方法将张量调整大小 60 | 61 | Args: 62 | tensor (torch.Tensor): 输入的张量,形状为 (batch_size, num_channels, height, width) 63 | size (tuple): 调整后的目标大小,形状为 (new_height, new_width) 64 | 65 | Returns: 66 | torch.Tensor: 调整大小后的张量,形状为 (batch_size, num_channels, new_height, new_width) 67 | """ 68 | resized_tensor = nn.functional.interpolate(tensor, size=size, mode='bilinear', align_corners=False) 69 | return resized_tensor 70 | 71 | class DWT(nn.Module): 72 | def __init__(self): 73 | super(DWT, self).__init__() 74 | self.requires_grad = False 75 | 76 | def forward(self, x): 77 | # x01 x02是低频信号; x1 x2 x3 x4是高频信号 78 | x01 = x[:, :, 0::2, :] / 2 79 | # 表示在第三个维度上以步幅为2选择元素,即选择索引为0、2、4、...的元素。 80 | x02 = x[:, :, 1::2, :] / 2 81 | # 这意味着只选择索引为1、3、5等奇数位置上的元素。/ 2:这是除法运算符, 82 | # 将选取的部分 x[:, :, 0::2, :] 的每个元素都除以2,即 / 2 的操作。这样做是为了缩小高频部分的值范围,将其变得更小。 83 | # 在一些信号处理的方法中,高频部分往往具有较大的值,可能会对某些操作产生较大的影响,例如激活函数的响应范围、优化算法的收敛性等。 84 | # 通过将高频部分除以2,可以将其数值范围缩小一半,使其相对于低频部分具有更小的权重,从而在一定程度上平衡了高频和低频的影响。 85 | x1 = x01[:, :, :, 0::2] 86 | # 是对 x01 在空间维度上进行下采样的操作,步长为2。这样的操作会使得 x01 在宽度维度上减半, 87 | # 即将每一行的元素进行间隔取样,形状变为 (batch_size, num_channels, height//2, width//2)。 88 | x2 = x02[:, :, :, 0::2] 89 | x3 = x01[:, :, :, 1::2] 90 | x4 = x02[:, :, :, 1::2] 91 | ll = x1 + x2 + x3 + x4 92 | lh = -x1 + x2 - x3 + x4 93 | hl = -x1 - x2 + x3 + x4 94 | hh = x1 - x2 - x3 + x4 95 | return ll, lh, hl, hh 96 | # ll 子图表示低频部分的信息,lh、hl、hh 分别表示高频部分在不同方向上的信息。 97 | 98 | class extract_high_frequency(nn.Module): 99 | def __init__(self): 100 | super(extract_high_frequency, self).__init__() 101 | self.dwt = DWT() 102 | self.conv_in = nn.Conv2d(768, 256, 3, padding=1) 103 | 104 | def forward(self, x): 105 | x = self.conv_in(x) 106 | ll, lh, hl, hh = self.dwt(x) 107 | high_frequency = resize_tensor(hh, (64, 64)) 108 | return high_frequency 109 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: pytorch183 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 13 | dependencies: 14 | - _libgcc_mutex=0.1=main 15 | - _openmp_mutex=4.5=1_gnu 16 | - blas=1.0=mkl 17 | - brotli=1.0.9=he6710b0_2 18 | - brotlipy=0.7.0=py37h27cfd23_1003 19 | - bzip2=1.0.8=h7b6447c_0 20 | - ca-certificates=2023.01.10=h06a4308_0 21 | - certifi=2022.12.7=py37h06a4308_0 22 | - cffi=1.15.0=py37hd667e15_1 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - cloudpickle=2.0.0=pyhd3eb1b0_0 25 | - colorama=0.4.4=pyhd3eb1b0_0 26 | - cryptography=39.0.1=py37h9ce1e76_0 27 | - cudatoolkit=11.3.1=h2bc3f7f_2 28 | - cycler=0.11.0=pyhd3eb1b0_0 29 | - cytoolz=0.11.0=py37h7b6447c_0 30 | - dask-core=2021.10.0=pyhd3eb1b0_0 31 | - ffmpeg=4.3=hf484d3e_0 32 | - fonttools=4.25.0=pyhd3eb1b0_0 33 | - freetype=2.11.0=h70c0345_0 34 | - fsspec=2022.2.0=pyhd3eb1b0_0 35 | - gmp=6.2.1=h2531618_2 36 | - gnutls=3.6.15=he1e5248_0 37 | - idna=3.4=py37h06a4308_0 38 | - imageio=2.9.0=pyhd3eb1b0_0 39 | - intel-openmp=2021.4.0=h06a4308_3561 40 | - jpeg=9b=h024ee3a_2 41 | - kiwisolver=1.3.2=py37h295c915_0 42 | - lame=3.100=h7b6447c_0 43 | - lcms2=2.12=h3be6417_0 44 | - ld_impl_linux-64=2.35.1=h7274673_9 45 | - libffi=3.3=he6710b0_2 46 | - libgcc-ng=9.3.0=h5101ec6_17 47 | - libgfortran-ng=7.5.0=ha8ba4b0_17 48 | - libgfortran4=7.5.0=ha8ba4b0_17 49 | - libgomp=9.3.0=h5101ec6_17 50 | - libiconv=1.15=h63c8f33_5 51 | - libidn2=2.3.2=h7f8727e_0 52 | - libpng=1.6.37=hbc83047_0 53 | - libstdcxx-ng=9.3.0=hd4cf53a_17 54 | - libtasn1=4.16.0=h27cfd23_0 55 | - libtiff=4.2.0=h85742a9_0 56 | - libunistring=0.9.10=h27cfd23_0 57 | - libuv=1.40.0=h7b6447c_0 58 | - libwebp-base=1.2.2=h7f8727e_0 59 | - locket=0.2.1=py37h06a4308_2 60 | - lz4-c=1.9.3=h295c915_1 61 | - matplotlib-base=3.5.1=py37ha18d171_1 62 | - mkl=2021.4.0=h06a4308_640 63 | - mkl-service=2.4.0=py37h7f8727e_0 64 | - mkl_fft=1.3.1=py37hd3c417c_0 65 | - mkl_random=1.2.2=py37h51133e4_0 66 | - munkres=1.1.4=py_0 67 | - ncurses=6.3=h7f8727e_2 68 | - nettle=3.7.3=hbbd107a_1 69 | - networkx=2.6.3=pyhd3eb1b0_0 70 | - numpy=1.21.5=py37he7a7128_2 71 | - numpy-base=1.21.5=py37hf524024_2 72 | - olefile=0.46=py37_0 73 | - openh264=2.1.1=h4ff587b_0 74 | - openssl=1.1.1t=h7f8727e_0 75 | - packaging=21.3=pyhd3eb1b0_0 76 | - partd=1.2.0=pyhd3eb1b0_1 77 | - pillow=8.0.0=py37h9a89aac_0 78 | - pip=21.2.2=py37h06a4308_0 79 | - pycparser=2.21=pyhd3eb1b0_0 80 | - pyopenssl=23.0.0=py37h06a4308_0 81 | - pyparsing=3.0.4=pyhd3eb1b0_0 82 | - pysocks=1.7.1=py37_1 83 | - python=3.7.11=h12debd9_0 84 | - python-dateutil=2.8.2=pyhd3eb1b0_0 85 | - pytorch=1.12.1=py3.7_cuda11.3_cudnn8.3.2_0 86 | - pytorch-mutex=1.0=cuda 87 | - pywavelets=1.1.1=py37h7b6447c_2 88 | - pyyaml=6.0=py37h7f8727e_1 89 | - readline=8.1.2=h7f8727e_1 90 | - requests=2.28.1=py37h06a4308_0 91 | - scikit-image=0.15.0=py37hb3f55d8_2 92 | - scipy=1.7.3=py37hc147768_0 93 | - setuptools=58.0.4=py37h06a4308_0 94 | - six=1.16.0=pyhd3eb1b0_1 95 | - sqlite=3.38.0=hc218d9a_0 96 | - tk=8.6.11=h1ccaba5_0 97 | - toolz=0.11.2=pyhd3eb1b0_0 98 | - torchaudio=0.12.1=py37_cu113 99 | - torchvision=0.13.1=py37_cu113 100 | - urllib3=1.26.14=py37h06a4308_0 101 | - wheel=0.38.4=py37h06a4308_0 102 | - xz=5.2.5=h7b6447c_0 103 | - yaml=0.2.5=h7b6447c_0 104 | - zlib=1.2.11=h7f8727e_4 105 | - zstd=1.4.9=haebb681_0 106 | - pip: 107 | - addict==2.4.0 108 | - anykeystore==0.2 109 | - apex==0.1 110 | - causal-conv1d==1.2.2.post1 111 | - cryptacular==1.6.2 112 | - defusedxml==0.7.1 113 | - einops==0.6.1 114 | - filelock==3.12.2 115 | - greenlet==3.0.3 116 | - huggingface-hub==0.16.4 117 | - hupper==1.12.1 118 | - importlib-metadata==6.7.0 119 | - joblib==1.3.2 120 | - mamba-ssm==1.2.2 121 | - markupsafe==2.1.4 122 | - mmcv==1.7.0 123 | - monai==1.1.0 124 | - ninja==1.11.1.1 125 | - oauthlib==3.2.2 126 | - pandas==1.3.5 127 | - pastedeploy==3.1.0 128 | - pbkdf2==1.3 129 | - plaster==1.1.2 130 | - plaster-pastedeploy==1.0.1 131 | - platformdirs==4.0.0 132 | - protobuf==4.24.4 133 | - pyramid==2.0.2 134 | - pyramid-mailer==0.15.1 135 | - python3-openid==3.2.0 136 | - pytz==2024.1 137 | - regex==2024.4.16 138 | - repoze-sendmail==4.4.1 139 | - requests-oauthlib==1.3.1 140 | - safetensors==0.4.2 141 | - scikit-learn==1.0.2 142 | - seaborn==0.12.2 143 | - sklearn==0.0 144 | - sqlalchemy==2.0.25 145 | - tensorboardx==2.6.2.2 146 | - thop==0.1.1-2209072238 147 | - threadpoolctl==3.1.0 148 | - timm==0.9.12 149 | - tokenizers==0.13.3 150 | - tomli==2.0.1 151 | - torchsummary==1.5.1 152 | - tqdm==4.66.1 153 | - transaction==4.0 154 | - transformers==4.30.2 155 | - translationstring==1.4 156 | - triton==2.3.1 157 | - ttach==0.0.3 158 | - typing-extensions==4.7.1 159 | - velruse==1.1.1 160 | - venusian==3.1.0 161 | - webob==1.8.7 162 | - wtforms==3.0.1 163 | - wtforms-recaptcha==0.3.2 164 | - yapf==0.40.2 165 | - zipp==3.15.0 166 | - zope-deprecation==5.0 167 | - zope-interface==6.1 168 | - zope-sqlalchemy==3.1 169 | prefix: /root/anaconda3/envs/pytorch183 170 | -------------------------------------------------------------------------------- /segment_anything/modeling/sd_merge.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class BasicConv2d(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True, 6 | bn=nn.BatchNorm2d): 7 | super(BasicConv2d, self).__init__() 8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 9 | stride=stride, padding=padding, dilation=dilation, bias=False) 10 | self.bn = bn(out_channels) 11 | self.relu = nn.ReLU() 12 | self.need_relu = need_relu 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = self.bn(x) 17 | if self.need_relu: 18 | x = self.relu(x) 19 | return x 20 | # merge_dense_sparse 21 | class MDS(nn.Module): 22 | def __init__(self, in_channels, out_channels): 23 | super(MDS, self).__init__() 24 | '''8 :sparse_embedding:tensor(bs, 2, 256) -> tensor(bs, 1, 2, 256) 25 | :dense_embedding:tensor(bs, 256, 64, 64) 26 | 8 = 64*64/2/256''' 27 | self.sparse_in_channels = 1 28 | self.dense_in_channels = in_channels*8 29 | self.real_in_channels = in_channels*8+self.sparse_in_channels 30 | self.real_out_channels = out_channels*8+self.sparse_in_channels 31 | self.relu = nn.ReLU(True) 32 | # self.branch0 = BasicConv2d(self.real_in_channels, self.real_out_channels, 1) 33 | self.branch1 = nn.Sequential( 34 | BasicConv2d(self.real_in_channels, self.real_out_channels, 1), 35 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(1, 3), padding=(0, 1)), 36 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(3, 1), padding=(1, 0)), 37 | BasicConv2d(self.real_out_channels, self.real_out_channels, 3, padding=3, dilation=3) 38 | ) 39 | self.branch2 = nn.Sequential( 40 | BasicConv2d(self.real_in_channels, self.real_out_channels, 1), 41 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(1, 5), padding=(0, 2)), 42 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(5, 1), padding=(2, 0)), 43 | BasicConv2d(self.real_out_channels, self.real_out_channels, 3, padding=5, dilation=5) 44 | ) 45 | # self.branch3 = nn.Sequential( 46 | # BasicConv2d(in_channels, out_channels, 1), 47 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 7), padding=(0, 3)), 48 | # BasicConv2d(out_channels, out_channels, kernel_size=(7, 1), padding=(3, 0)), 49 | # BasicConv2d(out_channels, out_channels, 3, padding=7, dilation=7) 50 | # ) 51 | self.conv_cat = BasicConv2d(2 * self.real_out_channels, self.real_out_channels, 3, padding=1) 52 | # self.conv_res = BasicConv2d(in_channels, out_channels, 1) 53 | self.conv_spares_cat = BasicConv2d(2*self.sparse_in_channels, self.sparse_in_channels, 3, padding=1) 54 | self.conv_dense_cat = BasicConv2d(2*self.dense_in_channels, self.dense_in_channels, 3, padding=1) 55 | def initialize_parameters(self): 56 | for name, param in self.named_parameters(): 57 | if 'weight' in name: 58 | # Initialize the weight parameters 59 | # One-dimensional vectors cannot be initialized, so we first convert them to two dimensions, 60 | # initialize them, and then convert them to one dimension 61 | if len(param.shape) == 1: 62 | param_unsqueeze = param.unsqueeze(0) 63 | nn.init.xavier_uniform_(param_unsqueeze) 64 | param.data.copy_(param_unsqueeze.squeeze(0)) 65 | else: 66 | nn.init.xavier_uniform_(param) 67 | 68 | elif 'bias' in name: 69 | # print("bias:" + name) 70 | # The bias term is initialized 71 | nn.init.zeros_(param) 72 | 73 | def forward(self, sparse_embeddings, dense_embeddings): 74 | sparse_embeddings = sparse_embeddings.unsqueeze(1) 75 | bs_d, c_d, h_d, w_d = dense_embeddings.size() 76 | bs_s, c_s, h_s, w_s = sparse_embeddings.size() 77 | n = h_d * w_d // (h_s * w_s) 78 | dense_embeddings = dense_embeddings.view(bs_d, n*c_d, h_s, w_s) 79 | tmp_sparse = sparse_embeddings 80 | tmp_dense = dense_embeddings 81 | sd_cat = torch.cat([sparse_embeddings, dense_embeddings], dim=1) 82 | 83 | # x0 = self.branch0(sd_cat) 84 | x1 = self.branch1(sd_cat) 85 | x_cat = x1 86 | # x_cat = self.conv_cat(torch.cat((x0, x1), 1)) 87 | 88 | sparse_embeddings_up = x_cat[:, :self.sparse_in_channels, :, :] 89 | sparse_embeddings = self.relu(sparse_embeddings * sparse_embeddings_up) 90 | sparse_cat = torch.cat([sparse_embeddings, tmp_sparse], dim=1) 91 | sparse_embeddings = self.conv_spares_cat(sparse_cat) 92 | # sparse_embeddings = sparse_embeddings.view(32, 1, -1, 256) 93 | sparse_embeddings = sparse_embeddings.squeeze(dim=1) 94 | 95 | dense_embeddings_down = x_cat[:, self.sparse_in_channels:, :, :] 96 | dense_embeddings = self.relu(dense_embeddings * dense_embeddings_down) 97 | dense_embeddings = self.conv_dense_cat(torch.cat([dense_embeddings, tmp_dense], dim=1)) 98 | dense_embeddings = dense_embeddings.view(bs_d, c_d, h_d, w_d) 99 | # dense_embeddings = dense_embeddings + tmp_dense 100 | return sparse_embeddings, dense_embeddings -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from functools import partial 7 | from pathlib import Path 8 | import urllib.request 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | from .modeling import ( 14 | ImageEncoderViT, 15 | MaskDecoder, 16 | PromptEncoder, 17 | Sam, 18 | TwoWayTransformer, 19 | # Shield, 20 | extract_high_frequency, 21 | ME, 22 | Loop_Finer, 23 | bias_correction, 24 | pvt_v2_b2, 25 | # InvertedResidual 26 | 27 | ) 28 | 29 | 30 | def build_sam_vit_h(checkpoint=None): 31 | return _build_sam( 32 | encoder_embed_dim=1280, 33 | encoder_depth=32, 34 | encoder_num_heads=16, 35 | encoder_global_attn_indexes=[7, 15, 23, 31], 36 | checkpoint=checkpoint, 37 | ) 38 | 39 | 40 | build_sam = build_sam_vit_h 41 | 42 | 43 | def build_sam_vit_l(checkpoint=None): 44 | return _build_sam( 45 | encoder_embed_dim=1024, 46 | encoder_depth=24, 47 | encoder_num_heads=16, 48 | encoder_global_attn_indexes=[5, 11, 17, 23], 49 | checkpoint=checkpoint, 50 | ) 51 | 52 | 53 | def build_sam_vit_b(checkpoint=None): 54 | return _build_sam( 55 | encoder_embed_dim=768, 56 | encoder_depth=12, 57 | encoder_num_heads=12, 58 | encoder_global_attn_indexes=[2, 5, 8, 11], 59 | checkpoint=checkpoint, 60 | ) 61 | 62 | 63 | sam_model_registry = { 64 | "default": build_sam_vit_h, 65 | "vit_h": build_sam_vit_h, 66 | "vit_l": build_sam_vit_l, 67 | "vit_b": build_sam_vit_b, 68 | } 69 | 70 | 71 | def _build_sam( 72 | encoder_embed_dim, 73 | encoder_depth, 74 | encoder_num_heads, 75 | encoder_global_attn_indexes, 76 | checkpoint=None, 77 | ): 78 | prompt_embed_dim = 256 79 | image_size = 1024 80 | vit_patch_size = 16 81 | image_embedding_size = image_size // vit_patch_size 82 | sam = Sam( 83 | image_encoder=ImageEncoderViT( 84 | depth=encoder_depth, 85 | embed_dim=encoder_embed_dim, 86 | img_size=image_size, 87 | mlp_ratio=4, 88 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 89 | num_heads=encoder_num_heads, 90 | patch_size=vit_patch_size, 91 | qkv_bias=True, 92 | use_rel_pos=True, 93 | global_attn_indexes=encoder_global_attn_indexes, 94 | window_size=14, 95 | out_chans=prompt_embed_dim, 96 | ), 97 | prompt_encoder=PromptEncoder( 98 | embed_dim=prompt_embed_dim, 99 | image_embedding_size=(image_embedding_size, image_embedding_size), 100 | input_image_size=(image_size, image_size), 101 | mask_in_chans=16, 102 | ), 103 | mask_decoder=MaskDecoder( 104 | num_multimask_outputs=3, 105 | transformer=TwoWayTransformer( 106 | depth=2, 107 | embedding_dim=prompt_embed_dim, 108 | mlp_dim=2048, 109 | num_heads=8, 110 | ), 111 | transformer_dim=prompt_embed_dim, 112 | iou_head_depth=3, 113 | iou_head_hidden_dim=256, 114 | ), 115 | pvt=pvt_v2_b2(), 116 | DWT=extract_high_frequency(), 117 | ME=ME(256*2, 256), 118 | Loop_finer=Loop_Finer(64), 119 | BC=bias_correction(512), 120 | pixel_mean=[123.675, 116.28, 103.53], 121 | pixel_std=[58.395, 57.12, 57.375], 122 | ) 123 | 124 | sam.eval() 125 | checkpoint = Path(checkpoint) 126 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 127 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 128 | if len(cmd) == 0 or cmd.lower() == 'y': 129 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 130 | print("Downloading SAM ViT-B checkpoint...") 131 | urllib.request.urlretrieve( 132 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 133 | checkpoint, 134 | ) 135 | print(checkpoint.name, " is downloaded!") 136 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 137 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 138 | if len(cmd) == 0 or cmd.lower() == 'y': 139 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 140 | print("Downloading SAM ViT-H checkpoint...") 141 | urllib.request.urlretrieve( 142 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 143 | checkpoint, 144 | ) 145 | print(checkpoint.name, " is downloaded!") 146 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 147 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 148 | if len(cmd) == 0 or cmd.lower() == 'y': 149 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 150 | print("Downloading SAM ViT-L checkpoint...") 151 | urllib.request.urlretrieve( 152 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 153 | checkpoint, 154 | ) 155 | print(checkpoint.name, " is downloaded!") 156 | 157 | 158 | if checkpoint is not None: 159 | # sam.ME = ME(256 * 3) # 在加载预训练权重之前,初始化 ME 模块 160 | with open(checkpoint, "rb") as f: 161 | state_dict = torch.load(f) 162 | 163 | sam.ME.initialize_parameters() 164 | sam.load_state_dict(state_dict, strict=False) # 设置 strict=False 允许部分加载权重 165 | path = 'work_dir_cod/SAM/pvt_v2_b2.pth' 166 | save_model = torch.load(path) 167 | model_dict = sam.pvt.state_dict() 168 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 169 | model_dict.update(state_dict) 170 | sam.pvt.load_state_dict(model_dict) 171 | return sam 172 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/modeling/agent_swin.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Agent Attention: On the Integration of Softmax and Linear Attention 8 | # Modified by Dongchen Han 9 | # ----------------------------------------------------------------------- 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | 17 | class AgentAttention(nn.Module): 18 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 19 | It supports both of shifted and non-shifted window. 20 | 21 | Args: 22 | dim (int): Number of input channels. 23 | num_heads (int): Number of attention heads. 24 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 25 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 26 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 27 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 28 | """ 29 | 30 | def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., 31 | shift_size=0, agent_num=49, **kwargs): 32 | 33 | super().__init__() 34 | self.dim = dim # embed_dim=96 35 | # self.window_size = window_size # Wh, Ww 36 | self.num_heads = num_heads # num_heads=[3, 6, 12, 24] 37 | head_dim = dim // num_heads 38 | self.scale = head_dim ** -0.5 39 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 40 | self.attn_drop = nn.Dropout(attn_drop) 41 | self.proj = nn.Linear(dim, dim) 42 | self.proj_drop = nn.Dropout(proj_drop) 43 | self.softmax = nn.Softmax(dim=-1) 44 | self.shift_size = shift_size 45 | 46 | self.agent_num = agent_num 47 | self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim) 48 | # self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) 49 | # self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7)) 50 | # self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0], 1)) 51 | # self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1])) 52 | # self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num)) 53 | # self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num)) 54 | # trunc_normal_(self.an_bias, std=.02) 55 | # trunc_normal_(self.na_bias, std=.02) 56 | # trunc_normal_(self.ah_bias, std=.02) 57 | # trunc_normal_(self.aw_bias, std=.02) 58 | # trunc_normal_(self.ha_bias, std=.02) 59 | # trunc_normal_(self.wa_bias, std=.02) 60 | pool_size = int(agent_num ** 0.5) 61 | self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)) 62 | 63 | def forward(self, x, mask=None): 64 | """ 65 | Args: 66 | x: input features with shape of (num_windows*B, N, C) 67 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 68 | """ 69 | x = x.permute(0, 2, 3, 1) 70 | b, h, w, c = x.shape 71 | x = x.reshape(b, h * w, c) 72 | # b, c, h, w = x.shape 73 | # n = h * w 74 | b, n, c = x.shape 75 | h = int(n ** 0.5) 76 | w = int(n ** 0.5) 77 | # print(b) 78 | # print(h) 79 | # print(w) 80 | # print(c) 81 | num_heads = self.num_heads 82 | head_dim = c // num_heads 83 | qkv = self.qkv(x) 84 | # print(qkv.shape) 85 | qkv = qkv.reshape(b, n, 3, c).permute(2, 0, 1, 3) 86 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 87 | # q, k, v: b, n, c 88 | 89 | agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1) 90 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 91 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 92 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) 93 | agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3) 94 | 95 | # position_bias1 = nn.functional.interpolate(self.an_bias, size=self.window_size, mode='bilinear') 96 | # position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) 97 | # position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1) 98 | # position_bias = position_bias1 + position_bias2 99 | agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1)) 100 | agent_attn = self.attn_drop(agent_attn) 101 | agent_v = agent_attn @ v 102 | 103 | # agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear') 104 | # agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1) 105 | # agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1) 106 | # agent_bias = agent_bias1 + agent_bias2 107 | q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1)) 108 | q_attn = self.attn_drop(q_attn) 109 | x = q_attn @ agent_v 110 | 111 | x = x.transpose(1, 2).reshape(b, n, c) 112 | v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2) 113 | x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c) 114 | 115 | x = self.proj(x) 116 | x = self.proj_drop(x) 117 | x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) 118 | return x 119 | 120 | # def extra_repr(self) -> str: 121 | # return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 122 | 123 | def flops(self, N): 124 | # calculate flops for 1 window with token length of N 125 | flops = 0 126 | # qkv = self.qkv(x) 127 | flops += N * self.dim * 3 * self.dim 128 | # attn = (q @ k.transpose(-2, -1)) 129 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 130 | # x = (attn @ v) 131 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 132 | # x = self.proj(x) 133 | flops += N * self.dim * self.dim 134 | return flops 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | if image_embeddings.shape[0] != tokens.shape[0]: 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | # 将自己的tensor插入自己,是的bs和tokens的bs一致 129 | else: 130 | src = image_embeddings 131 | 132 | src = src + dense_prompt_embeddings # image_embeddings:32 512 64 64 dense_prompt_embeddings:32 256 64 64, img_em and dense_em are the same 133 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 134 | b, c, h, w = src.shape 135 | 136 | # Run the transformer 137 | hs, src = self.transformer(src, pos_src, tokens) 138 | iou_token_out = hs[:, 0, :] 139 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 140 | 141 | # Upscale mask embeddings and predict masks using the mask tokens 142 | src = src.transpose(1, 2).view(b, c, h, w) 143 | upscaled_embedding = self.output_upscaling(src) 144 | hyper_in_list: List[torch.Tensor] = [] 145 | for i in range(self.num_mask_tokens): 146 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 147 | hyper_in = torch.stack(hyper_in_list, dim=1) 148 | b, c, h, w = upscaled_embedding.shape 149 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 150 | 151 | # Generate mask quality predictions 152 | iou_pred = self.iou_prediction_head(iou_token_out) 153 | 154 | return masks, iou_pred 155 | 156 | 157 | # Lightly adapted from 158 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 159 | class MLP(nn.Module): 160 | def __init__( 161 | self, 162 | input_dim: int, 163 | hidden_dim: int, 164 | output_dim: int, 165 | num_layers: int, 166 | sigmoid_output: bool = False, 167 | ) -> None: 168 | super().__init__() 169 | self.num_layers = num_layers 170 | h = [hidden_dim] * (num_layers - 1) 171 | self.layers = nn.ModuleList( 172 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 173 | ) 174 | self.sigmoid_output = sigmoid_output 175 | 176 | def forward(self, x): 177 | for i, layer in enumerate(self.layers): 178 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 179 | if self.sigmoid_output: 180 | x = F.sigmoid(x) 181 | return x 182 | -------------------------------------------------------------------------------- /Mytrain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | join = os.path.join 5 | from tqdm import tqdm 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | import monai 9 | import torch.nn as nn 10 | from segment_anything import sam_model_registry 11 | from segment_anything.utils.transforms import ResizeLongestSide 12 | from segment_anything.modeling.CWDLoss import CriterionCWD 13 | from torch.nn import functional as F 14 | from torchvision.models.mobilenetv2 import InvertedResidual 15 | # set seeds 16 | torch.manual_seed(2024) 17 | np.random.seed(2024) 18 | 19 | def _upsample_like_1024(src): 20 | src = F.interpolate(src, size=(1024, 1024), mode='bilinear') 21 | return src 22 | 23 | 24 | class NpzDataset(Dataset): 25 | def __init__(self, data_root): 26 | self.data_root = data_root 27 | self.npz_files = sorted(os.listdir(self.data_root)) 28 | self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files] 29 | # this implementation is ugly but it works (and is also fast for feeding data to GPU) 30 | # if your server has enough RAM 31 | # as an alternative, you can also use a list of npy files and load them one by one 32 | self.ori_gts = np.vstack([d['gts'] for d in self.npz_data]) 33 | self.ori_imgs = np.vstack([d['imgs'] for d in self.npz_data]) 34 | self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data]) 35 | self.boundary = np.vstack([d['boundary'] for d in self.npz_data]) 36 | self.depth_embeddings = np.vstack([d['depth_embeddings'] for d in self.npz_data]) 37 | print(f"img_embeddings.shape={self.img_embeddings.shape}, ori_gts.shape={self.ori_gts.shape}, " 38 | f"boundary.shape={self.boundary.shape}", f"depth_embeddings.shape={self.depth_embeddings.shape}") 39 | 40 | def __len__(self): 41 | return self.ori_gts.shape[0] 42 | 43 | def __getitem__(self, index): 44 | img_embed = self.img_embeddings[index] 45 | gt2D = self.ori_gts[index] 46 | img = self.ori_imgs[index] 47 | boundary = self.boundary[index] 48 | depth_embed = self.depth_embeddings[index] 49 | y_indices, x_indices = np.where(gt2D > 0) 50 | x_min, x_max = np.min(x_indices), np.max(x_indices) 51 | y_min, y_max = np.min(y_indices), np.max(y_indices) 52 | # add perturbation to bounding box coordinates 53 | H, W = gt2D.shape 54 | x_min = max(0, x_min - np.random.randint(0, 20)) 55 | x_max = min(W, x_max + np.random.randint(0, 20)) 56 | y_min = max(0, y_min - np.random.randint(0, 20)) 57 | y_max = min(H, y_max + np.random.randint(0, 20)) 58 | bboxes = np.array([x_min, y_min, x_max, y_max]) 59 | # convert img embedding, mask, bounding box to torch tensor 60 | return torch.tensor(img_embed).float(), torch.tensor(img).float(), torch.tensor(gt2D[None, :, :]).long(), torch.tensor(bboxes).float(),\ 61 | torch.tensor(boundary[None, :, :]).long(), torch.tensor(depth_embed).float() 62 | 63 | # %% test dataset class and dataloader 64 | npz_tr_path = 'data/vit_b/COD_train' 65 | work_dir = './work_dir_cod' 66 | task_name = 'DSAM' 67 | # prepare SAM model 68 | model_type = 'vit_b' 69 | checkpoint = 'work_dir_cod/SAM/sam_vit_b_01ec64.pth' 70 | device = 'cuda:0' 71 | model_save_path = join(work_dir, task_name) 72 | os.makedirs(model_save_path, exist_ok=True) 73 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device) 74 | 75 | sam_model.train() 76 | # Set up the optimizer, hyperparameter tuning will improve performance here 77 | optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0) 78 | seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') 79 | CWD_loss = CriterionCWD(norm_type='channel', divergence='kl', temperature=4.0) 80 | 81 | num_epochs = 100 82 | losses = [] 83 | best_loss = 1e10 84 | train_dataset = NpzDataset(npz_tr_path) 85 | mask_threshold = 0.0 86 | train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) 87 | for epoch in range(num_epochs+1): 88 | epoch_loss = 0 89 | # train 90 | for step, (image_embedding, img, gt2D, boxes, boundary, depth_embedding) in enumerate(tqdm(train_dataloader)): 91 | 92 | with torch.no_grad(): 93 | box_np = boxes.numpy() 94 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size) 95 | box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1])) 96 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device) 97 | boundary = torch.as_tensor(boundary, dtype=torch.float, device=device) 98 | # boundary = boun_conv(boundary) 99 | image_embedding = torch.as_tensor(image_embedding, dtype=torch.float, device=device) 100 | depth_embedding = torch.as_tensor(depth_embedding, dtype=torch.float, device=device) 101 | if len(box_torch.shape) == 2: 102 | box_torch = box_torch[:, None, :] # (B, 1, 4) 103 | # get prompt embeddings 104 | sparse_embeddings_box, dense_embeddings_box = sam_model.prompt_encoder( 105 | points=None, 106 | boxes=box_torch, 107 | masks=None 108 | ) 109 | 110 | resize_img_tensor = np.transpose(img, (0, 3, 1, 2)).to(device) 111 | input_image = _upsample_like_1024(resize_img_tensor) 112 | pvt_embedding = sam_model.pvt(input_image)[3] 113 | 114 | bc_embedding, pvt_64 = sam_model.BC(pvt_embedding) 115 | # bc_embedding shape:" 1, 256, 64, 64 116 | distill_loss = CWD_loss(bc_embedding, depth_embedding) 117 | hybrid_embedding = torch.cat([pvt_64, bc_embedding], dim=1) 118 | high_frequency = sam_model.DWT(hybrid_embedding) 119 | 120 | dense_embeddings, sparse_embeddings = sam_model.ME(dense_embeddings_box, 121 | high_frequency, sparse_embeddings_box) 122 | 123 | # predicted masks 124 | mask_predictions, _ = sam_model.mask_decoder( 125 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64) 126 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 127 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 128 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 129 | multimask_output=False, 130 | ) 131 | 132 | final_mask = sam_model.loop_finer(mask_predictions, depth_embedding, depth_embedding) 133 | 134 | mask_predictions = 0.1*final_mask + 0.9*mask_predictions 135 | 136 | loss = 0.9*seg_loss(mask_predictions, gt2D.to(device)) + 0.1*distill_loss 137 | optimizer.zero_grad() 138 | loss.backward() 139 | optimizer.step() 140 | epoch_loss += loss.item() 141 | 142 | epoch_loss /= step 143 | losses.append(epoch_loss) 144 | print(f'EPOCH: {epoch}, Loss: {epoch_loss}') 145 | # save the latest model checkpoint 146 | if epoch >= 80 and epoch % 10 == 0: 147 | torch.save(sam_model.state_dict(), join(model_save_path, str(epoch) + 'sam_model.pth')) 148 | # save the best model 149 | if epoch_loss < best_loss: 150 | best_loss = epoch_loss 151 | torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_best.pth')) 152 | # plot loss 153 | plt.plot(losses) 154 | plt.title('Dice + Cross Entropy Loss') 155 | plt.xlabel('Epoch') 156 | plt.ylabel('Loss') 157 | # plt.show() # comment this line if you are running on a server 158 | plt.savefig(join(model_save_path, 'train_loss.png')) 159 | plt.close() -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | # from .FUSE import FUSE 17 | from .pvtv2 import pvt_v2_b2 18 | from .DWT import extract_high_frequency 19 | from .mix_enbedding import ME 20 | from .loop_finer import Loop_Finer 21 | from .bias_correction import bias_correction 22 | from torchvision.models.mobilenetv2 import InvertedResidual 23 | 24 | # from .sd_merge import MDS 25 | class Sam(nn.Module): 26 | mask_threshold: float = 0.0 27 | image_format: str = "RGB" 28 | 29 | def __init__( 30 | self, 31 | image_encoder: ImageEncoderViT, 32 | prompt_encoder: PromptEncoder, 33 | mask_decoder: MaskDecoder, 34 | pvt: pvt_v2_b2, 35 | DWT: extract_high_frequency, 36 | ME:ME, 37 | Loop_finer: Loop_Finer, 38 | BC: bias_correction, 39 | # IRB: InvertedResidual, 40 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 41 | pixel_std: List[float] = [58.395, 57.12, 57.375], 42 | ) -> None: 43 | """ 44 | SAM predicts object masks from an image and input prompts. 45 | 46 | Arguments: 47 | image_encoder (ImageEncoderViT): The backbone used to encode the 48 | image into image embeddings that allow for efficient mask prediction. 49 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 50 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 51 | and encoded prompts. 52 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 53 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 54 | """ 55 | super().__init__() 56 | self.image_encoder = image_encoder 57 | self.prompt_encoder = prompt_encoder 58 | self.mask_decoder = mask_decoder 59 | self.pvt = pvt 60 | self.BC = BC 61 | # self.IRB = IRB 62 | self. DWT = DWT 63 | self.ME = ME 64 | self.loop_finer = Loop_finer 65 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 66 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 67 | 68 | @property 69 | def device(self) -> Any: 70 | return self.pixel_mean.device 71 | 72 | @torch.no_grad() 73 | def forward( 74 | self, 75 | batched_input: List[Dict[str, Any]], 76 | multimask_output: bool, 77 | ) -> List[Dict[str, torch.Tensor]]: 78 | """ 79 | Predicts masks end-to-end from provided images and prompts. 80 | If prompts are not known in advance, using SamPredictor is 81 | recommended over calling the model directly. 82 | 83 | Arguments: 84 | batched_input (list(dict)): A list over input images, each a 85 | dictionary with the following keys. A prompt key can be 86 | excluded if it is not present. 87 | 'image': The image as a torch tensor in 3xHxW format, 88 | already transformed for input to the model. 89 | 'original_size': (tuple(int, int)) The original size of 90 | the image before transformation, as (H, W). 91 | 'point_coords': (torch.Tensor) Batched point prompts for 92 | this image, with shape BxNx2. Already transformed to the 93 | input frame of the model. 94 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 95 | with shape BxN. 96 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 97 | Already transformed to the input frame of the model. 98 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 99 | in the form Bx1xHxW. 100 | multimask_output (bool): Whether the model should predict multiple 101 | disambiguating masks, or return a single mask. 102 | 103 | Returns: 104 | (list(dict)): A list over input images, where each element is 105 | as dictionary with the following keys. 106 | 'masks': (torch.Tensor) Batched binary mask predictions, 107 | with shape BxCxHxW, where B is the number of input prompts, 108 | C is determined by multimask_output, and (H, W) is the 109 | original size of the image. 110 | 'iou_predictions': (torch.Tensor) The model's predictions 111 | of mask quality, in shape BxC. 112 | 'low_res_logits': (torch.Tensor) Low resolution logits with 113 | shape BxCxHxW, where H=W=256. Can be passed as mask input 114 | to subsequent iterations of prediction. 115 | """ 116 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 117 | image_embeddings = self.image_encoder(input_images) 118 | 119 | outputs = [] 120 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 121 | if "point_coords" in image_record: 122 | points = (image_record["point_coords"], image_record["point_labels"]) 123 | else: 124 | points = None 125 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 126 | points=points, 127 | boxes=image_record.get("boxes", None), 128 | masks=image_record.get("mask_inputs", None), 129 | ) 130 | low_res_masks, iou_predictions = self.mask_decoder( 131 | image_embeddings=curr_embedding.unsqueeze(0), 132 | image_pe=self.prompt_encoder.get_dense_pe(), 133 | sparse_prompt_embeddings=sparse_embeddings, 134 | dense_prompt_embeddings=dense_embeddings, 135 | multimask_output=multimask_output, 136 | ) 137 | masks = self.postprocess_masks( 138 | low_res_masks, 139 | input_size=image_record["image"].shape[-2:], 140 | original_size=image_record["original_size"], 141 | ) 142 | masks = masks > self.mask_threshold 143 | outputs.append( 144 | { 145 | "masks": masks, 146 | "iou_predictions": iou_predictions, 147 | "low_res_logits": low_res_masks, 148 | } 149 | ) 150 | return outputs 151 | 152 | def postprocess_masks( 153 | self, 154 | masks: torch.Tensor, 155 | input_size: Tuple[int, ...], 156 | original_size: Tuple[int, ...], 157 | ) -> torch.Tensor: 158 | """ 159 | Remove padding and upscale masks to the original image size. 160 | 161 | Arguments: 162 | masks (torch.Tensor): Batched masks from the mask_decoder, 163 | in BxCxHxW format. 164 | input_size (tuple(int, int)): The size of the image input to the 165 | model, in (H, W) format. Used to remove padding. 166 | original_size (tuple(int, int)): The original size of the image 167 | before resizing for input to the model, in (H, W) format. 168 | 169 | Returns: 170 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 171 | is given by original_size. 172 | """ 173 | masks = F.interpolate( 174 | masks, 175 | (self.image_encoder.img_size, self.image_encoder.img_size), 176 | mode="bilinear", 177 | align_corners=False, 178 | ) 179 | masks = masks[..., : input_size[0], : input_size[1]] 180 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 181 | return masks 182 | 183 | # def preprocess(self, x: torch.Tensor) -> torch.Tensor: 184 | # """Normalize pixel values and pad to a square input.""" 185 | # # Normalize colors 186 | # x = (x - self.pixel_mean) / self.pixel_std 187 | # 188 | # # Pad 189 | # h, w = x.shape[-2:] 190 | # padh = 352 - h 191 | # padw = 352 - w 192 | # x = F.pad(x, (0, padw, 0, padh)) 193 | # return x 194 | 195 | # 20240201 sam原本的代码 196 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 197 | """Normalize pixel values and pad to a square input.""" 198 | # Normalize colors 199 | x = (x - self.pixel_mean) / self.pixel_std 200 | 201 | # Pad 202 | h, w = x.shape[-2:] 203 | padh = self.image_encoder.img_size - h 204 | padw = self.image_encoder.img_size - w 205 | x = F.pad(x, (0, padw, 0, padh)) 206 | return x 207 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /Mytest.py: -------------------------------------------------------------------------------- 1 | # %% load environment 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | 6 | join = os.path.join 7 | import torch 8 | from segment_anything import sam_model_registry 9 | from segment_anything.utils.transforms import ResizeLongestSide 10 | from tqdm import tqdm 11 | import argparse 12 | import traceback 13 | import sys 14 | import torch.nn as nn 15 | import shutil 16 | from torchvision.models.mobilenetv2 import InvertedResidual 17 | import os 18 | 19 | torch.manual_seed(1) 20 | np.random.seed(1) 21 | # %% run inference 22 | # set up the parser 23 | parser = argparse.ArgumentParser(description='run inference on testing set') 24 | parser.add_argument('-i', '--data_path', type=str, default='data/vit_b/COD_test', help='path to the data folder') 25 | parser.add_argument('-o', '--seg_path_root', type=str, default='data/inference_npz/DSAM', 26 | help='path to the segmentation folder') 27 | parser.add_argument('--seg_png_path', type=str, default='data/inference_img/DSAM', 28 | help='path to the segmentation folder') 29 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type') 30 | parser.add_argument('--device', type=str, default='cuda:1', help='device') 31 | parser.add_argument('-chk', '--checkpoint', type=str, default='work_dir_cod/DSAM/DSAM.pth', 32 | help='path to the trained model') 33 | args = parser.parse_args() 34 | 35 | 36 | mDSC = [] 37 | num_of_mDSC = 0 38 | 39 | def show_mask(mask, ax, random_color=False): 40 | if random_color: 41 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 42 | else: 43 | color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) 44 | h, w = mask.shape[-2:] 45 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 46 | ax.imshow(mask_image) 47 | 48 | 49 | def show_box(box, ax): 50 | x0, y0 = box[0], box[1] 51 | w, h = box[2] - box[0], box[3] - box[1] 52 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2)) 53 | 54 | 55 | def compute_dice(mask_gt, mask_pred): 56 | """Compute soerensen-dice coefficient. 57 | Returns: 58 | the dice coeffcient as float. If both masks are empty, the result is NaN 59 | """ 60 | volume_sum = mask_gt.sum() + mask_pred.sum() 61 | if volume_sum == 0: 62 | return np.NaN 63 | volume_intersect = (mask_gt & mask_pred).sum() 64 | return 2 * volume_intersect / volume_sum 65 | 66 | 67 | def finetune_model_predict(img_np, box_np, depth_np, boundary, sam_trans, sam_model_tune, device=args.device): 68 | H, W = img_np.shape[:2] 69 | # Original image processing 70 | resize_img = sam_trans.apply_image(img_np) 71 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device) # (3, 1024, 1024) 72 | input_image = sam_model_tune.preprocess(resize_img_tensor[None, :, :, :]) # (1, 3, 1024, 1024) 73 | # Depth map processing 74 | resize_depth_img = sam_trans.apply_image(depth_np) 75 | resize_dep_tensor = torch.as_tensor(resize_depth_img.transpose(2, 0, 1)).to(device) 76 | depth_image = sam_model_tune.preprocess(resize_dep_tensor[None, :, :, :]) # (1, 3, 1024, 1024) 77 | 78 | with torch.no_grad(): 79 | image_embedding = sam_model_tune.image_encoder(input_image.to(device)) # (1, 256, 64, 64) 80 | depth_embedding = sam_model_tune.image_encoder(depth_image.to(device)) # (1, 256, 64, 64) 81 | 82 | # convert box to 1024x1024 grid 83 | box = sam_trans.apply_boxes(box_np, (H, W)) 84 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device) 85 | 86 | if len(box_torch.shape) == 2: 87 | box_torch = box_torch[:, None, :] # (B, 1, 4) 88 | 89 | sparse_embeddings_box, dense_embeddings_box = sam_model_tune.prompt_encoder( 90 | points=None, 91 | boxes=box_torch, 92 | masks=None 93 | ) 94 | 95 | pvt_embedding = sam_model_tune.pvt(input_image)[3] 96 | bc_embedding, pvt_64 = sam_model_tune.BC(pvt_embedding) 97 | hybrid_embedding = torch.cat([pvt_64, bc_embedding], dim=1) 98 | high_frequency = sam_model_tune.DWT(hybrid_embedding) 99 | dense_embeddings, sparse_embeddings = sam_model_tune.ME(dense_embeddings_box, 100 | high_frequency, sparse_embeddings_box) 101 | # predicted masks 102 | seg_prob, _ = sam_model_tune.mask_decoder( 103 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64) hybrid_embedding:(1, 512, 64, 64) 104 | image_pe=sam_model_tune.prompt_encoder.get_dense_pe(), # 105 | # (1, 256, 64, 64) 106 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 107 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 108 | multimask_output=False, 109 | ) 110 | final_mask = sam_model_tune.loop_finer(seg_prob, depth_embedding, depth_embedding) 111 | seg_prob = 0.1 * final_mask + 0.9 * seg_prob 112 | seg_prob = torch.sigmoid(seg_prob) 113 | # convert soft mask to hard mask 114 | seg_prob = seg_prob.cpu().numpy().squeeze() 115 | seg = (seg_prob > 0.5).astype(np.uint8) 116 | return seg 117 | 118 | def delete_folder(folder_path): 119 | try: 120 | # 删除文件夹及其内容 121 | shutil.rmtree(folder_path) 122 | print(f"delete folder: {folder_path}") 123 | except Exception as e: 124 | print(f"fail to delete folder: {folder_path}: {e}") 125 | 126 | def divide(x, y): 127 | try: 128 | result = x / y 129 | return result 130 | except ZeroDivisionError as e: 131 | # 处理除以零的情况 132 | delete_folder(args.seg_path_root) 133 | delete_folder(args.seg_png_path) 134 | print("division by zero:", e) 135 | 136 | 137 | device = args.device 138 | sam_model_tune = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(device) 139 | 140 | sam_trans = ResizeLongestSide(sam_model_tune.image_encoder.img_size) 141 | 142 | npz_folders = sorted(os.listdir(args.data_path)) 143 | os.makedirs(args.seg_png_path, exist_ok=True) 144 | sam_dice_scores = [] 145 | for npz_folder in npz_folders: 146 | npz_data_path = join(args.data_path, npz_folder) 147 | save_path = join(args.seg_path_root, npz_folder) 148 | if not os.path.exists(save_path): 149 | os.makedirs(save_path, exist_ok=True) 150 | npz_files = sorted(os.listdir(npz_data_path)) 151 | for npz_file in tqdm(npz_files): 152 | try: 153 | npz = np.load(join(npz_data_path, npz_file)) 154 | print(npz_file) 155 | ori_imgs = npz['imgs'] 156 | ori_gts = npz['gts'] 157 | ori_number = npz['number'] 158 | boundary = npz['boundary'] 159 | dep_imgs = npz['depth_imgs'] 160 | 161 | sam_segs = [] 162 | sam_bboxes = [] 163 | sam_dice_scores = [] 164 | for img_id, ori_img in enumerate(ori_imgs): 165 | # get bounding box from mask 166 | gt2D = ori_gts[img_id] 167 | bboundary = boundary[img_id] 168 | depth_img = dep_imgs[img_id] 169 | 170 | y_indices, x_indices = np.where(gt2D > 0) 171 | x_min, x_max = np.min(x_indices), np.max(x_indices) 172 | y_min, y_max = np.min(y_indices), np.max(y_indices) 173 | # add perturbation to bounding box coordinates 174 | H, W = gt2D.shape 175 | x_min = max(0, x_min - np.random.randint(0, 20)) 176 | x_max = min(W, x_max + np.random.randint(0, 20)) 177 | y_min = max(0, y_min - np.random.randint(0, 20)) 178 | y_max = min(H, y_max + np.random.randint(0, 20)) 179 | bbox = np.array([x_min, y_min, x_max, y_max]) 180 | seg_mask = finetune_model_predict(ori_img, bbox, depth_img, bboundary, sam_trans, sam_model_tune, device=device) 181 | sam_segs.append(seg_mask) 182 | sam_bboxes.append(bbox) 183 | # these 2D dice scores are for debugging purpose. 184 | # 3D dice scores should be computed for 3D images 185 | sam_dice_scores.append(compute_dice(seg_mask > 0, gt2D > 0)) 186 | 187 | # save npz, including sam_segs, sam_bboxes, sam_dice_scores 188 | np.savez_compressed(join(save_path, npz_file), medsam_segs=sam_segs, gts=ori_gts, number=ori_number, sam_bboxes=sam_bboxes) 189 | 190 | # visualize segmentation results 191 | img_id = np.random.randint(0, len(ori_imgs)) 192 | # show ground truth and segmentation results in two subplots 193 | fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 194 | axes[0].imshow(ori_imgs[img_id]) 195 | show_box(sam_bboxes[img_id], axes[0]) 196 | show_mask(ori_gts[img_id], axes[0]) 197 | axes[0].set_title('Ground Truth') 198 | axes[0].axis('off') 199 | 200 | axes[1].imshow(ori_imgs[img_id]) 201 | show_box(sam_bboxes[img_id], axes[1]) 202 | show_mask(sam_segs[img_id], axes[1]) 203 | axes[1].set_title('DSAM: DSC={:.3f}'.format(sam_dice_scores[img_id])) 204 | axes[1].axis('off') 205 | # save figure 206 | fig.savefig(join(args.seg_png_path, npz_file.split('.npz')[0] + '.png')) 207 | # close figure 208 | plt.close(fig) 209 | except Exception: 210 | traceback.print_exc() 211 | print('error in {}'.format(npz_file)) 212 | 213 | tmp_mDSC = divide(sum(sam_dice_scores), len(sam_dice_scores)) 214 | mDSC.append(tmp_mDSC) 215 | 216 | print(str(npz_folder)+": " + str(tmp_mDSC)) 217 | # average number of 218 | print("finial average mDSC: " + str((sum(mDSC)/len(mDSC)))) -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /pre_npz.py: -------------------------------------------------------------------------------- 1 | # %% import packages 2 | import numpy as np 3 | import os 4 | from glob import glob 5 | import pandas as pd 6 | import cv2 7 | join = os.path.join 8 | from skimage import transform, io, segmentation 9 | from tqdm import tqdm 10 | import torch 11 | from segment_anything import sam_model_registry 12 | from segment_anything.utils.transforms import ResizeLongestSide 13 | import argparse 14 | from PIL import Image 15 | 16 | # set up the parser 17 | parser = argparse.ArgumentParser(description="preprocess grey and RGB images") 18 | 19 | # add arguments to the parser 20 | parser.add_argument( 21 | "-i", 22 | "--img_path", 23 | type=str, 24 | default="data/TrainDataset/Imgs", 25 | help="path to the images", 26 | ) 27 | parser.add_argument( 28 | "-gt", 29 | "--gt_path", 30 | type=str, 31 | default="data/TrainDataset/GT", 32 | help="path to the ground truth (gt)", 33 | ) 34 | 35 | parser.add_argument( 36 | "-depth", 37 | "--depth_path", 38 | type=str, 39 | default="data/TrainDataset/depth", 40 | help="path to the ground truth (gt)", 41 | ) 42 | 43 | parser.add_argument( 44 | "--csv", 45 | type=str, 46 | default=None, 47 | help="path to the csv file", 48 | ) 49 | 50 | parser.add_argument( 51 | "-o", 52 | "--npz_path", 53 | type=str, 54 | default="data/npz", 55 | help="path to save the npz files", 56 | ) 57 | parser.add_argument( 58 | "--data_name", 59 | type=str, 60 | default="COD_Test_CAMO", 61 | help="dataset name; used to name the final npz file, e.g., demo2d.npz", 62 | ) 63 | parser.add_argument("--image_size", type=int, default=256, help="image size") 64 | # parser.add_argument("--depth_size", type=int, default=256, help="depth image size") 65 | parser.add_argument( 66 | "--img_name_suffix", type=str, default=".jpg", help="image name suffix" 67 | ) 68 | parser.add_argument("--label_id", type=int, default=255, help="label id") 69 | parser.add_argument("--model_type", type=str, default="vit_b", help="model type") 70 | parser.add_argument( 71 | "--checkpoint", 72 | type=str, 73 | default="work_dir/SAM/sam_vit_b_01ec64.pth", 74 | help="checkpoint", 75 | ) 76 | parser.add_argument("--device", type=str, default="cuda:2", help="device") 77 | parser.add_argument("--seed", type=int, default=2023, help="random seed") 78 | 79 | # parse the arguments 80 | args = parser.parse_args() 81 | 82 | # convert 2d grey or rgb images to npz file 83 | imgs = [] 84 | gts = [] 85 | depth_imgs = [] 86 | number = [] 87 | boundary = [] 88 | img_embeddings = [] 89 | depth_embeddings = [] 90 | global num_of_processed_imgs 91 | num_of_processed_imgs = 0 92 | 93 | 94 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to( 95 | args.device 96 | ) 97 | # create a directory to save the npz files 98 | save_path = args.npz_path + "_" + args.model_type 99 | os.makedirs(save_path, exist_ok=True) 100 | 101 | 102 | def find_bundary(img, mask, path, name): 103 | mask = mask * 255 # convert to 0-255 104 | kernel = np.ones((3, 3), dtype=np.uint8) 105 | fore = cv2.dilate(mask, kernel, 3) 106 | dilate = (fore - mask) 107 | kernel_again = np.ones((5, 5), dtype=np.uint8) 108 | dilate_again = cv2.dilate(dilate, kernel_again, 3) 109 | 110 | edges = cv2.Canny(img, 0.2, 0.6) 111 | # edges[edges > 0] = 1 112 | # edges = edges.astype(np.uint8)/255.0 113 | os.makedirs(os.path.join(path + '/dilate/'), exist_ok=True) 114 | os.makedirs(os.path.join(path + '/boundary/'), exist_ok=True) 115 | os.makedirs(os.path.join(path + '/canny/'), exist_ok=True) 116 | cv2.imwrite(os.path.join(path + '/dilate/', name), fore) 117 | cv2.imwrite(os.path.join(path + '/boundary/', name), dilate_again) 118 | cv2.imwrite(os.path.join(path + '/canny/', name), edges) 119 | 120 | edges_2 = Image.open(os.path.join(os.path.join(path + '/canny/', name))).convert('1') 121 | dilate_again_2 = Image.open(os.path.join(os.path.join(path + '/boundary/', name))).convert('1') 122 | 123 | boundary_grads = (Image.fromarray(np.array(edges_2) * np.array(dilate_again_2))) 124 | os.makedirs(os.path.join(path + '/boundary_grads/'), exist_ok=True) 125 | boundary_grads.save(os.path.join(os.path.join(path + '/boundary_grads/', gt_name))) 126 | 127 | return boundary_grads 128 | 129 | def process(gt_name: str, image_name: str, num_of_processed_imgs:int): 130 | if image_name == None: 131 | image_name = gt_name.split(".")[0] + args.img_name_suffix # Find the name of images based on the name of GT 132 | gt_data = io.imread(join(args.gt_path, gt_name)) 133 | 134 | # if it is rgb, select the first channel 135 | if len(gt_data.shape) == 3: 136 | gt_data = gt_data[:, :, 0] 137 | assert len(gt_data.shape) == 2, "ground truth should be 2D" 138 | 139 | # resize ground truch image 140 | gt_data = transform.resize( 141 | gt_data == args.label_id, 142 | (args.image_size, args.image_size), 143 | order=0, 144 | preserve_range=True, 145 | mode="constant", 146 | ) 147 | 148 | # convert to uint8 149 | gt_data = np.uint8(gt_data) 150 | 151 | 152 | if np.sum(gt_data) > 0: # don't exclude tiny objects(Polyps may be small in shape and there are tiny objects in the COD) 153 | """Optional binary thresholding can be added""" 154 | assert ( 155 | np.max(gt_data) == 1 and np.unique(gt_data).shape[0] == 2 156 | ), "ground truth should be binary" 157 | # image preprocessing 158 | image_data = io.imread(join(args.img_path, image_name)) 159 | # Remove any alpha channel if present. 160 | if image_data.shape[-1] > 3 and len(image_data.shape) == 3: 161 | image_data = image_data[:, :, :3] 162 | # If image is grayscale, then repeat the last channel to convert to rgb 163 | if len(image_data.shape) == 2: 164 | image_data = np.repeat(image_data[:, :, None], 3, axis=-1) 165 | # nii preprocess start 166 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile( 167 | image_data, 99.5 168 | ) 169 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 170 | # min-max normalize and scale 171 | image_data_pre = ( 172 | (image_data_pre - np.min(image_data_pre)) 173 | / (np.max(image_data_pre) - np.min(image_data_pre)) 174 | * 255.0 175 | ) 176 | image_data_pre[image_data == 0] = 0 177 | 178 | image_data_pre = transform.resize( 179 | image_data_pre, 180 | (args.image_size, args.image_size), 181 | order=3, 182 | preserve_range=True, 183 | mode="constant", 184 | anti_aliasing=True, 185 | ) 186 | image_data_pre = np.uint8(image_data_pre) 187 | 188 | imgs.append(image_data_pre) 189 | # End of image preprocessing 190 | 191 | # depth image preprocessing 192 | depth_data = io.imread(join(args.depth_path, gt_name)) 193 | 194 | gt_data_pre = transform.resize( 195 | gt_data, 196 | (args.image_size, args.image_size), 197 | order=1, 198 | preserve_range=True, 199 | mode="constant", 200 | anti_aliasing=True, 201 | ) 202 | depth_data_pre = transform.resize( 203 | depth_data, 204 | (args.image_size, args.image_size), 205 | order=1, 206 | preserve_range=True, 207 | mode="constant", 208 | anti_aliasing=True, 209 | ) 210 | y_indices, x_indices = np.where(gt_data_pre > 0) 211 | x_min, x_max = np.min(x_indices), np.max(x_indices) 212 | y_min, y_max = np.min(y_indices), np.max(y_indices) 213 | # add perturbation to bounding box coordinates 214 | H, W = args.image_size, args.image_size 215 | x_min = max(0, x_min - int(10)) 216 | x_max = min(W, x_max + int(10)) 217 | y_min = max(0, y_min - int(10)) 218 | y_max = min(H, y_max + int(10)) 219 | 220 | depth_data = depth_data_pre[y_min:y_max, x_min:x_max] 221 | os.makedirs(os.path.join(save_path+'/'+args.data_name + '/depth_crop/'), exist_ok=True) 222 | cv2.imwrite(os.path.join(save_path+'/'+args.data_name + '/depth_crop/', gt_name), depth_data) 223 | # Remove any alpha channel if present. 224 | if depth_data.shape[-1] > 3 and len(depth_data.shape) == 3: 225 | depth_data = depth_data[:, :, :3] 226 | # If image is grayscale, then repeat the last channel to convert to rgb 227 | if len(depth_data.shape) == 2: 228 | depth_data = np.repeat(depth_data[:, :, None], 3, axis=-1) 229 | # nii preprocess start 230 | lower_bound, upper_bound = np.percentile(depth_data, 0.5), np.percentile( 231 | depth_data, 99.5 232 | ) 233 | depth_data_pre = np.clip(depth_data, lower_bound, upper_bound) 234 | # min-max normalize and scale 235 | depth_data_pre = ( 236 | (depth_data_pre - np.min(depth_data_pre)) 237 | / (np.max(depth_data_pre) - np.min(depth_data_pre)) 238 | * 255.0 239 | ) 240 | depth_data_pre[depth_data == 0] = 0 241 | 242 | depth_data_pre = transform.resize( 243 | depth_data_pre, 244 | (args.image_size, args.image_size), 245 | order=3, 246 | preserve_range=True, 247 | mode="constant", 248 | anti_aliasing=True, 249 | ) 250 | depth_data_pre = np.uint8(depth_data_pre) 251 | depth_imgs.append(depth_data_pre) 252 | 253 | # End of depth image preprocessing 254 | 255 | number.append(image_name) 256 | 257 | print("the number of images: " + str(len(imgs)) + " and the name of image: " + str(image_name)) 258 | num_of_processed_imgs = num_of_processed_imgs + 1 259 | assert np.sum(gt_data) > 0, "ground truth should have more than 0 pixels ,because of the tiny objects in COD" 260 | 261 | gts.append(gt_data) 262 | boundary.append(find_bundary(image_data_pre, gt_data, save_path+'/'+args.data_name, gt_name)) 263 | print("the boundary_grad is produced now!!") 264 | 265 | # img -->img embedding 266 | # resize image to 3*1024*1024 267 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 268 | resize_img = sam_transform.apply_image(image_data_pre) 269 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to( 270 | args.device 271 | ) 272 | input_image = sam_model.preprocess( 273 | resize_img_tensor[None, :, :, :] 274 | ) # (1, 3, 1024, 1024) 275 | assert input_image.shape == ( 276 | 1, 277 | 3, 278 | sam_model.image_encoder.img_size, 279 | sam_model.image_encoder.img_size, 280 | ), "input image should be resized to 1024*1024" 281 | # pre-compute the image embedding 282 | with torch.no_grad(): 283 | embedding = sam_model.image_encoder(input_image) 284 | img_embeddings.append(embedding.cpu().numpy()[0]) 285 | # end of img -->img embedding 286 | 287 | # depth_img -->depth_img embedding 288 | # resize image to 3*1024*1024 289 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 290 | resize_depth_img = sam_transform.apply_image(depth_data_pre) 291 | resize_depth_img_tensor = torch.as_tensor(resize_depth_img.transpose(2, 0, 1)).to( 292 | args.device 293 | ) 294 | input_depth_image = sam_model.preprocess( 295 | resize_depth_img_tensor[None, :, :, :] 296 | ) # (1, 3, 1024, 1024) 297 | assert input_depth_image.shape == ( 298 | 1, 299 | 3, 300 | sam_model.image_encoder.img_size, 301 | sam_model.image_encoder.img_size, 302 | ), "input depth image should be resized to 1024*1024" 303 | # pre-compute the image embedding 304 | with torch.no_grad(): 305 | depth_embedding = sam_model.image_encoder(input_depth_image) 306 | depth_embeddings.append(depth_embedding.cpu().numpy()[0]) 307 | # end of depth_img -->depth_img embedding 308 | 309 | return num_of_processed_imgs 310 | 311 | 312 | if args.csv != None: 313 | # if data is presented in csv format 314 | # columns must be named image_filename and mask_filename respectively 315 | try: 316 | os.path.exists(args.csv) 317 | except FileNotFoundError as e: 318 | print(f"File {args.csv} not found!!") 319 | 320 | df = pd.read_csv(args.csv) 321 | bar = tqdm(df.iterrows(), total=len(df)) 322 | for idx, row in bar: 323 | process(row.mask_filename, row.image_filename) 324 | 325 | else: 326 | # get all the names of the images in the ground truth folder 327 | # names = sorted(os.listdir(args.gt_path), key=lambda x: int(x.split('.')[0])) # Sort by numeric size of filenames, but not all filenames are numeric 328 | names = sorted(os.listdir(args.gt_path)) 329 | # print the number of images found in the ground truth folder 330 | print("image number:", len(names)) 331 | for gt_name in tqdm(names): 332 | num_of_processed_imgs = process(gt_name, None, num_of_processed_imgs) 333 | print("the number of processed images: " + str(num_of_processed_imgs)) 334 | 335 | 336 | 337 | # stack the list to array 338 | print("Num. of images:", len(imgs)) 339 | if len(imgs) > 1: 340 | # The features are stacked and become embedding 341 | imgs = np.stack(imgs, axis=0) # (n, 256, 256, 3) 342 | gts = np.stack(gts, axis=0) # (n, 256, 256) 343 | depth_imgs = np.stack(depth_imgs, axis=0) # (n, 256, 256) 344 | img_embeddings = np.stack(img_embeddings, axis=0) # (n, 1, 256, 64, 64) 345 | depth_embeddings = np.stack(depth_embeddings, axis=0) # (n, 1, 256, 64, 64) 346 | boundary = np.stack(boundary, axis=0) # (n, 256, 256) 347 | np.savez_compressed( 348 | join(save_path, args.data_name + ".npz"), 349 | imgs=imgs, 350 | gts=gts, 351 | depth_imgs=depth_imgs, 352 | number=number, 353 | img_embeddings=img_embeddings, 354 | boundary=boundary, 355 | depth_embeddings=depth_embeddings, 356 | ) 357 | # save an example image for sanity check 358 | idx = np.random.randint(imgs.shape[0]) 359 | img_idx = imgs[idx, :, :, :] 360 | gt_idx = gts[idx, :, :] 361 | bd = segmentation.find_boundaries(gt_idx, mode="inner") 362 | img_idx[bd, :] = [255, 0, 0] 363 | io.imsave(save_path + ".png", img_idx, check_contrast=False) 364 | else: 365 | print( 366 | "Do not find image and ground-truth pairs. Please check your dataset and argument settings" 367 | ) -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros_like(data["boxes"][:, 0]), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros_like(boxes[:, 0]), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /MSCAF_COD_evaluation/sod_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 3 | 4 | __version__ = '1.2.1' 5 | 6 | _EPS = 1e-16 7 | _TYPE = np.float64 8 | 9 | 10 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: 11 | gt = gt > 128 12 | # im2double, mapminmax 13 | pred = pred / 255 14 | if pred.max() != pred.min(): 15 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 16 | return pred, gt 17 | 18 | 19 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 20 | return min(2 * matrix.mean(), max_value) 21 | 22 | 23 | class Fmeasure(object): 24 | def __init__(self, beta: float = 0.3): 25 | self.beta = beta 26 | self.precisions = [] 27 | self.recalls = [] 28 | self.adaptive_fms = [] 29 | self.changeable_fms = [] 30 | 31 | def step(self, pred: np.ndarray, gt: np.ndarray): 32 | pred, gt = _prepare_data(pred, gt) 33 | 34 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 35 | self.adaptive_fms.append(adaptive_fm) 36 | 37 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 38 | self.precisions.append(precisions) 39 | self.recalls.append(recalls) 40 | self.changeable_fms.append(changeable_fms) 41 | 42 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 43 | # 快速统计numpy数组的非零值建议使用np.count_nonzero, 44 | # 一个简单的小实验可见tests/test_speed_for_count_nonzero.py 45 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 46 | binary_predcition = pred >= adaptive_threshold 47 | area_intersection = binary_predcition[gt].sum() 48 | if area_intersection == 0: 49 | adaptive_fm = 0 50 | else: 51 | pre = area_intersection / np.count_nonzero(binary_predcition) 52 | rec = area_intersection / np.count_nonzero(gt) 53 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 54 | return adaptive_fm 55 | 56 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 57 | # 1. 获取预测结果在真值前背景区域中的直方图 58 | pred = (pred * 255).astype(np.uint8) 59 | bins = np.linspace(0, 256, 257) 60 | fg_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256] 61 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 62 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量 63 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域 64 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 65 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 66 | # 3. 使用不同阈值的结果计算对应的precision和recall 67 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1 68 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量 69 | TPs = fg_w_thrs 70 | Ps = fg_w_thrs + bg_w_thrs 71 | # 为防止除0,这里针对除0的情况分析后直接对于0分母设为1,因为此时分子必为0 72 | Ps[Ps == 0] = 1 73 | T = max(np.count_nonzero(gt), 1) 74 | # TODO: T=0 或者 特定阈值下fg_w_thrs=0或者bg_w_thrs=0,这些都会包含在TPs[i]=0的情况中, 75 | # 但是这里使用TPs不便于处理列表 76 | # T=0 -> fg_w_thrs=[0, ..., 0] -> TPs=[0, ..., 0] 解决办法:T重新赋值为1 77 | # Ps[i] = 0 -> fg_w_thrs[i] = 0, bg_w_thrs[i] = 0 78 | precisions = TPs / Ps 79 | recalls = TPs / T 80 | 81 | numerator = (1 + self.beta) * precisions * recalls 82 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 83 | changeable_fms = numerator / denominator 84 | return precisions, recalls, changeable_fms 85 | 86 | def get_results(self) -> dict: 87 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 88 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 89 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 90 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 91 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), 92 | pr=dict(p=precision, r=recall)) 93 | 94 | 95 | class MAE(object): 96 | def __init__(self): 97 | self.maes = [] 98 | 99 | def step(self, pred: np.ndarray, gt: np.ndarray): 100 | pred, gt = _prepare_data(pred, gt) 101 | 102 | mae = self.cal_mae(pred, gt) 103 | self.maes.append(mae) 104 | 105 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float: 106 | mae = np.mean(np.abs(pred - gt)) 107 | return mae 108 | 109 | def get_results(self) -> dict: 110 | mae = np.mean(np.array(self.maes, _TYPE)) 111 | return dict(mae=mae) 112 | 113 | 114 | class Smeasure(object): 115 | def __init__(self, alpha: float = 0.5): 116 | self.sms = [] 117 | self.alpha = alpha 118 | 119 | def step(self, pred: np.ndarray, gt: np.ndarray): 120 | pred, gt = _prepare_data(pred=pred, gt=gt) 121 | 122 | sm = self.cal_sm(pred, gt) 123 | self.sms.append(sm) 124 | 125 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 126 | y = np.mean(gt) 127 | if y == 0: 128 | sm = 1 - np.mean(pred) 129 | elif y == 1: 130 | sm = np.mean(pred) 131 | else: 132 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 133 | sm = max(0, sm) 134 | return sm 135 | 136 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 137 | fg = pred * gt 138 | bg = (1 - pred) * (1 - gt) 139 | u = np.mean(gt) 140 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 141 | return object_score 142 | 143 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 144 | x = np.mean(pred[gt == 1]) 145 | sigma_x = np.std(pred[gt == 1]) 146 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 147 | return score 148 | 149 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 150 | x, y = self.centroid(gt) 151 | part_info = self.divide_with_xy(pred, gt, x, y) 152 | w1, w2, w3, w4 = part_info['weight'] 153 | # assert np.isclose(w1 + w2 + w3 + w4, 1), (w1 + w2 + w3 + w4, pred.mean(), gt.mean()) 154 | 155 | pred1, pred2, pred3, pred4 = part_info['pred'] 156 | gt1, gt2, gt3, gt4 = part_info['gt'] 157 | score1 = self.ssim(pred1, gt1) 158 | score2 = self.ssim(pred2, gt2) 159 | score3 = self.ssim(pred3, gt3) 160 | score4 = self.ssim(pred4, gt4) 161 | 162 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 163 | 164 | def centroid(self, matrix: np.ndarray) -> tuple: 165 | """ 166 | 为了保证与matlab代码的一致性,这里对中心坐标进行了加一,在后面划分区域的时候就不用使用多余的加一操作 167 | 因为matlab里的1:X生成的序列会包含X这个值 168 | """ 169 | h, w = matrix.shape 170 | if matrix.sum() == 0: 171 | x = np.round(w / 2) 172 | y = np.round(h / 2) 173 | else: 174 | area_object = np.sum(matrix) 175 | row_ids = np.arange(h) 176 | col_ids = np.arange(w) 177 | x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object) 178 | y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object) 179 | return int(x) + 1, int(y) + 1 180 | 181 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: 182 | h, w = gt.shape 183 | area = h * w 184 | 185 | gt_LT = gt[0:y, 0:x] 186 | gt_RT = gt[0:y, x:w] 187 | gt_LB = gt[y:h, 0:x] 188 | gt_RB = gt[y:h, x:w] 189 | 190 | pred_LT = pred[0:y, 0:x] 191 | pred_RT = pred[0:y, x:w] 192 | pred_LB = pred[y:h, 0:x] 193 | pred_RB = pred[y:h, x:w] 194 | 195 | w1 = x * y / area 196 | w2 = y * (w - x) / area 197 | w3 = (h - y) * x / area 198 | # w4 = (h - y) * (w - x) / area 199 | w4 = 1 - w1 - w2 - w3 200 | 201 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), 202 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 203 | weight=(w1, w2, w3, w4)) 204 | 205 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 206 | h, w = pred.shape 207 | N = h * w 208 | 209 | x = np.mean(pred) 210 | y = np.mean(gt) 211 | 212 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 213 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 214 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 215 | 216 | alpha = 4 * x * y * sigma_xy 217 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 218 | 219 | if alpha != 0: 220 | score = alpha / (beta + _EPS) 221 | elif alpha == 0 and beta == 0: 222 | score = 1 223 | else: 224 | score = 0 225 | return score 226 | 227 | def get_results(self) -> dict: 228 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 229 | return dict(sm=sm) 230 | 231 | 232 | class Emeasure(object): 233 | def __init__(self): 234 | self.adaptive_ems = [] 235 | self.changeable_ems = [] 236 | 237 | def step(self, pred: np.ndarray, gt: np.ndarray): 238 | pred, gt = _prepare_data(pred=pred, gt=gt) 239 | self.gt_fg_numel = np.count_nonzero(gt) 240 | self.gt_size = gt.shape[0] * gt.shape[1] 241 | 242 | changeable_ems = self.cal_changeable_em(pred, gt) 243 | self.changeable_ems.append(changeable_ems) 244 | adaptive_em = self.cal_adaptive_em(pred, gt) 245 | self.adaptive_ems.append(adaptive_em) 246 | 247 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 248 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 249 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 250 | return adaptive_em 251 | 252 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 253 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 254 | return changeable_ems 255 | 256 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 257 | """ 258 | 函数内部变量命名规则: 259 | pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义 260 | 如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换 261 | """ 262 | binarized_pred = pred >= threshold 263 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 264 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 265 | 266 | fg___numel = fg_fg_numel + fg_bg_numel 267 | bg___numel = self.gt_size - fg___numel 268 | 269 | if self.gt_fg_numel == 0: 270 | enhanced_matrix_sum = bg___numel 271 | elif self.gt_fg_numel == self.gt_size: 272 | enhanced_matrix_sum = fg___numel 273 | else: 274 | parts_numel, combinations = self.generate_parts_numel_combinations( 275 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, 276 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, 277 | ) 278 | 279 | results_parts = [] 280 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 281 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 282 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 283 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 284 | results_parts.append(enhanced_matrix_value * part_numel) 285 | enhanced_matrix_sum = sum(results_parts) 286 | 287 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 288 | return em 289 | 290 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 291 | """ 292 | 函数内部变量命名规则: 293 | pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义 294 | 如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换 295 | """ 296 | pred = (pred * 255).astype(np.uint8) 297 | bins = np.linspace(0, 256, 257) 298 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 299 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 300 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 301 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 302 | 303 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 304 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 305 | 306 | if self.gt_fg_numel == 0: 307 | enhanced_matrix_sum = bg___numel_w_thrs 308 | elif self.gt_fg_numel == self.gt_size: 309 | enhanced_matrix_sum = fg___numel_w_thrs 310 | else: 311 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 312 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, 313 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, 314 | ) 315 | 316 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 317 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 318 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 319 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 320 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 321 | results_parts[i] = enhanced_matrix_value * part_numel 322 | enhanced_matrix_sum = results_parts.sum(axis=0) 323 | 324 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 325 | return em 326 | 327 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): 328 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 329 | bg_bg_numel = pred_bg_numel - bg_fg_numel 330 | 331 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 332 | 333 | mean_pred_value = pred_fg_numel / self.gt_size 334 | mean_gt_value = self.gt_fg_numel / self.gt_size 335 | 336 | demeaned_pred_fg_value = 1 - mean_pred_value 337 | demeaned_pred_bg_value = 0 - mean_pred_value 338 | demeaned_gt_fg_value = 1 - mean_gt_value 339 | demeaned_gt_bg_value = 0 - mean_gt_value 340 | 341 | combinations = [ 342 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 343 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 344 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 345 | (demeaned_pred_bg_value, demeaned_gt_bg_value) 346 | ] 347 | return parts_numel, combinations 348 | 349 | def get_results(self) -> dict: 350 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 351 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 352 | return dict(em=dict(adp=adaptive_em, curve=changeable_em)) 353 | 354 | 355 | class WeightedFmeasure(object): 356 | def __init__(self, beta: float = 1): 357 | self.beta = beta 358 | self.weighted_fms = [] 359 | 360 | def step(self, pred: np.ndarray, gt: np.ndarray): 361 | pred, gt = _prepare_data(pred=pred, gt=gt) 362 | 363 | if np.all(~gt): 364 | wfm = 0 365 | else: 366 | wfm = self.cal_wfm(pred, gt) 367 | self.weighted_fms.append(wfm) 368 | 369 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 370 | # [Dst,IDXT] = bwdist(dGT); 371 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 372 | 373 | # %Pixel dependency 374 | # E = abs(FG-dGT); 375 | E = np.abs(pred - gt) 376 | # Et = E; 377 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 378 | Et = np.copy(E) 379 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 380 | 381 | # K = fspecial('gaussian',7,5); 382 | # EA = imfilter(Et,K); 383 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 384 | EA = convolve(Et, weights=K, mode="constant", cval=0) 385 | # MIN_E_EA = E; 386 | # MIN_E_EA(GT & EA np.ndarray: 413 | """ 414 | 2D gaussian mask - should give the same result as MATLAB's 415 | fspecial('gaussian',[shape],[sigma]) 416 | """ 417 | m, n = [(ss - 1) / 2 for ss in shape] 418 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 419 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 420 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 421 | sumh = h.sum() 422 | if sumh != 0: 423 | h /= sumh 424 | return h 425 | 426 | def get_results(self) -> dict: 427 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 428 | return dict(wfm=weighted_fm) 429 | --------------------------------------------------------------------------------