├── work_dir_cod
└── weight_place.txt
├── data
└── BEGIN.png
├── segment_anything
├── utils
│ ├── __init__.py
│ ├── transforms.py
│ ├── onnx.py
│ └── amg.py
├── __init__.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── PUA_res_plus.py
│ ├── bias_correction.py
│ ├── CWDLoss.py
│ ├── mix_enbedding.py
│ ├── loop_finer.py
│ ├── guided_filter.py
│ ├── DWT.py
│ ├── sd_merge.py
│ ├── agent_swin.py
│ ├── mask_decoder.py
│ ├── sam.py
│ ├── transformer.py
│ ├── prompt_encoder.py
│ └── image_encoder.py
├── build_sam.py
├── predictor.py
└── automatic_mask_generator.py
├── MSCAF_COD_evaluation
├── test_speed_for_count_nonzero.py
├── evaluation.py
└── sod_metrics
│ └── __init__.py
├── transformer_npz_2_gt.py
├── utils
├── precompute_img_embed.py
└── dataset.py
├── README.md
├── environment.yaml
├── Mytrain.py
├── Mytest.py
└── pre_npz.py
/work_dir_cod/weight_place.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/BEGIN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guobaoxiao/DSAM/HEAD/data/BEGIN.png
--------------------------------------------------------------------------------
/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 | from .pvtv2 import pvt_v2_b2
13 | from .DWT import extract_high_frequency
14 | from .mix_enbedding import ME
15 | from .bias_correction import bias_correction
16 | from .loop_finer import Loop_Finer
17 | # from .IRB import InvertedResidual
18 | # from torchvision.models.mobilenetv2 import InvertedResidual
19 |
--------------------------------------------------------------------------------
/MSCAF_COD_evaluation/test_speed_for_count_nonzero.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/11/27
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 | import time
6 |
7 | import numpy as np
8 |
9 |
10 | # 快速统计numpy数组的非零值建议使用np.count_nonzero,一个简单的小实验
11 | def cal_nonzero(size):
12 | a = np.random.randn(size, size)
13 | a = a > 0
14 | start = time.time()
15 | print(np.count_nonzero(a), time.time() - start)
16 | start = time.time()
17 | print(np.sum(a), time.time() - start)
18 | start = time.time()
19 | print(len(np.nonzero(a)[0]), time.time() - start)
20 |
21 |
22 | if __name__ == '__main__':
23 | cal_nonzero(1000)
24 | # 499950 6.723403930664062e-05
25 | # 499950 0.0006949901580810547
26 | # 499950 0.007088184356689453
27 |
--------------------------------------------------------------------------------
/transformer_npz_2_gt.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage import io
3 | import os
4 | join = os.path.join
5 |
6 |
7 | path = "data/inference_npz/DSAM/"
8 | save_path = 'data/inference_img/DSAM/'
9 | # data = np.load("data/demo2D_vit_b/CVC-ColonDB/CVC-ColonDB.npz")
10 | npz_folders = sorted(os.listdir(path))
11 | for npz_folder in npz_folders:
12 | # 加载npz文件
13 | data = np.load(join(path, npz_folder,npz_folder) + '.npz')
14 | # 获取图像数据
15 | imgs = data["medsam_segs"] # 假设图像数据保存在名为"imgs"的数组中
16 | name = data['number']
17 |
18 | # 将图像数据转换为正确的数据类型和范围
19 | imgs = imgs.astype(np.uint8)
20 | imgs = (imgs * 255.0).astype(np.uint8) # 根据图像数据的范围进行调整
21 |
22 | # 保存图像数据为图片文件
23 | for i, img in enumerate(imgs):
24 | num = 149 + i
25 | img_path = join(save_path, npz_folder) + "/" + name[i] # 设置保存图片的路径和文件名
26 | if not os.path.exists(join(save_path, npz_folder)):
27 | os.makedirs(join(save_path, npz_folder), exist_ok=True)
28 | io.imsave(img_path, img)
29 |
30 | print("Images saved successfully.")
--------------------------------------------------------------------------------
/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/modeling/PUA_res_plus.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def _upsample_like(src, tar):
5 | # 将 src 移动到与 tar 相同的设备 (GPU)
6 | # src = src.to(device)
7 | src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
8 | return src
9 | class PUAModule(nn.Module):
10 | def __init__(self, channel):
11 | super(PUAModule, self).__init__()
12 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, dilation=1, groups=channel//2)
13 | self.conv2 = nn.Conv2d(channel, 2*channel, kernel_size=3, stride=1, padding=1, dilation=2, groups=channel//2)
14 | self.conv3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=2, padding=1, dilation=3, groups=channel//2)
15 |
16 | # self.conv4 = nn.Conv2d(channel, channel/2, kernel_size=1, stride=1, padding=1)
17 | # self.classifier = nn.Conv2d(channel, 1, kernel_size=3, stride=2, padding=1)
18 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
19 | self.bn1 = nn.BatchNorm2d(channel)
20 | self.bn2 = nn.BatchNorm2d(2*channel)
21 | self.bn3 = nn.BatchNorm2d(channel)
22 | # self.bn4 = nn.BatchNorm2d(channel)
23 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
24 | # #self.sigmoid = nn.Sigmoid()
25 | def forward(self, x):
26 | res_x = x
27 | x = self.conv1(x)
28 | x = self.bn1(x)
29 | x = self.leaky_relu(x)
30 | res_x = _upsample_like(res_x, x)
31 | x = x + res_x
32 | res_x_2 = x
33 | x = self.conv2(x)
34 | x = self.bn2(x)
35 | x = self.leaky_relu(x)
36 | # res_x_2 = _upsample_like(res_x_2, x)
37 | # x = x + res_x_2
38 | # res_x_3 = x
39 | x = self.conv3(x)
40 | x = self.bn3(x)
41 | x = self.leaky_relu(x)
42 | res_x_2 = _upsample_like(res_x_2, x)
43 | x = x + res_x_2
44 | # x = self.conv4(x)
45 | # x = self.conv4(x)
46 | # x = self.bn4(x)
47 | # x = self.leaky_relu(x)
48 | # x = self.classifier(x)
49 | return x
--------------------------------------------------------------------------------
/segment_anything/modeling/bias_correction.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from segment_anything.modeling.PUA_res_plus import PUAModule
3 | import torch.nn.functional as F
4 | import torch
5 |
6 | class CrossAttention(nn.Module):
7 | def __init__(self, in_channels):
8 | super(CrossAttention, self).__init__()
9 |
10 | # Query, Key, and Value projections for both input vectors
11 | self.query_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
12 | self.key_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
13 | self.value_v1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
14 |
15 | self.query_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
16 | self.key_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
17 | self.value_v2 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
18 |
19 | def forward(self, v1, v2):
20 | # Project vectors to Query, Key, and Value
21 | query_v1 = self.query_v1(v1)
22 | key_v2 = self.key_v2(v2)
23 | value_v2 = self.value_v2(v2)
24 |
25 | # Compute attention scores
26 | scores = torch.matmul(query_v1.view(query_v1.size(0), -1, query_v1.size(-1)),
27 | key_v2.view(key_v2.size(0), -1, key_v2.size(-1)).transpose(1, 2))
28 | attention = F.softmax(scores, dim=-1)
29 |
30 | # Apply attention to values
31 | output_v1 = torch.matmul(attention, value_v2.view(value_v2.size(0), -1, value_v2.size(-1)))
32 | output_v1 = output_v1.view(v1.size())
33 |
34 | return output_v1
35 |
36 | def _upsample_like_64(src):
37 | src = F.interpolate(src, size=(64,64), mode='bilinear')
38 | return src
39 |
40 | class bias_correction(nn.Module):
41 | def __init__(self, out_channels):
42 | super(bias_correction, self).__init__()
43 | self.PUA = PUAModule(out_channels)
44 | self.conv = nn.Conv2d(out_channels, out_channels // 2, kernel_size=1, stride=1, padding=1)
45 | def forward(self, embedding):
46 | embedding_64 = _upsample_like_64(embedding)
47 | pua_em = self.PUA(embedding)
48 | embedding = self.conv(embedding)
49 | pua_em = self.conv(pua_em)
50 | pua_em = _upsample_like_64(pua_em)
51 | embedding = _upsample_like_64(embedding)
52 | out_em = pua_em + embedding
53 | return out_em, embedding_64
54 |
55 |
--------------------------------------------------------------------------------
/segment_anything/modeling/CWDLoss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ChannelNorm(nn.Module):
4 | '''Channel-wise knowledge distillation for dense prediction'''
5 |
6 | def __init__(self):
7 | super(ChannelNorm, self).__init__()
8 | def forward(self,featmap):
9 | n,c,h,w = featmap.shape
10 | featmap = featmap.reshape((n,c,-1))
11 | featmap = featmap.softmax(dim=-1)
12 | return featmap
13 |
14 | class CriterionCWD(nn.Module):
15 | '''Channel-wise knowledge distillation for dense prediction'''
16 |
17 | def __init__(self, norm_type='none', divergence='mse', temperature=1.0):
18 |
19 | super(CriterionCWD, self).__init__()
20 |
21 | # define normalize function
22 | if norm_type == 'channel':
23 | self.normalize = ChannelNorm()
24 | elif norm_type == 'spatial':
25 | self.normalize = nn.Softmax(dim=1)
26 | elif norm_type == 'channel_mean':
27 | self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)
28 | else:
29 | self.normalize = None
30 | self.norm_type = norm_type
31 |
32 | self.temperature = 1.0
33 |
34 | # define loss function
35 | if divergence == 'mse':
36 | self.criterion = nn.MSELoss(reduction='sum')
37 | elif divergence == 'kl':
38 | self.criterion = nn.KLDivLoss(reduction='sum')
39 | self.temperature = temperature
40 | self.divergence = divergence
41 |
42 | def forward(self, preds_S, preds_T):
43 |
44 | n, c, h, w = preds_S.shape
45 | # import pdb;pdb.set_trace()
46 | if self.normalize is not None:
47 | norm_s = self.normalize(preds_S / self.temperature)
48 | norm_t = self.normalize(preds_T.detach() / self.temperature)
49 | else:
50 | norm_s = preds_S[0]
51 | norm_t = preds_T[0].detach()
52 |
53 | if self.divergence == 'kl':
54 | norm_s = norm_s.log()
55 | loss = self.criterion(norm_s, norm_t)
56 |
57 | # item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)]
58 | # import pdb;pdb.set_trace()
59 | if self.norm_type == 'channel' or self.norm_type == 'channel_mean':
60 | loss /= n * c
61 | # loss /= n * h * w
62 | else:
63 | loss /= n * h * w
64 |
65 | return loss * (self.temperature ** 2)
--------------------------------------------------------------------------------
/utils/precompute_img_embed.py:
--------------------------------------------------------------------------------
1 | #%% import packages
2 | # precompute image embeddings and save them to disk for model training
3 |
4 | import numpy as np
5 | import os
6 | join = os.path.join
7 | from skimage import io, segmentation
8 | from tqdm import tqdm
9 | import torch
10 | from segment_anything import sam_model_registry
11 | from segment_anything.utils.transforms import ResizeLongestSide
12 | import argparse
13 |
14 | #%% parse arguments
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-i', '--img_path', type=str, default='./data/Tr_Release_Part1', help='# and also Tr_Release_Part2 when part1 is done')
17 | parser.add_argument('-o', '--save_path', type=str, default='./data/Tr_npy', help='path to save the image embeddings')
18 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type')
19 | parser.add_argument('--checkpoint', type=str, default='../work_dir/SAM/sam_vit_b_01ec64.pth', help='path to the pre-trained SAM model')
20 | args = parser.parse_args()
21 |
22 | pre_img_path = args.img_path
23 | save_img_emb_path = join(args.save_path, 'npy_embs')
24 | save_gt_path = join(args.save_path, 'npy_gts')
25 | os.makedirs(save_img_emb_path, exist_ok=True)
26 | os.makedirs(save_gt_path, exist_ok=True)
27 | npz_files = sorted(os.listdir(pre_img_path))
28 | #%% set up the model
29 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0')
30 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31 |
32 | # compute image embeddings
33 | for name in tqdm(npz_files):
34 | img = np.load(join(pre_img_path, name))['img'] # (256, 256, 3)
35 | gt = np.load(join(pre_img_path, name))['gt']
36 | resize_img = sam_transform.apply_image(img)
37 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0')
38 | # model input: (1, 3, 1024, 1024)
39 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)
40 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'
41 | with torch.no_grad():
42 | embedding = sam_model.image_encoder(input_image)
43 |
44 | # save as npy
45 | np.save(join(save_img_emb_path, name.split('.npz')[0]+'.npy'), embedding.cpu().numpy()[0])
46 | np.save(join(save_gt_path, name.split('.npz')[0]+'.npy'), gt)
47 | # sanity check
48 | img_idx = img.copy()
49 | bd = segmentation.find_boundaries(gt, mode='inner')
50 | img_idx[bd, :] = [255, 0, 0]
51 | io.imsave(save_img_emb_path + '.png', img_idx)
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Exploring Deeper! Segment Anything Model with Depth Perception for Camouflaged Object Detection
3 | Zhenni Yu, Xiaoqin Zhang, Li Zhao, Yi Bin, Guobao Xiao
4 | ACM MM, 2024
5 |
6 |
7 | 
8 |
9 | ## Usage
10 |
11 | ### Installation
12 |
13 | ```bash
14 | git clone https://github.com/guobaoxiao/DSAM
15 | cd DSAM
16 | ```
17 |
18 | ### environment
19 |
20 | ```bash
21 | conda env create -f environment.yaml
22 | ```
23 |
24 | ## From datasets to npz
25 | you can load down the COD datasets and run this to get npz for train.
26 | ```bash
27 | python pre_npz.py
28 | ```
29 |
30 | - **COD datasets**:
31 | download the COD datasets set from [here](https://github.com/lartpang/awesome-segmentation-saliency-dataset#camouflaged-object-detection-cod)(CAMO, COD10K, NC4K), and put into 'data/'
32 |
33 | - **depth datasets**:
34 | download the depth datasets set, put into 'data/'. The depth image is from PopNet.
35 |
36 | - 通过百度网盘分享的文件:Train_depth.zip
37 | 链接:https://pan.baidu.com/s/1grcASolza9GLpHIVk8mESQ
38 | 提取码:wocz
39 | - 通过百度网盘分享的文件:Test_depth.zip
40 | 链接:https://pan.baidu.com/s/1HobAvMBpfSUfUHNXGZeFLw
41 | 提取码:32ut
42 |
43 |
44 | ### Weights
45 | - **pre-weigth**:
46 | download the weight of sam from [here](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth), the weight of pvt form xxx, put into 'work_dir_cod/SAM/'
47 |
48 | - **DSAM**:
49 | download the weight of well-trained DSAM, put into 'work_dir_cod/DSAM'
50 | - 通过百度网盘分享的文件:DSAM.pth
51 | 链接:https://pan.baidu.com/s/1148mXSjTv7OKlWHcfZFh5A
52 | 提取码:39xx
53 |
54 |
55 | ### The predicted image
56 | - **DSAM**:
57 | - 通过百度网盘分享的文件:DSAM.zip
58 | 链接:https://pan.baidu.com/s/1V5372Z_GdHzYEyOR3iEu4Q
59 | 提取码:fu49
60 |
61 | ### Train
62 | ```bash
63 | python Mytrain.py
64 | ```
65 |
66 | ### Test
67 |
68 | ```bash
69 | python Mytest.py
70 | ```
71 |
72 | ### Translate npz to img
73 |
74 | ```bash
75 | python transformer_nzp_2_gt.py
76 | ```
77 |
78 | ### eval
79 |
80 | ```bash
81 | python MSCAF_COD_evaluation/evaluation.py
82 | ```
83 | ## Citation
84 |
85 | If you find this project useful, please consider citing:
86 |
87 | ```bibtex
88 | @inproceedings{yu2024exploring,
89 | title={Exploring Deeper! Segment Anything Model with Depth Perception for Camouflaged Object Detection},
90 | author={Zhenni Yu and Xiaoqin Zhang and LiZhao and Yi Bin and Guobao Xiao},
91 | booktitle={ACM Multimedia 2024},
92 | year={2024},
93 | url={https://openreview.net/forum?id=d4A0Cw1gVS}
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/MSCAF_COD_evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/11/21
3 | # @Author : Lart Pang
4 | # @GitHub : https://github.com/lartpang
5 |
6 | import os
7 | import sys
8 |
9 | import cv2
10 | from tqdm import tqdm
11 | import torch
12 | import sod_metrics as M
13 | import torch.nn.functional as F
14 |
15 | FM = M.Fmeasure()
16 | WFM = M.WeightedFmeasure()
17 | SM = M.Smeasure()
18 | EM = M.Emeasure()
19 | MAE = M.MAE()
20 |
21 | mask_root = '/data/Jenny/2309COMPrompter_plus_DSAM/data/TestDataset/NC4K/GT'
22 | pred_root = '/data/Jenny/2309COMPrompter_plus_DSAM/data/inference_img/240611_offfice_DSAM_2/NC4K'
23 |
24 | def _upsample_like(src, tar):
25 | src = torch.tensor(src, dtype=torch.float32)
26 | tar = torch.tensor(tar)
27 | src = F.interpolate(src.unsqueeze(0).unsqueeze(0), size=tar.shape, mode='bilinear')
28 | src = src.squeeze(0).squeeze(0).numpy()
29 | return src
30 | mask_name_list = sorted(os.listdir(mask_root))
31 | for mask_name in tqdm(mask_name_list, total=len(mask_name_list)):
32 | mask_path = os.path.join(mask_root, mask_name)
33 | mask_name_for_pred = mask_name.replace(".png", ".jpg")
34 | # mask_name_for_pred = mask_name
35 | # pred_path = os.path.join(pred_root, mask_name_for_pred)
36 | pred_path = os.path.join(pred_root, mask_name_for_pred)
37 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
38 | pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
39 | if len(pred.shape) != 2:
40 | pred = pred[:, :, 0] # 返回(height, width)
41 | if len(mask.shape) != 2:
42 | mask = mask[:, :, 0]
43 | pred = _upsample_like(pred, mask)
44 | assert pred.shape == mask.shape
45 |
46 | FM.step(pred=pred, gt=mask)
47 | WFM.step(pred=pred, gt=mask)
48 | SM.step(pred=pred, gt=mask)
49 | EM.step(pred=pred, gt=mask)
50 | MAE.step(pred=pred, gt=mask)
51 |
52 | fm = FM.get_results()['fm']
53 | wfm = WFM.get_results()['wfm']
54 | sm = SM.get_results()['sm']
55 | em = EM.get_results()['em']
56 | mae = MAE.get_results()['mae']
57 |
58 | print(
59 | 'Smeasure:', sm.round(3), '; ',
60 | 'wFmeasure:', wfm.round(3), '; ',
61 | 'MAE:', mae.round(3), '; ',
62 | 'adpEm:', em['adp'].round(3), '; ',
63 | 'meanEm:', '-' if em['curve'] is None else em['curve'].mean().round(3), '; ',
64 | 'maxEm:', '-' if em['curve'] is None else em['curve'].max().round(3), '; ',
65 | 'adpFm:', fm['adp'].round(3), '; ',
66 | 'meanFm:', fm['curve'].mean().round(3), '; ',
67 | 'maxFm:', fm['curve'].max().round(3),
68 | sep=''
69 | )
70 |
71 | with open("../result.txt", "a+") as f:
72 | print('Smeasure:', sm.round(3), '; ',
73 | 'meanEm:', '-' if em['curve'] is None else em['curve'].mean().round(3), '; ',
74 | 'wFmeasure:', wfm.round(3), '; ',
75 | 'MAE:', mae.round(3), '; ',
76 | file=f
77 | )
78 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mix_enbedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | class BasicConv2d(nn.Module):
4 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True,
5 | bn=nn.BatchNorm2d):
6 | super(BasicConv2d, self).__init__()
7 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
8 | stride=stride, padding=padding, dilation=dilation, bias=False)
9 | self.bn = bn(out_channels)
10 | self.relu = nn.ReLU()
11 | self.need_relu = need_relu
12 |
13 | def forward(self, x):
14 | x = self.conv(x)
15 | x = self.bn(x)
16 | if self.need_relu:
17 | x = self.relu(x)
18 | return x
19 | class ME(torch.nn.Module):
20 | def __init__(self, in_channels, out_channels):
21 | super(ME, self).__init__()
22 |
23 | self.depthwise_conv_reduce_channels_1 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=1,
24 | stride=1, padding=0, groups=in_channels // 16)
25 | self.depthwise_conv_reduce_channels_2 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1,
26 | stride=1, padding=0, groups=in_channels // 16)
27 | self.relu = nn.ReLU(True)
28 | self.branch1 = nn.Sequential(
29 | BasicConv2d(out_channels, out_channels, 1),
30 | BasicConv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
31 | BasicConv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)),
32 | BasicConv2d(out_channels, out_channels, 3, padding=3, dilation=3)
33 | )
34 |
35 | def initialize_parameters(self):
36 | for name, param in self.named_parameters():
37 | if 'weight' in name:
38 |
39 | if len(param.shape) == 1:
40 | param_unsqueeze = param.unsqueeze(0)
41 | nn.init.xavier_uniform_(param_unsqueeze)
42 | param.data.copy_(param_unsqueeze.squeeze(0))
43 | else:
44 | nn.init.xavier_uniform_(param)
45 |
46 | elif 'bias' in name:
47 | # print("bias:" + name)
48 | # The bias term is initialized
49 | nn.init.zeros_(param)
50 | def forward(self,dense_embeddings_box, high_frequency, sparse_embeddings_box): # high_frequency, 这里删了一个这个
51 |
52 | # dense_cat = torch.cat([dense_embeddings_boundary, dense_embeddings_box], dim=1)
53 | dense_cat_tmp = self.depthwise_conv_reduce_channels_1(dense_embeddings_box)
54 | dense_em = self.branch1(dense_embeddings_box)
55 | dense_em =dense_em + dense_cat_tmp
56 |
57 | dense_embeddings = torch.cat([dense_em, high_frequency], dim=1) # 这里注销了一个这个
58 |
59 | dense_embeddings = self.depthwise_conv_reduce_channels_2(dense_embeddings)
60 | sparse_embeddings = sparse_embeddings_box
61 | return dense_embeddings, sparse_embeddings
62 |
--------------------------------------------------------------------------------
/segment_anything/modeling/loop_finer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from .guided_filter import GuidedFilter
5 | from .agent_swin import AgentAttention
6 |
7 | class BasicConv2d(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, need_relu=True,
9 | bn=nn.BatchNorm2d):
10 | super(BasicConv2d, self).__init__()
11 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
12 | stride=stride, padding=padding)
13 | self.bn = bn(out_channels)
14 | self.relu = nn.ReLU()
15 | self.need_relu = need_relu
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | if self.need_relu:
21 | x = self.relu(x)
22 | return x
23 |
24 | def _upsample_like_64(src):
25 | src = F.interpolate(src, size=(64, 64), mode='bilinear')
26 | return src
27 | def _upsample_like_128(src):
28 | src = F.interpolate(src, size=(128, 128), mode='bilinear')
29 | return src
30 | def _upsample_like_256(src):
31 | src = F.interpolate(src, size=(256, 256), mode='bilinear')
32 | return src
33 |
34 |
35 | class Loop_Finer(nn.Module):
36 | def __init__(self, mid_ch):
37 | super(Loop_Finer, self).__init__()
38 | self.gfp = GuidedFilter()
39 | self.finer0 = BasicConv2d(264, mid_ch, kernel_size=3, stride=1, padding=1)
40 | self.finer1 = BasicConv2d(264, mid_ch, kernel_size=3, stride=1, padding=1)
41 | self.finer2 = BasicConv2d(mid_ch*2, mid_ch, kernel_size=1, stride=1, padding=1)
42 | self.finer_atten = AgentAttention(dim=64, num_heads=2)
43 | self.finer4 = BasicConv2d(mid_ch, 1, kernel_size=1, stride=1, padding=0) # fini_ch
44 | def forward(self, pred, image_em, depth_em):
45 | pred = _upsample_like_64(pred)
46 | reversed_pred = 1 - pred
47 | image_cut = torch.chunk(image_em, 8, dim=1)
48 | image_cat = torch.cat((image_cut[0], reversed_pred, image_cut[1], reversed_pred, image_cut[2], reversed_pred, image_cut[3], reversed_pred
49 | , image_cut[4], reversed_pred, image_cut[5], reversed_pred, image_cut[6], reversed_pred, image_cut[7], reversed_pred), 1)
50 |
51 | depth_cut = torch.chunk(depth_em, 8, dim=1)
52 | depth_cat = torch.cat((depth_cut[0], reversed_pred, depth_cut[1], reversed_pred, depth_cut[2], reversed_pred, depth_cut[3], reversed_pred
53 | , depth_cut[4], reversed_pred, depth_cut[5], reversed_pred, depth_cut[6], reversed_pred, depth_cut[7], reversed_pred), 1)
54 |
55 | image_fliter = self.gfp(image_cat, depth_cat)
56 | image_fliter = image_fliter + image_cat
57 |
58 | i_f0 = self.finer0(image_fliter)
59 | i_f0 = _upsample_like_64(i_f0)
60 |
61 | d_f1 = self.finer1(depth_cat)
62 | d_f1 = _upsample_like_64(d_f1)
63 | tmp = d_f1
64 | d_f1 = self.finer_atten(d_f1)
65 | d_f1 = d_f1 + tmp
66 |
67 | i_f2 = self.finer2(torch.cat((i_f0, d_f1), 1))
68 | i_f2 = _upsample_like_128(i_f2)
69 |
70 | d_f3 = self.finer_atten(i_f2)
71 | d_f3 = _upsample_like_128(d_f3)
72 | d_f4 = self.finer4(d_f3)
73 | d_f4 = _upsample_like_256(d_f4)
74 |
75 | return d_f4
--------------------------------------------------------------------------------
/segment_anything/modeling/guided_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 |
8 | def diff_x(input, r):
9 | assert input.dim() == 4
10 |
11 | left = input[:, :, r:2 * r + 1]
12 | middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
13 | right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
14 |
15 | output = torch.cat([left, middle, right], dim=2)
16 |
17 | return output
18 |
19 | def diff_y(input, r):
20 | assert input.dim() == 4
21 |
22 | left = input[:, :, :, r:2 * r + 1]
23 | middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
24 | right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
25 |
26 | output = torch.cat([left, middle, right], dim=3)
27 |
28 | return output
29 |
30 | class BoxFilter(nn.Module):
31 | def __init__(self, r):
32 | super(BoxFilter, self).__init__()
33 |
34 | self.r = r
35 |
36 | def forward(self, x):
37 | assert x.dim() == 4
38 |
39 | return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
40 |
41 |
42 | class FastGuidedFilter(nn.Module):
43 | def __init__(self, r, eps=1e-8):
44 | super(FastGuidedFilter, self).__init__()
45 |
46 | self.r = r
47 | self.eps = eps
48 | self.boxfilter = BoxFilter(r)
49 |
50 |
51 | def forward(self, lr_x, lr_y, hr_x):
52 | n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
53 | n_lry, c_lry, h_lry, w_lry = lr_y.size()
54 | n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
55 |
56 | assert n_lrx == n_lry and n_lry == n_hrx
57 | assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
58 | assert h_lrx == h_lry and w_lrx == w_lry
59 | assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
60 |
61 | ## N
62 | N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
63 |
64 | ## mean_x
65 | mean_x = self.boxfilter(lr_x) / N
66 | ## mean_y
67 | mean_y = self.boxfilter(lr_y) / N
68 | ## cov_xy
69 | cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
70 | ## var_x
71 | var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
72 |
73 | ## A
74 | A = cov_xy / (var_x + self.eps)
75 | ## b
76 | b = mean_y - A * mean_x
77 |
78 | ## mean_A; mean_b
79 | mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
80 | mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
81 |
82 | return mean_A*hr_x+mean_b
83 |
84 |
85 | class GuidedFilter(nn.Module):
86 | def __init__(self, r=4, eps=1e-2):
87 | super(GuidedFilter, self).__init__()
88 |
89 | self.r = r
90 | self.eps = eps
91 | self.boxfilter = BoxFilter(r)
92 |
93 |
94 | def forward(self, x, y):
95 | n_x, c_x, h_x, w_x = x.size()
96 | n_y, c_y, h_y, w_y = y.size()
97 |
98 | assert n_x == n_y
99 | assert c_x == 1 or c_x == c_y
100 | assert h_x == h_y and w_x == w_y
101 | assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
102 |
103 | # N
104 | N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
105 |
106 | # mean_x
107 | mean_x = self.boxfilter(x) / N
108 | # mean_y
109 | mean_y = self.boxfilter(y) / N
110 | # cov_xy
111 | cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
112 | # var_x
113 | var_x = self.boxfilter(x * x) / N - mean_x * mean_x
114 |
115 | # A
116 | A = cov_xy / (var_x + self.eps)
117 | # b
118 | b = mean_y - A * mean_x
119 |
120 | # mean_A; mean_b
121 | mean_A = self.boxfilter(A) / N
122 | mean_b = self.boxfilter(b) / N
123 |
124 | return mean_A * x + mean_b
125 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | import cv2
5 | from typing import Any, Tuple
6 | import torch
7 | from torch.utils.data import Dataset
8 | import torchvision.transforms.functional as TF
9 |
10 |
11 | class MedSamDataset(Dataset):
12 | def __init__(
13 | self,
14 | df: pd.DataFrame,
15 | image_col: str,
16 | mask_col: str,
17 | image_dir: Any = None,
18 | mask_dir: str = None,
19 | image_size: Tuple = (256, 256),
20 | ):
21 | """
22 | PyTorch dataset class for loading image,mask and bbox pairs from a dataframe.
23 | The dataframe will need to have atleast two columns for the image and mask file names. The columns can either have the full or relative
24 | path of the images or just the file names.
25 | If only file names are given in the columns, the `image_dir` and `mask_dir` arguments should be specified.
26 |
27 | Args:
28 | df (pd.DataFrame): the pandas dataframe object
29 | image_col (str): the name of the column on the dataframe that holds the image file names.
30 | mask_col (str): the name of the column on the dataframe that holds the mask file names.
31 | image_dir (Any, optional): Path to the input image directory. Defaults to None.
32 | mask_dir (str, optional): Path to the mask images directory. Defaults to None.
33 | image_size (Tuple, optional): image size. Defaults to (256, 256).
34 | """
35 | self.df = df
36 | self.image_dir = image_dir
37 | self.mask_dir = mask_dir
38 | self.image_col = image_col
39 | self.mask_col = mask_col
40 | self.image_size = image_size
41 |
42 | def __len__(self):
43 | return len(self.df)
44 |
45 | def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
46 | # read dataframe row
47 | row = self.df.iloc[idx]
48 | # If the `image_dir` attribute is set, the path will be relative to that directory.
49 | # Otherwise, the path will be the value of the `row[self.image_col]` attribute.
50 | image_file = (
51 | os.path.join(self.image_dir, row[self.image_col])
52 | if self.image_dir
53 | else row[self.image_col]
54 | )
55 | mask_file = (
56 | os.path.join(self.mask_dir, row[self.mask_col])
57 | if self.mask_dir
58 | else row[self.mask_col]
59 | )
60 |
61 | if not os.path.exists(image_file):
62 | raise FileNotFoundError(f"Couldn't find image {image_file}")
63 | if not os.path.exists(mask_file):
64 | raise FileNotFoundError(f"Couldn't find image {mask_file}")
65 |
66 | # read image and mask files
67 | image_data = cv2.imread(image_file)
68 | # read mask as gray scale
69 | mask_data = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
70 |
71 | return self._preprocess(image_data, mask_data)
72 |
73 | def _preprocess(
74 | self, image: np.ndarray, mask: np.ndarray
75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
76 | # Threshold mask to binary
77 | mask = cv2.threshold(mask, 127.0, 255.0, cv2.THRESH_BINARY)[1]
78 | # convert to tensor
79 | image = TF.to_tensor(image)
80 | mask = TF.to_tensor(mask)
81 | # min-max normalize and scale
82 | image = (image - image.min()) / (image.max() - image.min()) * 255.0
83 | # resize
84 | image = TF.resize(image, self.image_size, antialias=True)
85 | mask = TF.resize(mask, self.image_size, antialias=True)
86 |
87 | bbox = self._get_bbox(mask)
88 |
89 | return image, mask, bbox
90 |
91 | def _get_bbox(self, mask: torch.Tensor) -> torch.Tensor:
92 | _, y_indices, x_indices = torch.where(mask > 0)
93 |
94 | x_min, y_min = (x_indices.min(), y_indices.min())
95 | x_max, y_max = (x_indices.max(), y_indices.max())
96 |
97 | # add perturbation to bounding box coordinates
98 | H, W = mask.shape[1:]
99 | # add perfurbation to the bbox
100 | assert H == W, f"{W} and {H} are not equal size!!"
101 | x_min = max(0, x_min - np.random.randint(0, 10))
102 | x_max = min(W, x_max + np.random.randint(0, 10))
103 | y_min = max(0, y_min - np.random.randint(0, 10))
104 | y_max = min(H, y_max + np.random.randint(0, 10))
105 |
106 | return torch.tensor([x_min, y_min, x_max, y_max])
107 |
--------------------------------------------------------------------------------
/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | # for pvt's image size
34 | def apply_image_pvt(self, image: np.ndarray) -> np.ndarray:
35 | """
36 | Expects a numpy array with shape HxWxC in uint8 format.
37 | """
38 | target_size = self.get_preprocess_shape(image.shape[1], image.shape[2], self.target_length)
39 | return np.array(resize(to_pil_image(image), target_size))
40 |
41 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
42 | """
43 | Expects a numpy array of length 2 in the final dimension. Requires the
44 | original image size in (H, W) format.
45 | """
46 | old_h, old_w = original_size
47 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length)
48 | new_coords = np.empty_like(coords)
49 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w)
50 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h)
51 | return new_coords
52 |
53 |
54 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
55 | """
56 | Expects a numpy array shape Bx4. Requires the original image size
57 | in (H, W) format.
58 | """
59 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
60 | return boxes.reshape(-1, 4)
61 |
62 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
63 | """
64 | Expects batched images with shape BxCxHxW and float format. This
65 | transformation may not exactly match apply_image. apply_image is
66 | the transformation expected by the model.
67 | """
68 | # Expects an image in BCHW format. May not exactly match apply_image.
69 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
70 | return F.interpolate(
71 | image, target_size, mode="bilinear", align_corners=False, antialias=True
72 | )
73 |
74 | def apply_coords_torch(
75 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
76 | ) -> torch.Tensor:
77 | """
78 | Expects a torch tensor with length 2 in the last dimension. Requires the
79 | original image size in (H, W) format.
80 | """
81 | old_h, old_w = original_size
82 | new_h, new_w = self.get_preprocess_shape(
83 | original_size[0], original_size[1], self.target_length
84 | )
85 | coords = deepcopy(coords).to(torch.float)
86 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
87 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
88 | return coords
89 |
90 | def apply_boxes_torch(
91 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
92 | ) -> torch.Tensor:
93 | """
94 | Expects a torch tensor with shape Bx4. Requires the original image
95 | size in (H, W) format.
96 | """
97 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
98 | return boxes.reshape(-1, 4)
99 |
100 | @staticmethod
101 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
102 | """
103 | Compute the output size given input size and target long side length.
104 | """
105 | scale = long_side_length * 1.0 / max(oldh, oldw)
106 | newh, neww = oldh * scale, oldw * scale
107 | neww = int(neww + 0.5)
108 | newh = int(newh + 0.5)
109 | return (newh, neww)
110 |
--------------------------------------------------------------------------------
/segment_anything/modeling/DWT.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | # affine_par = True
5 | class BasicConv2d(nn.Module):
6 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True,
7 | bn=nn.BatchNorm2d):
8 | super(BasicConv2d, self).__init__()
9 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
10 | stride=stride, padding=padding, dilation=dilation, bias=False)
11 | self.bn = bn(out_channels)
12 | self.relu = nn.ReLU()
13 | self.need_relu = need_relu
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.bn(x)
18 | if self.need_relu:
19 | x = self.relu(x)
20 | return x
21 |
22 | # class ETM(nn.Module):
23 | # def __init__(self, in_channels, out_channels):
24 | # super(ETM, self).__init__()
25 | # self.relu = nn.ReLU(True)
26 | # self.branch0 = BasicConv2d(in_channels, out_channels, 1)
27 | # self.branch1 = nn.Sequential(
28 | # BasicConv2d(in_channels, out_channels, 1),
29 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
30 | # BasicConv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)),
31 | # BasicConv2d(out_channels, out_channels, 3, padding=3, dilation=3)
32 | # )
33 | # self.branch2 = nn.Sequential(
34 | # BasicConv2d(in_channels, out_channels, 1),
35 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 5), padding=(0, 2)),
36 | # BasicConv2d(out_channels, out_channels, kernel_size=(5, 1), padding=(2, 0)),
37 | # BasicConv2d(out_channels, out_channels, 3, padding=5, dilation=5)
38 | # )
39 | # self.branch3 = nn.Sequential(
40 | # BasicConv2d(in_channels, out_channels, 1),
41 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 7), padding=(0, 3)),
42 | # BasicConv2d(out_channels, out_channels, kernel_size=(7, 1), padding=(3, 0)),
43 | # BasicConv2d(out_channels, out_channels, 3, padding=7, dilation=7)
44 | # )
45 | # self.conv_cat = BasicConv2d(4 * out_channels, out_channels, 3, padding=1)
46 | # self.conv_res = BasicConv2d(in_channels, out_channels, 1)
47 | #
48 | # def forward(self, x):
49 | # x0 = self.branch0(x)
50 | # x1 = self.branch1(x)
51 | # x2 = self.branch2(x)
52 | # x3 = self.branch3(x)
53 | # x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
54 | #
55 | # x = self.relu(x_cat + self.conv_res(x))
56 | # return x
57 | def resize_tensor(tensor, size):
58 | """
59 | 使用插值方法将张量调整大小
60 |
61 | Args:
62 | tensor (torch.Tensor): 输入的张量,形状为 (batch_size, num_channels, height, width)
63 | size (tuple): 调整后的目标大小,形状为 (new_height, new_width)
64 |
65 | Returns:
66 | torch.Tensor: 调整大小后的张量,形状为 (batch_size, num_channels, new_height, new_width)
67 | """
68 | resized_tensor = nn.functional.interpolate(tensor, size=size, mode='bilinear', align_corners=False)
69 | return resized_tensor
70 |
71 | class DWT(nn.Module):
72 | def __init__(self):
73 | super(DWT, self).__init__()
74 | self.requires_grad = False
75 |
76 | def forward(self, x):
77 | # x01 x02是低频信号; x1 x2 x3 x4是高频信号
78 | x01 = x[:, :, 0::2, :] / 2
79 | # 表示在第三个维度上以步幅为2选择元素,即选择索引为0、2、4、...的元素。
80 | x02 = x[:, :, 1::2, :] / 2
81 | # 这意味着只选择索引为1、3、5等奇数位置上的元素。/ 2:这是除法运算符,
82 | # 将选取的部分 x[:, :, 0::2, :] 的每个元素都除以2,即 / 2 的操作。这样做是为了缩小高频部分的值范围,将其变得更小。
83 | # 在一些信号处理的方法中,高频部分往往具有较大的值,可能会对某些操作产生较大的影响,例如激活函数的响应范围、优化算法的收敛性等。
84 | # 通过将高频部分除以2,可以将其数值范围缩小一半,使其相对于低频部分具有更小的权重,从而在一定程度上平衡了高频和低频的影响。
85 | x1 = x01[:, :, :, 0::2]
86 | # 是对 x01 在空间维度上进行下采样的操作,步长为2。这样的操作会使得 x01 在宽度维度上减半,
87 | # 即将每一行的元素进行间隔取样,形状变为 (batch_size, num_channels, height//2, width//2)。
88 | x2 = x02[:, :, :, 0::2]
89 | x3 = x01[:, :, :, 1::2]
90 | x4 = x02[:, :, :, 1::2]
91 | ll = x1 + x2 + x3 + x4
92 | lh = -x1 + x2 - x3 + x4
93 | hl = -x1 - x2 + x3 + x4
94 | hh = x1 - x2 - x3 + x4
95 | return ll, lh, hl, hh
96 | # ll 子图表示低频部分的信息,lh、hl、hh 分别表示高频部分在不同方向上的信息。
97 |
98 | class extract_high_frequency(nn.Module):
99 | def __init__(self):
100 | super(extract_high_frequency, self).__init__()
101 | self.dwt = DWT()
102 | self.conv_in = nn.Conv2d(768, 256, 3, padding=1)
103 |
104 | def forward(self, x):
105 | x = self.conv_in(x)
106 | ll, lh, hl, hh = self.dwt(x)
107 | high_frequency = resize_tensor(hh, (64, 64))
108 | return high_frequency
109 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: pytorch183
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - defaults
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/
10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
13 | dependencies:
14 | - _libgcc_mutex=0.1=main
15 | - _openmp_mutex=4.5=1_gnu
16 | - blas=1.0=mkl
17 | - brotli=1.0.9=he6710b0_2
18 | - brotlipy=0.7.0=py37h27cfd23_1003
19 | - bzip2=1.0.8=h7b6447c_0
20 | - ca-certificates=2023.01.10=h06a4308_0
21 | - certifi=2022.12.7=py37h06a4308_0
22 | - cffi=1.15.0=py37hd667e15_1
23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
24 | - cloudpickle=2.0.0=pyhd3eb1b0_0
25 | - colorama=0.4.4=pyhd3eb1b0_0
26 | - cryptography=39.0.1=py37h9ce1e76_0
27 | - cudatoolkit=11.3.1=h2bc3f7f_2
28 | - cycler=0.11.0=pyhd3eb1b0_0
29 | - cytoolz=0.11.0=py37h7b6447c_0
30 | - dask-core=2021.10.0=pyhd3eb1b0_0
31 | - ffmpeg=4.3=hf484d3e_0
32 | - fonttools=4.25.0=pyhd3eb1b0_0
33 | - freetype=2.11.0=h70c0345_0
34 | - fsspec=2022.2.0=pyhd3eb1b0_0
35 | - gmp=6.2.1=h2531618_2
36 | - gnutls=3.6.15=he1e5248_0
37 | - idna=3.4=py37h06a4308_0
38 | - imageio=2.9.0=pyhd3eb1b0_0
39 | - intel-openmp=2021.4.0=h06a4308_3561
40 | - jpeg=9b=h024ee3a_2
41 | - kiwisolver=1.3.2=py37h295c915_0
42 | - lame=3.100=h7b6447c_0
43 | - lcms2=2.12=h3be6417_0
44 | - ld_impl_linux-64=2.35.1=h7274673_9
45 | - libffi=3.3=he6710b0_2
46 | - libgcc-ng=9.3.0=h5101ec6_17
47 | - libgfortran-ng=7.5.0=ha8ba4b0_17
48 | - libgfortran4=7.5.0=ha8ba4b0_17
49 | - libgomp=9.3.0=h5101ec6_17
50 | - libiconv=1.15=h63c8f33_5
51 | - libidn2=2.3.2=h7f8727e_0
52 | - libpng=1.6.37=hbc83047_0
53 | - libstdcxx-ng=9.3.0=hd4cf53a_17
54 | - libtasn1=4.16.0=h27cfd23_0
55 | - libtiff=4.2.0=h85742a9_0
56 | - libunistring=0.9.10=h27cfd23_0
57 | - libuv=1.40.0=h7b6447c_0
58 | - libwebp-base=1.2.2=h7f8727e_0
59 | - locket=0.2.1=py37h06a4308_2
60 | - lz4-c=1.9.3=h295c915_1
61 | - matplotlib-base=3.5.1=py37ha18d171_1
62 | - mkl=2021.4.0=h06a4308_640
63 | - mkl-service=2.4.0=py37h7f8727e_0
64 | - mkl_fft=1.3.1=py37hd3c417c_0
65 | - mkl_random=1.2.2=py37h51133e4_0
66 | - munkres=1.1.4=py_0
67 | - ncurses=6.3=h7f8727e_2
68 | - nettle=3.7.3=hbbd107a_1
69 | - networkx=2.6.3=pyhd3eb1b0_0
70 | - numpy=1.21.5=py37he7a7128_2
71 | - numpy-base=1.21.5=py37hf524024_2
72 | - olefile=0.46=py37_0
73 | - openh264=2.1.1=h4ff587b_0
74 | - openssl=1.1.1t=h7f8727e_0
75 | - packaging=21.3=pyhd3eb1b0_0
76 | - partd=1.2.0=pyhd3eb1b0_1
77 | - pillow=8.0.0=py37h9a89aac_0
78 | - pip=21.2.2=py37h06a4308_0
79 | - pycparser=2.21=pyhd3eb1b0_0
80 | - pyopenssl=23.0.0=py37h06a4308_0
81 | - pyparsing=3.0.4=pyhd3eb1b0_0
82 | - pysocks=1.7.1=py37_1
83 | - python=3.7.11=h12debd9_0
84 | - python-dateutil=2.8.2=pyhd3eb1b0_0
85 | - pytorch=1.12.1=py3.7_cuda11.3_cudnn8.3.2_0
86 | - pytorch-mutex=1.0=cuda
87 | - pywavelets=1.1.1=py37h7b6447c_2
88 | - pyyaml=6.0=py37h7f8727e_1
89 | - readline=8.1.2=h7f8727e_1
90 | - requests=2.28.1=py37h06a4308_0
91 | - scikit-image=0.15.0=py37hb3f55d8_2
92 | - scipy=1.7.3=py37hc147768_0
93 | - setuptools=58.0.4=py37h06a4308_0
94 | - six=1.16.0=pyhd3eb1b0_1
95 | - sqlite=3.38.0=hc218d9a_0
96 | - tk=8.6.11=h1ccaba5_0
97 | - toolz=0.11.2=pyhd3eb1b0_0
98 | - torchaudio=0.12.1=py37_cu113
99 | - torchvision=0.13.1=py37_cu113
100 | - urllib3=1.26.14=py37h06a4308_0
101 | - wheel=0.38.4=py37h06a4308_0
102 | - xz=5.2.5=h7b6447c_0
103 | - yaml=0.2.5=h7b6447c_0
104 | - zlib=1.2.11=h7f8727e_4
105 | - zstd=1.4.9=haebb681_0
106 | - pip:
107 | - addict==2.4.0
108 | - anykeystore==0.2
109 | - apex==0.1
110 | - causal-conv1d==1.2.2.post1
111 | - cryptacular==1.6.2
112 | - defusedxml==0.7.1
113 | - einops==0.6.1
114 | - filelock==3.12.2
115 | - greenlet==3.0.3
116 | - huggingface-hub==0.16.4
117 | - hupper==1.12.1
118 | - importlib-metadata==6.7.0
119 | - joblib==1.3.2
120 | - mamba-ssm==1.2.2
121 | - markupsafe==2.1.4
122 | - mmcv==1.7.0
123 | - monai==1.1.0
124 | - ninja==1.11.1.1
125 | - oauthlib==3.2.2
126 | - pandas==1.3.5
127 | - pastedeploy==3.1.0
128 | - pbkdf2==1.3
129 | - plaster==1.1.2
130 | - plaster-pastedeploy==1.0.1
131 | - platformdirs==4.0.0
132 | - protobuf==4.24.4
133 | - pyramid==2.0.2
134 | - pyramid-mailer==0.15.1
135 | - python3-openid==3.2.0
136 | - pytz==2024.1
137 | - regex==2024.4.16
138 | - repoze-sendmail==4.4.1
139 | - requests-oauthlib==1.3.1
140 | - safetensors==0.4.2
141 | - scikit-learn==1.0.2
142 | - seaborn==0.12.2
143 | - sklearn==0.0
144 | - sqlalchemy==2.0.25
145 | - tensorboardx==2.6.2.2
146 | - thop==0.1.1-2209072238
147 | - threadpoolctl==3.1.0
148 | - timm==0.9.12
149 | - tokenizers==0.13.3
150 | - tomli==2.0.1
151 | - torchsummary==1.5.1
152 | - tqdm==4.66.1
153 | - transaction==4.0
154 | - transformers==4.30.2
155 | - translationstring==1.4
156 | - triton==2.3.1
157 | - ttach==0.0.3
158 | - typing-extensions==4.7.1
159 | - velruse==1.1.1
160 | - venusian==3.1.0
161 | - webob==1.8.7
162 | - wtforms==3.0.1
163 | - wtforms-recaptcha==0.3.2
164 | - yapf==0.40.2
165 | - zipp==3.15.0
166 | - zope-deprecation==5.0
167 | - zope-interface==6.1
168 | - zope-sqlalchemy==3.1
169 | prefix: /root/anaconda3/envs/pytorch183
170 |
--------------------------------------------------------------------------------
/segment_anything/modeling/sd_merge.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class BasicConv2d(nn.Module):
5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, need_relu=True,
6 | bn=nn.BatchNorm2d):
7 | super(BasicConv2d, self).__init__()
8 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
9 | stride=stride, padding=padding, dilation=dilation, bias=False)
10 | self.bn = bn(out_channels)
11 | self.relu = nn.ReLU()
12 | self.need_relu = need_relu
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | x = self.bn(x)
17 | if self.need_relu:
18 | x = self.relu(x)
19 | return x
20 | # merge_dense_sparse
21 | class MDS(nn.Module):
22 | def __init__(self, in_channels, out_channels):
23 | super(MDS, self).__init__()
24 | '''8 :sparse_embedding:tensor(bs, 2, 256) -> tensor(bs, 1, 2, 256)
25 | :dense_embedding:tensor(bs, 256, 64, 64)
26 | 8 = 64*64/2/256'''
27 | self.sparse_in_channels = 1
28 | self.dense_in_channels = in_channels*8
29 | self.real_in_channels = in_channels*8+self.sparse_in_channels
30 | self.real_out_channels = out_channels*8+self.sparse_in_channels
31 | self.relu = nn.ReLU(True)
32 | # self.branch0 = BasicConv2d(self.real_in_channels, self.real_out_channels, 1)
33 | self.branch1 = nn.Sequential(
34 | BasicConv2d(self.real_in_channels, self.real_out_channels, 1),
35 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(1, 3), padding=(0, 1)),
36 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(3, 1), padding=(1, 0)),
37 | BasicConv2d(self.real_out_channels, self.real_out_channels, 3, padding=3, dilation=3)
38 | )
39 | self.branch2 = nn.Sequential(
40 | BasicConv2d(self.real_in_channels, self.real_out_channels, 1),
41 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(1, 5), padding=(0, 2)),
42 | BasicConv2d(self.real_out_channels, self.real_out_channels, kernel_size=(5, 1), padding=(2, 0)),
43 | BasicConv2d(self.real_out_channels, self.real_out_channels, 3, padding=5, dilation=5)
44 | )
45 | # self.branch3 = nn.Sequential(
46 | # BasicConv2d(in_channels, out_channels, 1),
47 | # BasicConv2d(out_channels, out_channels, kernel_size=(1, 7), padding=(0, 3)),
48 | # BasicConv2d(out_channels, out_channels, kernel_size=(7, 1), padding=(3, 0)),
49 | # BasicConv2d(out_channels, out_channels, 3, padding=7, dilation=7)
50 | # )
51 | self.conv_cat = BasicConv2d(2 * self.real_out_channels, self.real_out_channels, 3, padding=1)
52 | # self.conv_res = BasicConv2d(in_channels, out_channels, 1)
53 | self.conv_spares_cat = BasicConv2d(2*self.sparse_in_channels, self.sparse_in_channels, 3, padding=1)
54 | self.conv_dense_cat = BasicConv2d(2*self.dense_in_channels, self.dense_in_channels, 3, padding=1)
55 | def initialize_parameters(self):
56 | for name, param in self.named_parameters():
57 | if 'weight' in name:
58 | # Initialize the weight parameters
59 | # One-dimensional vectors cannot be initialized, so we first convert them to two dimensions,
60 | # initialize them, and then convert them to one dimension
61 | if len(param.shape) == 1:
62 | param_unsqueeze = param.unsqueeze(0)
63 | nn.init.xavier_uniform_(param_unsqueeze)
64 | param.data.copy_(param_unsqueeze.squeeze(0))
65 | else:
66 | nn.init.xavier_uniform_(param)
67 |
68 | elif 'bias' in name:
69 | # print("bias:" + name)
70 | # The bias term is initialized
71 | nn.init.zeros_(param)
72 |
73 | def forward(self, sparse_embeddings, dense_embeddings):
74 | sparse_embeddings = sparse_embeddings.unsqueeze(1)
75 | bs_d, c_d, h_d, w_d = dense_embeddings.size()
76 | bs_s, c_s, h_s, w_s = sparse_embeddings.size()
77 | n = h_d * w_d // (h_s * w_s)
78 | dense_embeddings = dense_embeddings.view(bs_d, n*c_d, h_s, w_s)
79 | tmp_sparse = sparse_embeddings
80 | tmp_dense = dense_embeddings
81 | sd_cat = torch.cat([sparse_embeddings, dense_embeddings], dim=1)
82 |
83 | # x0 = self.branch0(sd_cat)
84 | x1 = self.branch1(sd_cat)
85 | x_cat = x1
86 | # x_cat = self.conv_cat(torch.cat((x0, x1), 1))
87 |
88 | sparse_embeddings_up = x_cat[:, :self.sparse_in_channels, :, :]
89 | sparse_embeddings = self.relu(sparse_embeddings * sparse_embeddings_up)
90 | sparse_cat = torch.cat([sparse_embeddings, tmp_sparse], dim=1)
91 | sparse_embeddings = self.conv_spares_cat(sparse_cat)
92 | # sparse_embeddings = sparse_embeddings.view(32, 1, -1, 256)
93 | sparse_embeddings = sparse_embeddings.squeeze(dim=1)
94 |
95 | dense_embeddings_down = x_cat[:, self.sparse_in_channels:, :, :]
96 | dense_embeddings = self.relu(dense_embeddings * dense_embeddings_down)
97 | dense_embeddings = self.conv_dense_cat(torch.cat([dense_embeddings, tmp_dense], dim=1))
98 | dense_embeddings = dense_embeddings.view(bs_d, c_d, h_d, w_d)
99 | # dense_embeddings = dense_embeddings + tmp_dense
100 | return sparse_embeddings, dense_embeddings
--------------------------------------------------------------------------------
/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from functools import partial
7 | from pathlib import Path
8 | import urllib.request
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | from .modeling import (
14 | ImageEncoderViT,
15 | MaskDecoder,
16 | PromptEncoder,
17 | Sam,
18 | TwoWayTransformer,
19 | # Shield,
20 | extract_high_frequency,
21 | ME,
22 | Loop_Finer,
23 | bias_correction,
24 | pvt_v2_b2,
25 | # InvertedResidual
26 |
27 | )
28 |
29 |
30 | def build_sam_vit_h(checkpoint=None):
31 | return _build_sam(
32 | encoder_embed_dim=1280,
33 | encoder_depth=32,
34 | encoder_num_heads=16,
35 | encoder_global_attn_indexes=[7, 15, 23, 31],
36 | checkpoint=checkpoint,
37 | )
38 |
39 |
40 | build_sam = build_sam_vit_h
41 |
42 |
43 | def build_sam_vit_l(checkpoint=None):
44 | return _build_sam(
45 | encoder_embed_dim=1024,
46 | encoder_depth=24,
47 | encoder_num_heads=16,
48 | encoder_global_attn_indexes=[5, 11, 17, 23],
49 | checkpoint=checkpoint,
50 | )
51 |
52 |
53 | def build_sam_vit_b(checkpoint=None):
54 | return _build_sam(
55 | encoder_embed_dim=768,
56 | encoder_depth=12,
57 | encoder_num_heads=12,
58 | encoder_global_attn_indexes=[2, 5, 8, 11],
59 | checkpoint=checkpoint,
60 | )
61 |
62 |
63 | sam_model_registry = {
64 | "default": build_sam_vit_h,
65 | "vit_h": build_sam_vit_h,
66 | "vit_l": build_sam_vit_l,
67 | "vit_b": build_sam_vit_b,
68 | }
69 |
70 |
71 | def _build_sam(
72 | encoder_embed_dim,
73 | encoder_depth,
74 | encoder_num_heads,
75 | encoder_global_attn_indexes,
76 | checkpoint=None,
77 | ):
78 | prompt_embed_dim = 256
79 | image_size = 1024
80 | vit_patch_size = 16
81 | image_embedding_size = image_size // vit_patch_size
82 | sam = Sam(
83 | image_encoder=ImageEncoderViT(
84 | depth=encoder_depth,
85 | embed_dim=encoder_embed_dim,
86 | img_size=image_size,
87 | mlp_ratio=4,
88 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
89 | num_heads=encoder_num_heads,
90 | patch_size=vit_patch_size,
91 | qkv_bias=True,
92 | use_rel_pos=True,
93 | global_attn_indexes=encoder_global_attn_indexes,
94 | window_size=14,
95 | out_chans=prompt_embed_dim,
96 | ),
97 | prompt_encoder=PromptEncoder(
98 | embed_dim=prompt_embed_dim,
99 | image_embedding_size=(image_embedding_size, image_embedding_size),
100 | input_image_size=(image_size, image_size),
101 | mask_in_chans=16,
102 | ),
103 | mask_decoder=MaskDecoder(
104 | num_multimask_outputs=3,
105 | transformer=TwoWayTransformer(
106 | depth=2,
107 | embedding_dim=prompt_embed_dim,
108 | mlp_dim=2048,
109 | num_heads=8,
110 | ),
111 | transformer_dim=prompt_embed_dim,
112 | iou_head_depth=3,
113 | iou_head_hidden_dim=256,
114 | ),
115 | pvt=pvt_v2_b2(),
116 | DWT=extract_high_frequency(),
117 | ME=ME(256*2, 256),
118 | Loop_finer=Loop_Finer(64),
119 | BC=bias_correction(512),
120 | pixel_mean=[123.675, 116.28, 103.53],
121 | pixel_std=[58.395, 57.12, 57.375],
122 | )
123 |
124 | sam.eval()
125 | checkpoint = Path(checkpoint)
126 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
127 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
128 | if len(cmd) == 0 or cmd.lower() == 'y':
129 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
130 | print("Downloading SAM ViT-B checkpoint...")
131 | urllib.request.urlretrieve(
132 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
133 | checkpoint,
134 | )
135 | print(checkpoint.name, " is downloaded!")
136 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
137 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
138 | if len(cmd) == 0 or cmd.lower() == 'y':
139 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
140 | print("Downloading SAM ViT-H checkpoint...")
141 | urllib.request.urlretrieve(
142 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
143 | checkpoint,
144 | )
145 | print(checkpoint.name, " is downloaded!")
146 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
147 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
148 | if len(cmd) == 0 or cmd.lower() == 'y':
149 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
150 | print("Downloading SAM ViT-L checkpoint...")
151 | urllib.request.urlretrieve(
152 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
153 | checkpoint,
154 | )
155 | print(checkpoint.name, " is downloaded!")
156 |
157 |
158 | if checkpoint is not None:
159 | # sam.ME = ME(256 * 3) # 在加载预训练权重之前,初始化 ME 模块
160 | with open(checkpoint, "rb") as f:
161 | state_dict = torch.load(f)
162 |
163 | sam.ME.initialize_parameters()
164 | sam.load_state_dict(state_dict, strict=False) # 设置 strict=False 允许部分加载权重
165 | path = 'work_dir_cod/SAM/pvt_v2_b2.pth'
166 | save_model = torch.load(path)
167 | model_dict = sam.pvt.state_dict()
168 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
169 | model_dict.update(state_dict)
170 | sam.pvt.load_state_dict(model_dict)
171 | return sam
172 |
--------------------------------------------------------------------------------
/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/modeling/agent_swin.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Agent Attention: On the Integration of Softmax and Linear Attention
8 | # Modified by Dongchen Han
9 | # -----------------------------------------------------------------------
10 |
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 |
16 |
17 | class AgentAttention(nn.Module):
18 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
19 | It supports both of shifted and non-shifted window.
20 |
21 | Args:
22 | dim (int): Number of input channels.
23 | num_heads (int): Number of attention heads.
24 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
25 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
26 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
27 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
28 | """
29 |
30 | def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
31 | shift_size=0, agent_num=49, **kwargs):
32 |
33 | super().__init__()
34 | self.dim = dim # embed_dim=96
35 | # self.window_size = window_size # Wh, Ww
36 | self.num_heads = num_heads # num_heads=[3, 6, 12, 24]
37 | head_dim = dim // num_heads
38 | self.scale = head_dim ** -0.5
39 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
40 | self.attn_drop = nn.Dropout(attn_drop)
41 | self.proj = nn.Linear(dim, dim)
42 | self.proj_drop = nn.Dropout(proj_drop)
43 | self.softmax = nn.Softmax(dim=-1)
44 | self.shift_size = shift_size
45 |
46 | self.agent_num = agent_num
47 | self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
48 | # self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
49 | # self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
50 | # self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0], 1))
51 | # self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1]))
52 | # self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
53 | # self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
54 | # trunc_normal_(self.an_bias, std=.02)
55 | # trunc_normal_(self.na_bias, std=.02)
56 | # trunc_normal_(self.ah_bias, std=.02)
57 | # trunc_normal_(self.aw_bias, std=.02)
58 | # trunc_normal_(self.ha_bias, std=.02)
59 | # trunc_normal_(self.wa_bias, std=.02)
60 | pool_size = int(agent_num ** 0.5)
61 | self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))
62 |
63 | def forward(self, x, mask=None):
64 | """
65 | Args:
66 | x: input features with shape of (num_windows*B, N, C)
67 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
68 | """
69 | x = x.permute(0, 2, 3, 1)
70 | b, h, w, c = x.shape
71 | x = x.reshape(b, h * w, c)
72 | # b, c, h, w = x.shape
73 | # n = h * w
74 | b, n, c = x.shape
75 | h = int(n ** 0.5)
76 | w = int(n ** 0.5)
77 | # print(b)
78 | # print(h)
79 | # print(w)
80 | # print(c)
81 | num_heads = self.num_heads
82 | head_dim = c // num_heads
83 | qkv = self.qkv(x)
84 | # print(qkv.shape)
85 | qkv = qkv.reshape(b, n, 3, c).permute(2, 0, 1, 3)
86 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
87 | # q, k, v: b, n, c
88 |
89 | agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
90 | q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
91 | k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
92 | v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
93 | agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)
94 |
95 | # position_bias1 = nn.functional.interpolate(self.an_bias, size=self.window_size, mode='bilinear')
96 | # position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
97 | # position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
98 | # position_bias = position_bias1 + position_bias2
99 | agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1))
100 | agent_attn = self.attn_drop(agent_attn)
101 | agent_v = agent_attn @ v
102 |
103 | # agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
104 | # agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
105 | # agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
106 | # agent_bias = agent_bias1 + agent_bias2
107 | q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1))
108 | q_attn = self.attn_drop(q_attn)
109 | x = q_attn @ agent_v
110 |
111 | x = x.transpose(1, 2).reshape(b, n, c)
112 | v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
113 | x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)
114 |
115 | x = self.proj(x)
116 | x = self.proj_drop(x)
117 | x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
118 | return x
119 |
120 | # def extra_repr(self) -> str:
121 | # return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
122 |
123 | def flops(self, N):
124 | # calculate flops for 1 window with token length of N
125 | flops = 0
126 | # qkv = self.qkv(x)
127 | flops += N * self.dim * 3 * self.dim
128 | # attn = (q @ k.transpose(-2, -1))
129 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
130 | # x = (attn @ v)
131 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
132 | # x = self.proj(x)
133 | flops += N * self.dim * self.dim
134 | return flops
135 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | if image_embeddings.shape[0] != tokens.shape[0]:
127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
128 | # 将自己的tensor插入自己,是的bs和tokens的bs一致
129 | else:
130 | src = image_embeddings
131 |
132 | src = src + dense_prompt_embeddings # image_embeddings:32 512 64 64 dense_prompt_embeddings:32 256 64 64, img_em and dense_em are the same
133 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
134 | b, c, h, w = src.shape
135 |
136 | # Run the transformer
137 | hs, src = self.transformer(src, pos_src, tokens)
138 | iou_token_out = hs[:, 0, :]
139 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
140 |
141 | # Upscale mask embeddings and predict masks using the mask tokens
142 | src = src.transpose(1, 2).view(b, c, h, w)
143 | upscaled_embedding = self.output_upscaling(src)
144 | hyper_in_list: List[torch.Tensor] = []
145 | for i in range(self.num_mask_tokens):
146 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
147 | hyper_in = torch.stack(hyper_in_list, dim=1)
148 | b, c, h, w = upscaled_embedding.shape
149 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
150 |
151 | # Generate mask quality predictions
152 | iou_pred = self.iou_prediction_head(iou_token_out)
153 |
154 | return masks, iou_pred
155 |
156 |
157 | # Lightly adapted from
158 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
159 | class MLP(nn.Module):
160 | def __init__(
161 | self,
162 | input_dim: int,
163 | hidden_dim: int,
164 | output_dim: int,
165 | num_layers: int,
166 | sigmoid_output: bool = False,
167 | ) -> None:
168 | super().__init__()
169 | self.num_layers = num_layers
170 | h = [hidden_dim] * (num_layers - 1)
171 | self.layers = nn.ModuleList(
172 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
173 | )
174 | self.sigmoid_output = sigmoid_output
175 |
176 | def forward(self, x):
177 | for i, layer in enumerate(self.layers):
178 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
179 | if self.sigmoid_output:
180 | x = F.sigmoid(x)
181 | return x
182 |
--------------------------------------------------------------------------------
/Mytrain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 | join = os.path.join
5 | from tqdm import tqdm
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | import monai
9 | import torch.nn as nn
10 | from segment_anything import sam_model_registry
11 | from segment_anything.utils.transforms import ResizeLongestSide
12 | from segment_anything.modeling.CWDLoss import CriterionCWD
13 | from torch.nn import functional as F
14 | from torchvision.models.mobilenetv2 import InvertedResidual
15 | # set seeds
16 | torch.manual_seed(2024)
17 | np.random.seed(2024)
18 |
19 | def _upsample_like_1024(src):
20 | src = F.interpolate(src, size=(1024, 1024), mode='bilinear')
21 | return src
22 |
23 |
24 | class NpzDataset(Dataset):
25 | def __init__(self, data_root):
26 | self.data_root = data_root
27 | self.npz_files = sorted(os.listdir(self.data_root))
28 | self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]
29 | # this implementation is ugly but it works (and is also fast for feeding data to GPU)
30 | # if your server has enough RAM
31 | # as an alternative, you can also use a list of npy files and load them one by one
32 | self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])
33 | self.ori_imgs = np.vstack([d['imgs'] for d in self.npz_data])
34 | self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])
35 | self.boundary = np.vstack([d['boundary'] for d in self.npz_data])
36 | self.depth_embeddings = np.vstack([d['depth_embeddings'] for d in self.npz_data])
37 | print(f"img_embeddings.shape={self.img_embeddings.shape}, ori_gts.shape={self.ori_gts.shape}, "
38 | f"boundary.shape={self.boundary.shape}", f"depth_embeddings.shape={self.depth_embeddings.shape}")
39 |
40 | def __len__(self):
41 | return self.ori_gts.shape[0]
42 |
43 | def __getitem__(self, index):
44 | img_embed = self.img_embeddings[index]
45 | gt2D = self.ori_gts[index]
46 | img = self.ori_imgs[index]
47 | boundary = self.boundary[index]
48 | depth_embed = self.depth_embeddings[index]
49 | y_indices, x_indices = np.where(gt2D > 0)
50 | x_min, x_max = np.min(x_indices), np.max(x_indices)
51 | y_min, y_max = np.min(y_indices), np.max(y_indices)
52 | # add perturbation to bounding box coordinates
53 | H, W = gt2D.shape
54 | x_min = max(0, x_min - np.random.randint(0, 20))
55 | x_max = min(W, x_max + np.random.randint(0, 20))
56 | y_min = max(0, y_min - np.random.randint(0, 20))
57 | y_max = min(H, y_max + np.random.randint(0, 20))
58 | bboxes = np.array([x_min, y_min, x_max, y_max])
59 | # convert img embedding, mask, bounding box to torch tensor
60 | return torch.tensor(img_embed).float(), torch.tensor(img).float(), torch.tensor(gt2D[None, :, :]).long(), torch.tensor(bboxes).float(),\
61 | torch.tensor(boundary[None, :, :]).long(), torch.tensor(depth_embed).float()
62 |
63 | # %% test dataset class and dataloader
64 | npz_tr_path = 'data/vit_b/COD_train'
65 | work_dir = './work_dir_cod'
66 | task_name = 'DSAM'
67 | # prepare SAM model
68 | model_type = 'vit_b'
69 | checkpoint = 'work_dir_cod/SAM/sam_vit_b_01ec64.pth'
70 | device = 'cuda:0'
71 | model_save_path = join(work_dir, task_name)
72 | os.makedirs(model_save_path, exist_ok=True)
73 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
74 |
75 | sam_model.train()
76 | # Set up the optimizer, hyperparameter tuning will improve performance here
77 | optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
78 | seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
79 | CWD_loss = CriterionCWD(norm_type='channel', divergence='kl', temperature=4.0)
80 |
81 | num_epochs = 100
82 | losses = []
83 | best_loss = 1e10
84 | train_dataset = NpzDataset(npz_tr_path)
85 | mask_threshold = 0.0
86 | train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
87 | for epoch in range(num_epochs+1):
88 | epoch_loss = 0
89 | # train
90 | for step, (image_embedding, img, gt2D, boxes, boundary, depth_embedding) in enumerate(tqdm(train_dataloader)):
91 |
92 | with torch.no_grad():
93 | box_np = boxes.numpy()
94 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
95 | box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))
96 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
97 | boundary = torch.as_tensor(boundary, dtype=torch.float, device=device)
98 | # boundary = boun_conv(boundary)
99 | image_embedding = torch.as_tensor(image_embedding, dtype=torch.float, device=device)
100 | depth_embedding = torch.as_tensor(depth_embedding, dtype=torch.float, device=device)
101 | if len(box_torch.shape) == 2:
102 | box_torch = box_torch[:, None, :] # (B, 1, 4)
103 | # get prompt embeddings
104 | sparse_embeddings_box, dense_embeddings_box = sam_model.prompt_encoder(
105 | points=None,
106 | boxes=box_torch,
107 | masks=None
108 | )
109 |
110 | resize_img_tensor = np.transpose(img, (0, 3, 1, 2)).to(device)
111 | input_image = _upsample_like_1024(resize_img_tensor)
112 | pvt_embedding = sam_model.pvt(input_image)[3]
113 |
114 | bc_embedding, pvt_64 = sam_model.BC(pvt_embedding)
115 | # bc_embedding shape:" 1, 256, 64, 64
116 | distill_loss = CWD_loss(bc_embedding, depth_embedding)
117 | hybrid_embedding = torch.cat([pvt_64, bc_embedding], dim=1)
118 | high_frequency = sam_model.DWT(hybrid_embedding)
119 |
120 | dense_embeddings, sparse_embeddings = sam_model.ME(dense_embeddings_box,
121 | high_frequency, sparse_embeddings_box)
122 |
123 | # predicted masks
124 | mask_predictions, _ = sam_model.mask_decoder(
125 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
126 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
127 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
128 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
129 | multimask_output=False,
130 | )
131 |
132 | final_mask = sam_model.loop_finer(mask_predictions, depth_embedding, depth_embedding)
133 |
134 | mask_predictions = 0.1*final_mask + 0.9*mask_predictions
135 |
136 | loss = 0.9*seg_loss(mask_predictions, gt2D.to(device)) + 0.1*distill_loss
137 | optimizer.zero_grad()
138 | loss.backward()
139 | optimizer.step()
140 | epoch_loss += loss.item()
141 |
142 | epoch_loss /= step
143 | losses.append(epoch_loss)
144 | print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
145 | # save the latest model checkpoint
146 | if epoch >= 80 and epoch % 10 == 0:
147 | torch.save(sam_model.state_dict(), join(model_save_path, str(epoch) + 'sam_model.pth'))
148 | # save the best model
149 | if epoch_loss < best_loss:
150 | best_loss = epoch_loss
151 | torch.save(sam_model.state_dict(), join(model_save_path, 'sam_model_best.pth'))
152 | # plot loss
153 | plt.plot(losses)
154 | plt.title('Dice + Cross Entropy Loss')
155 | plt.xlabel('Epoch')
156 | plt.ylabel('Loss')
157 | # plt.show() # comment this line if you are running on a server
158 | plt.savefig(join(model_save_path, 'train_loss.png'))
159 | plt.close()
--------------------------------------------------------------------------------
/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 | # from .FUSE import FUSE
17 | from .pvtv2 import pvt_v2_b2
18 | from .DWT import extract_high_frequency
19 | from .mix_enbedding import ME
20 | from .loop_finer import Loop_Finer
21 | from .bias_correction import bias_correction
22 | from torchvision.models.mobilenetv2 import InvertedResidual
23 |
24 | # from .sd_merge import MDS
25 | class Sam(nn.Module):
26 | mask_threshold: float = 0.0
27 | image_format: str = "RGB"
28 |
29 | def __init__(
30 | self,
31 | image_encoder: ImageEncoderViT,
32 | prompt_encoder: PromptEncoder,
33 | mask_decoder: MaskDecoder,
34 | pvt: pvt_v2_b2,
35 | DWT: extract_high_frequency,
36 | ME:ME,
37 | Loop_finer: Loop_Finer,
38 | BC: bias_correction,
39 | # IRB: InvertedResidual,
40 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
41 | pixel_std: List[float] = [58.395, 57.12, 57.375],
42 | ) -> None:
43 | """
44 | SAM predicts object masks from an image and input prompts.
45 |
46 | Arguments:
47 | image_encoder (ImageEncoderViT): The backbone used to encode the
48 | image into image embeddings that allow for efficient mask prediction.
49 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
50 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
51 | and encoded prompts.
52 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
53 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
54 | """
55 | super().__init__()
56 | self.image_encoder = image_encoder
57 | self.prompt_encoder = prompt_encoder
58 | self.mask_decoder = mask_decoder
59 | self.pvt = pvt
60 | self.BC = BC
61 | # self.IRB = IRB
62 | self. DWT = DWT
63 | self.ME = ME
64 | self.loop_finer = Loop_finer
65 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
66 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
67 |
68 | @property
69 | def device(self) -> Any:
70 | return self.pixel_mean.device
71 |
72 | @torch.no_grad()
73 | def forward(
74 | self,
75 | batched_input: List[Dict[str, Any]],
76 | multimask_output: bool,
77 | ) -> List[Dict[str, torch.Tensor]]:
78 | """
79 | Predicts masks end-to-end from provided images and prompts.
80 | If prompts are not known in advance, using SamPredictor is
81 | recommended over calling the model directly.
82 |
83 | Arguments:
84 | batched_input (list(dict)): A list over input images, each a
85 | dictionary with the following keys. A prompt key can be
86 | excluded if it is not present.
87 | 'image': The image as a torch tensor in 3xHxW format,
88 | already transformed for input to the model.
89 | 'original_size': (tuple(int, int)) The original size of
90 | the image before transformation, as (H, W).
91 | 'point_coords': (torch.Tensor) Batched point prompts for
92 | this image, with shape BxNx2. Already transformed to the
93 | input frame of the model.
94 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
95 | with shape BxN.
96 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
97 | Already transformed to the input frame of the model.
98 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
99 | in the form Bx1xHxW.
100 | multimask_output (bool): Whether the model should predict multiple
101 | disambiguating masks, or return a single mask.
102 |
103 | Returns:
104 | (list(dict)): A list over input images, where each element is
105 | as dictionary with the following keys.
106 | 'masks': (torch.Tensor) Batched binary mask predictions,
107 | with shape BxCxHxW, where B is the number of input prompts,
108 | C is determined by multimask_output, and (H, W) is the
109 | original size of the image.
110 | 'iou_predictions': (torch.Tensor) The model's predictions
111 | of mask quality, in shape BxC.
112 | 'low_res_logits': (torch.Tensor) Low resolution logits with
113 | shape BxCxHxW, where H=W=256. Can be passed as mask input
114 | to subsequent iterations of prediction.
115 | """
116 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
117 | image_embeddings = self.image_encoder(input_images)
118 |
119 | outputs = []
120 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
121 | if "point_coords" in image_record:
122 | points = (image_record["point_coords"], image_record["point_labels"])
123 | else:
124 | points = None
125 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
126 | points=points,
127 | boxes=image_record.get("boxes", None),
128 | masks=image_record.get("mask_inputs", None),
129 | )
130 | low_res_masks, iou_predictions = self.mask_decoder(
131 | image_embeddings=curr_embedding.unsqueeze(0),
132 | image_pe=self.prompt_encoder.get_dense_pe(),
133 | sparse_prompt_embeddings=sparse_embeddings,
134 | dense_prompt_embeddings=dense_embeddings,
135 | multimask_output=multimask_output,
136 | )
137 | masks = self.postprocess_masks(
138 | low_res_masks,
139 | input_size=image_record["image"].shape[-2:],
140 | original_size=image_record["original_size"],
141 | )
142 | masks = masks > self.mask_threshold
143 | outputs.append(
144 | {
145 | "masks": masks,
146 | "iou_predictions": iou_predictions,
147 | "low_res_logits": low_res_masks,
148 | }
149 | )
150 | return outputs
151 |
152 | def postprocess_masks(
153 | self,
154 | masks: torch.Tensor,
155 | input_size: Tuple[int, ...],
156 | original_size: Tuple[int, ...],
157 | ) -> torch.Tensor:
158 | """
159 | Remove padding and upscale masks to the original image size.
160 |
161 | Arguments:
162 | masks (torch.Tensor): Batched masks from the mask_decoder,
163 | in BxCxHxW format.
164 | input_size (tuple(int, int)): The size of the image input to the
165 | model, in (H, W) format. Used to remove padding.
166 | original_size (tuple(int, int)): The original size of the image
167 | before resizing for input to the model, in (H, W) format.
168 |
169 | Returns:
170 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
171 | is given by original_size.
172 | """
173 | masks = F.interpolate(
174 | masks,
175 | (self.image_encoder.img_size, self.image_encoder.img_size),
176 | mode="bilinear",
177 | align_corners=False,
178 | )
179 | masks = masks[..., : input_size[0], : input_size[1]]
180 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
181 | return masks
182 |
183 | # def preprocess(self, x: torch.Tensor) -> torch.Tensor:
184 | # """Normalize pixel values and pad to a square input."""
185 | # # Normalize colors
186 | # x = (x - self.pixel_mean) / self.pixel_std
187 | #
188 | # # Pad
189 | # h, w = x.shape[-2:]
190 | # padh = 352 - h
191 | # padw = 352 - w
192 | # x = F.pad(x, (0, padw, 0, padh))
193 | # return x
194 |
195 | # 20240201 sam原本的代码
196 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
197 | """Normalize pixel values and pad to a square input."""
198 | # Normalize colors
199 | x = (x - self.pixel_mean) / self.pixel_std
200 |
201 | # Pad
202 | h, w = x.shape[-2:]
203 | padh = self.image_encoder.img_size - h
204 | padw = self.image_encoder.img_size - w
205 | x = F.pad(x, (0, padw, 0, padh))
206 | return x
207 |
--------------------------------------------------------------------------------
/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attention layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels = torch.cat([labels, padding_label], dim=1)
86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87 | point_embedding[labels == -1] = 0.0
88 | point_embedding[labels == -1] += self.not_a_point_embed.weight
89 | point_embedding[labels == 0] += self.point_embeddings[0].weight
90 | point_embedding[labels == 1] += self.point_embeddings[1].weight
91 | return point_embedding
92 |
93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94 | """Embeds box prompts."""
95 | boxes = boxes + 0.5 # Shift to center of pixel
96 | coords = boxes.reshape(-1, 2, 2)
97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100 | return corner_embedding
101 |
102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103 | """Embeds mask inputs."""
104 | mask_embedding = self.mask_downscaling(masks)
105 | return mask_embedding
106 |
107 | def _get_batch_size(
108 | self,
109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110 | boxes: Optional[torch.Tensor],
111 | masks: Optional[torch.Tensor],
112 | ) -> int:
113 | """
114 | Gets the batch size of the output given the batch size of the input prompts.
115 | """
116 | if points is not None:
117 | return points[0].shape[0]
118 | elif boxes is not None:
119 | return boxes.shape[0]
120 | elif masks is not None:
121 | return masks.shape[0]
122 | else:
123 | return 1
124 |
125 | def _get_device(self) -> torch.device:
126 | return self.point_embeddings[0].weight.device
127 |
128 | def forward(
129 | self,
130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131 | boxes: Optional[torch.Tensor],
132 | masks: Optional[torch.Tensor],
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """
135 | Embeds different types of prompts, returning both sparse and dense
136 | embeddings.
137 |
138 | Arguments:
139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140 | and labels to embed.
141 | boxes (torch.Tensor or none): boxes to embed
142 | masks (torch.Tensor or none): masks to embed
143 |
144 | Returns:
145 | torch.Tensor: sparse embeddings for the points and boxes, with shape
146 | BxNx(embed_dim), where N is determined by the number of input points
147 | and boxes.
148 | torch.Tensor: dense embeddings for the masks, in the shape
149 | Bx(embed_dim)x(embed_H)x(embed_W)
150 | """
151 | bs = self._get_batch_size(points, boxes, masks)
152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153 | if points is not None:
154 | coords, labels = points
155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 | if boxes is not None:
158 | box_embeddings = self._embed_boxes(boxes)
159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160 |
161 | if masks is not None:
162 | dense_embeddings = self._embed_masks(masks)
163 | else:
164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166 | )
167 |
168 | return sparse_embeddings, dense_embeddings
169 |
170 |
171 | class PositionEmbeddingRandom(nn.Module):
172 | """
173 | Positional encoding using random spatial frequencies.
174 | """
175 |
176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177 | super().__init__()
178 | if scale is None or scale <= 0.0:
179 | scale = 1.0
180 | self.register_buffer(
181 | "positional_encoding_gaussian_matrix",
182 | scale * torch.randn((2, num_pos_feats)),
183 | )
184 |
185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186 | """Positionally encode points that are normalized to [0,1]."""
187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188 | coords = 2 * coords - 1
189 | coords = coords @ self.positional_encoding_gaussian_matrix
190 | coords = 2 * np.pi * coords
191 | # outputs d_1 x ... x d_n x C shape
192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193 |
194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195 | """Generate positional encoding for a grid of the specified size."""
196 | h, w = size
197 | device: Any = self.positional_encoding_gaussian_matrix.device
198 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
199 | y_embed = grid.cumsum(dim=0) - 0.5
200 | x_embed = grid.cumsum(dim=1) - 0.5
201 | y_embed = y_embed / h
202 | x_embed = x_embed / w
203 |
204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205 | return pe.permute(2, 0, 1) # C x H x W
206 |
207 | def forward_with_coords(
208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209 | ) -> torch.Tensor:
210 | """Positionally encode points that are not normalized to [0,1]."""
211 | coords = coords_input.clone()
212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
215 |
--------------------------------------------------------------------------------
/Mytest.py:
--------------------------------------------------------------------------------
1 | # %% load environment
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import os
5 |
6 | join = os.path.join
7 | import torch
8 | from segment_anything import sam_model_registry
9 | from segment_anything.utils.transforms import ResizeLongestSide
10 | from tqdm import tqdm
11 | import argparse
12 | import traceback
13 | import sys
14 | import torch.nn as nn
15 | import shutil
16 | from torchvision.models.mobilenetv2 import InvertedResidual
17 | import os
18 |
19 | torch.manual_seed(1)
20 | np.random.seed(1)
21 | # %% run inference
22 | # set up the parser
23 | parser = argparse.ArgumentParser(description='run inference on testing set')
24 | parser.add_argument('-i', '--data_path', type=str, default='data/vit_b/COD_test', help='path to the data folder')
25 | parser.add_argument('-o', '--seg_path_root', type=str, default='data/inference_npz/DSAM',
26 | help='path to the segmentation folder')
27 | parser.add_argument('--seg_png_path', type=str, default='data/inference_img/DSAM',
28 | help='path to the segmentation folder')
29 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type')
30 | parser.add_argument('--device', type=str, default='cuda:1', help='device')
31 | parser.add_argument('-chk', '--checkpoint', type=str, default='work_dir_cod/DSAM/DSAM.pth',
32 | help='path to the trained model')
33 | args = parser.parse_args()
34 |
35 |
36 | mDSC = []
37 | num_of_mDSC = 0
38 |
39 | def show_mask(mask, ax, random_color=False):
40 | if random_color:
41 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
42 | else:
43 | color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
44 | h, w = mask.shape[-2:]
45 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
46 | ax.imshow(mask_image)
47 |
48 |
49 | def show_box(box, ax):
50 | x0, y0 = box[0], box[1]
51 | w, h = box[2] - box[0], box[3] - box[1]
52 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0, 0, 0, 0), lw=2))
53 |
54 |
55 | def compute_dice(mask_gt, mask_pred):
56 | """Compute soerensen-dice coefficient.
57 | Returns:
58 | the dice coeffcient as float. If both masks are empty, the result is NaN
59 | """
60 | volume_sum = mask_gt.sum() + mask_pred.sum()
61 | if volume_sum == 0:
62 | return np.NaN
63 | volume_intersect = (mask_gt & mask_pred).sum()
64 | return 2 * volume_intersect / volume_sum
65 |
66 |
67 | def finetune_model_predict(img_np, box_np, depth_np, boundary, sam_trans, sam_model_tune, device=args.device):
68 | H, W = img_np.shape[:2]
69 | # Original image processing
70 | resize_img = sam_trans.apply_image(img_np)
71 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device) # (3, 1024, 1024)
72 | input_image = sam_model_tune.preprocess(resize_img_tensor[None, :, :, :]) # (1, 3, 1024, 1024)
73 | # Depth map processing
74 | resize_depth_img = sam_trans.apply_image(depth_np)
75 | resize_dep_tensor = torch.as_tensor(resize_depth_img.transpose(2, 0, 1)).to(device)
76 | depth_image = sam_model_tune.preprocess(resize_dep_tensor[None, :, :, :]) # (1, 3, 1024, 1024)
77 |
78 | with torch.no_grad():
79 | image_embedding = sam_model_tune.image_encoder(input_image.to(device)) # (1, 256, 64, 64)
80 | depth_embedding = sam_model_tune.image_encoder(depth_image.to(device)) # (1, 256, 64, 64)
81 |
82 | # convert box to 1024x1024 grid
83 | box = sam_trans.apply_boxes(box_np, (H, W))
84 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
85 |
86 | if len(box_torch.shape) == 2:
87 | box_torch = box_torch[:, None, :] # (B, 1, 4)
88 |
89 | sparse_embeddings_box, dense_embeddings_box = sam_model_tune.prompt_encoder(
90 | points=None,
91 | boxes=box_torch,
92 | masks=None
93 | )
94 |
95 | pvt_embedding = sam_model_tune.pvt(input_image)[3]
96 | bc_embedding, pvt_64 = sam_model_tune.BC(pvt_embedding)
97 | hybrid_embedding = torch.cat([pvt_64, bc_embedding], dim=1)
98 | high_frequency = sam_model_tune.DWT(hybrid_embedding)
99 | dense_embeddings, sparse_embeddings = sam_model_tune.ME(dense_embeddings_box,
100 | high_frequency, sparse_embeddings_box)
101 | # predicted masks
102 | seg_prob, _ = sam_model_tune.mask_decoder(
103 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64) hybrid_embedding:(1, 512, 64, 64)
104 | image_pe=sam_model_tune.prompt_encoder.get_dense_pe(), #
105 | # (1, 256, 64, 64)
106 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
107 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
108 | multimask_output=False,
109 | )
110 | final_mask = sam_model_tune.loop_finer(seg_prob, depth_embedding, depth_embedding)
111 | seg_prob = 0.1 * final_mask + 0.9 * seg_prob
112 | seg_prob = torch.sigmoid(seg_prob)
113 | # convert soft mask to hard mask
114 | seg_prob = seg_prob.cpu().numpy().squeeze()
115 | seg = (seg_prob > 0.5).astype(np.uint8)
116 | return seg
117 |
118 | def delete_folder(folder_path):
119 | try:
120 | # 删除文件夹及其内容
121 | shutil.rmtree(folder_path)
122 | print(f"delete folder: {folder_path}")
123 | except Exception as e:
124 | print(f"fail to delete folder: {folder_path}: {e}")
125 |
126 | def divide(x, y):
127 | try:
128 | result = x / y
129 | return result
130 | except ZeroDivisionError as e:
131 | # 处理除以零的情况
132 | delete_folder(args.seg_path_root)
133 | delete_folder(args.seg_png_path)
134 | print("division by zero:", e)
135 |
136 |
137 | device = args.device
138 | sam_model_tune = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(device)
139 |
140 | sam_trans = ResizeLongestSide(sam_model_tune.image_encoder.img_size)
141 |
142 | npz_folders = sorted(os.listdir(args.data_path))
143 | os.makedirs(args.seg_png_path, exist_ok=True)
144 | sam_dice_scores = []
145 | for npz_folder in npz_folders:
146 | npz_data_path = join(args.data_path, npz_folder)
147 | save_path = join(args.seg_path_root, npz_folder)
148 | if not os.path.exists(save_path):
149 | os.makedirs(save_path, exist_ok=True)
150 | npz_files = sorted(os.listdir(npz_data_path))
151 | for npz_file in tqdm(npz_files):
152 | try:
153 | npz = np.load(join(npz_data_path, npz_file))
154 | print(npz_file)
155 | ori_imgs = npz['imgs']
156 | ori_gts = npz['gts']
157 | ori_number = npz['number']
158 | boundary = npz['boundary']
159 | dep_imgs = npz['depth_imgs']
160 |
161 | sam_segs = []
162 | sam_bboxes = []
163 | sam_dice_scores = []
164 | for img_id, ori_img in enumerate(ori_imgs):
165 | # get bounding box from mask
166 | gt2D = ori_gts[img_id]
167 | bboundary = boundary[img_id]
168 | depth_img = dep_imgs[img_id]
169 |
170 | y_indices, x_indices = np.where(gt2D > 0)
171 | x_min, x_max = np.min(x_indices), np.max(x_indices)
172 | y_min, y_max = np.min(y_indices), np.max(y_indices)
173 | # add perturbation to bounding box coordinates
174 | H, W = gt2D.shape
175 | x_min = max(0, x_min - np.random.randint(0, 20))
176 | x_max = min(W, x_max + np.random.randint(0, 20))
177 | y_min = max(0, y_min - np.random.randint(0, 20))
178 | y_max = min(H, y_max + np.random.randint(0, 20))
179 | bbox = np.array([x_min, y_min, x_max, y_max])
180 | seg_mask = finetune_model_predict(ori_img, bbox, depth_img, bboundary, sam_trans, sam_model_tune, device=device)
181 | sam_segs.append(seg_mask)
182 | sam_bboxes.append(bbox)
183 | # these 2D dice scores are for debugging purpose.
184 | # 3D dice scores should be computed for 3D images
185 | sam_dice_scores.append(compute_dice(seg_mask > 0, gt2D > 0))
186 |
187 | # save npz, including sam_segs, sam_bboxes, sam_dice_scores
188 | np.savez_compressed(join(save_path, npz_file), medsam_segs=sam_segs, gts=ori_gts, number=ori_number, sam_bboxes=sam_bboxes)
189 |
190 | # visualize segmentation results
191 | img_id = np.random.randint(0, len(ori_imgs))
192 | # show ground truth and segmentation results in two subplots
193 | fig, axes = plt.subplots(1, 2, figsize=(10, 5))
194 | axes[0].imshow(ori_imgs[img_id])
195 | show_box(sam_bboxes[img_id], axes[0])
196 | show_mask(ori_gts[img_id], axes[0])
197 | axes[0].set_title('Ground Truth')
198 | axes[0].axis('off')
199 |
200 | axes[1].imshow(ori_imgs[img_id])
201 | show_box(sam_bboxes[img_id], axes[1])
202 | show_mask(sam_segs[img_id], axes[1])
203 | axes[1].set_title('DSAM: DSC={:.3f}'.format(sam_dice_scores[img_id]))
204 | axes[1].axis('off')
205 | # save figure
206 | fig.savefig(join(args.seg_png_path, npz_file.split('.npz')[0] + '.png'))
207 | # close figure
208 | plt.close(fig)
209 | except Exception:
210 | traceback.print_exc()
211 | print('error in {}'.format(npz_file))
212 |
213 | tmp_mDSC = divide(sum(sam_dice_scores), len(sam_dice_scores))
214 | mDSC.append(tmp_mDSC)
215 |
216 | print(str(npz_folder)+": " + str(tmp_mDSC))
217 | # average number of
218 | print("finial average mDSC: " + str((sum(mDSC)/len(mDSC))))
--------------------------------------------------------------------------------
/segment_anything/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from segment_anything.modeling import Sam
11 |
12 | from typing import Optional, Tuple
13 |
14 | from .utils.transforms import ResizeLongestSide
15 |
16 |
17 | class SamPredictor:
18 | def __init__(
19 | self,
20 | sam_model: Sam,
21 | ) -> None:
22 | """
23 | Uses SAM to calculate the image embedding for an image, and then
24 | allow repeated, efficient mask prediction given prompts.
25 |
26 | Arguments:
27 | sam_model (Sam): The model to use for mask prediction.
28 | """
29 | super().__init__()
30 | self.model = sam_model
31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32 | self.reset_image()
33 |
34 | def set_image(
35 | self,
36 | image: np.ndarray,
37 | image_format: str = "RGB",
38 | ) -> None:
39 | """
40 | Calculates the image embeddings for the provided image, allowing
41 | masks to be predicted with the 'predict' method.
42 |
43 | Arguments:
44 | image (np.ndarray): The image for calculating masks. Expects an
45 | image in HWC uint8 format, with pixel values in [0, 255].
46 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
47 | """
48 | assert image_format in [
49 | "RGB",
50 | "BGR",
51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52 | if image_format != self.model.image_format:
53 | image = image[..., ::-1]
54 |
55 | # Transform the image to the form expected by the model
56 | input_image = self.transform.apply_image(image)
57 | input_image_torch = torch.as_tensor(input_image, device=self.device)
58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59 |
60 | self.set_torch_image(input_image_torch, image.shape[:2])
61 |
62 | @torch.no_grad()
63 | def set_torch_image(
64 | self,
65 | transformed_image: torch.Tensor,
66 | original_image_size: Tuple[int, ...],
67 | ) -> None:
68 | """
69 | Calculates the image embeddings for the provided image, allowing
70 | masks to be predicted with the 'predict' method. Expects the input
71 | image to be already transformed to the format expected by the model.
72 |
73 | Arguments:
74 | transformed_image (torch.Tensor): The input image, with shape
75 | 1x3xHxW, which has been transformed with ResizeLongestSide.
76 | original_image_size (tuple(int, int)): The size of the image
77 | before transformation, in (H, W) format.
78 | """
79 | assert (
80 | len(transformed_image.shape) == 4
81 | and transformed_image.shape[1] == 3
82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 | self.reset_image()
85 |
86 | self.original_size = original_image_size
87 | self.input_size = tuple(transformed_image.shape[-2:])
88 | input_image = self.model.preprocess(transformed_image)
89 | self.features = self.model.image_encoder(input_image)
90 | self.is_image_set = True
91 |
92 | def predict(
93 | self,
94 | point_coords: Optional[np.ndarray] = None,
95 | point_labels: Optional[np.ndarray] = None,
96 | box: Optional[np.ndarray] = None,
97 | mask_input: Optional[np.ndarray] = None,
98 | multimask_output: bool = True,
99 | return_logits: bool = False,
100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101 | """
102 | Predict masks for the given input prompts, using the currently set image.
103 |
104 | Arguments:
105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106 | model. Each point is in (X,Y) in pixels.
107 | point_labels (np.ndarray or None): A length N array of labels for the
108 | point prompts. 1 indicates a foreground point and 0 indicates a
109 | background point.
110 | box (np.ndarray or None): A length 4 array given a box prompt to the
111 | model, in XYXY format.
112 | mask_input (np.ndarray): A low resolution mask input to the model, typically
113 | coming from a previous prediction iteration. Has form 1xHxW, where
114 | for SAM, H=W=256.
115 | multimask_output (bool): If true, the model will return three masks.
116 | For ambiguous input prompts (such as a single click), this will often
117 | produce better masks than a single prediction. If only a single
118 | mask is needed, the model's predicted quality score can be used
119 | to select the best mask. For non-ambiguous prompts, such as multiple
120 | input prompts, multimask_output=False can give better results.
121 | return_logits (bool): If true, returns un-thresholded masks logits
122 | instead of a binary mask.
123 |
124 | Returns:
125 | (np.ndarray): The output masks in CxHxW format, where C is the
126 | number of masks, and (H, W) is the original image size.
127 | (np.ndarray): An array of length C containing the model's
128 | predictions for the quality of each mask.
129 | (np.ndarray): An array of shape CxHxW, where C is the number
130 | of masks and H=W=256. These low resolution logits can be passed to
131 | a subsequent iteration as mask input.
132 | """
133 | if not self.is_image_set:
134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135 |
136 | # Transform input prompts
137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138 | if point_coords is not None:
139 | assert (
140 | point_labels is not None
141 | ), "point_labels must be supplied if point_coords is supplied."
142 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146 | if box is not None:
147 | box = self.transform.apply_boxes(box, self.original_size)
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149 | box_torch = box_torch[None, :]
150 | if mask_input is not None:
151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152 | mask_input_torch = mask_input_torch[None, :, :, :]
153 |
154 | masks, iou_predictions, low_res_masks = self.predict_torch(
155 | coords_torch,
156 | labels_torch,
157 | box_torch,
158 | mask_input_torch,
159 | multimask_output,
160 | return_logits=return_logits,
161 | )
162 |
163 | masks_np = masks[0].detach().cpu().numpy()
164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166 | return masks_np, iou_predictions_np, low_res_masks_np
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | """
179 | Predict masks for the given input prompts, using the currently set image.
180 | Input prompts are batched torch tensors and are expected to already be
181 | transformed to the input frame using ResizeLongestSide.
182 |
183 | Arguments:
184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (torch.Tensor or None): A BxN array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form Bx1xHxW, where
193 | for SAM, H=W=256. Masks returned by a previous iteration of the
194 | predict method do not need further transformation.
195 | multimask_output (bool): If true, the model will return three masks.
196 | For ambiguous input prompts (such as a single click), this will often
197 | produce better masks than a single prediction. If only a single
198 | mask is needed, the model's predicted quality score can be used
199 | to select the best mask. For non-ambiguous prompts, such as multiple
200 | input prompts, multimask_output=False can give better results.
201 | return_logits (bool): If true, returns un-thresholded masks logits
202 | instead of a binary mask.
203 |
204 | Returns:
205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
206 | number of masks, and (H, W) is the original image size.
207 | (torch.Tensor): An array of shape BxC containing the model's
208 | predictions for the quality of each mask.
209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
210 | of masks and H=W=256. These low res logits can be passed to
211 | a subsequent iteration as mask input.
212 | """
213 | if not self.is_image_set:
214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 |
216 | if point_coords is not None:
217 | points = (point_coords, point_labels)
218 | else:
219 | points = None
220 |
221 | # Embed prompts
222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 | points=points,
224 | boxes=boxes,
225 | masks=mask_input,
226 | )
227 |
228 | # Predict masks
229 | low_res_masks, iou_predictions = self.model.mask_decoder(
230 | image_embeddings=self.features,
231 | image_pe=self.model.prompt_encoder.get_dense_pe(),
232 | sparse_prompt_embeddings=sparse_embeddings,
233 | dense_prompt_embeddings=dense_embeddings,
234 | multimask_output=multimask_output,
235 | )
236 |
237 | # Upscale the masks to the original image resolution
238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239 |
240 | if not return_logits:
241 | masks = masks > self.model.mask_threshold
242 |
243 | return masks, iou_predictions, low_res_masks
244 |
245 | def get_image_embedding(self) -> torch.Tensor:
246 | """
247 | Returns the image embeddings for the currently set image, with
248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
250 | """
251 | if not self.is_image_set:
252 | raise RuntimeError(
253 | "An image must be set with .set_image(...) to generate an embedding."
254 | )
255 | assert self.features is not None, "Features must exist if an image has been set."
256 | return self.features
257 |
258 | @property
259 | def device(self) -> torch.device:
260 | return self.model.device
261 |
262 | def reset_image(self) -> None:
263 | """Resets the currently set image."""
264 | self.is_image_set = False
265 | self.features = None
266 | self.orig_h = None
267 | self.orig_w = None
268 | self.input_h = None
269 | self.input_w = None
270 |
--------------------------------------------------------------------------------
/segment_anything/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | import math
11 | from copy import deepcopy
12 | from itertools import product
13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
14 |
15 |
16 | class MaskData:
17 | """
18 | A structure for storing masks and their related data in batched format.
19 | Implements basic filtering and concatenation.
20 | """
21 |
22 | def __init__(self, **kwargs) -> None:
23 | for v in kwargs.values():
24 | assert isinstance(
25 | v, (list, np.ndarray, torch.Tensor)
26 | ), "MaskData only supports list, numpy arrays, and torch tensors."
27 | self._stats = dict(**kwargs)
28 |
29 | def __setitem__(self, key: str, item: Any) -> None:
30 | assert isinstance(
31 | item, (list, np.ndarray, torch.Tensor)
32 | ), "MaskData only supports list, numpy arrays, and torch tensors."
33 | self._stats[key] = item
34 |
35 | def __delitem__(self, key: str) -> None:
36 | del self._stats[key]
37 |
38 | def __getitem__(self, key: str) -> Any:
39 | return self._stats[key]
40 |
41 | def items(self) -> ItemsView[str, Any]:
42 | return self._stats.items()
43 |
44 | def filter(self, keep: torch.Tensor) -> None:
45 | for k, v in self._stats.items():
46 | if v is None:
47 | self._stats[k] = None
48 | elif isinstance(v, torch.Tensor):
49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50 | elif isinstance(v, np.ndarray):
51 | self._stats[k] = v[keep.detach().cpu().numpy()]
52 | elif isinstance(v, list) and keep.dtype == torch.bool:
53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54 | elif isinstance(v, list):
55 | self._stats[k] = [v[i] for i in keep]
56 | else:
57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58 |
59 | def cat(self, new_stats: "MaskData") -> None:
60 | for k, v in new_stats.items():
61 | if k not in self._stats or self._stats[k] is None:
62 | self._stats[k] = deepcopy(v)
63 | elif isinstance(v, torch.Tensor):
64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65 | elif isinstance(v, np.ndarray):
66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67 | elif isinstance(v, list):
68 | self._stats[k] = self._stats[k] + deepcopy(v)
69 | else:
70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71 |
72 | def to_numpy(self) -> None:
73 | for k, v in self._stats.items():
74 | if isinstance(v, torch.Tensor):
75 | self._stats[k] = v.detach().cpu().numpy()
76 |
77 |
78 | def is_box_near_crop_edge(
79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80 | ) -> torch.Tensor:
81 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88 | return torch.any(near_crop_edge, dim=1)
89 |
90 |
91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92 | box_xywh = deepcopy(box_xyxy)
93 | box_xywh[2] = box_xywh[2] - box_xywh[0]
94 | box_xywh[3] = box_xywh[3] - box_xywh[1]
95 | return box_xywh
96 |
97 |
98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99 | assert len(args) > 0 and all(
100 | len(a) == len(args[0]) for a in args
101 | ), "Batched iteration must have inputs of all the same size."
102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103 | for b in range(n_batches):
104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105 |
106 |
107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108 | """
109 | Encodes masks to an uncompressed RLE, in the format expected by
110 | pycoco tools.
111 | """
112 | # Put in fortran order and flatten h,w
113 | b, h, w = tensor.shape
114 | tensor = tensor.permute(0, 2, 1).flatten(1)
115 |
116 | # Compute change indices
117 | diff = tensor[:, 1:] ^ tensor[:, :-1]
118 | change_indices = diff.nonzero()
119 |
120 | # Encode run length
121 | out = []
122 | for i in range(b):
123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124 | cur_idxs = torch.cat(
125 | [
126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127 | cur_idxs + 1,
128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | ]
130 | )
131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132 | counts = [] if tensor[i, 0] == 0 else [0]
133 | counts.extend(btw_idxs.detach().cpu().tolist())
134 | out.append({"size": [h, w], "counts": counts})
135 | return out
136 |
137 |
138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139 | """Compute a binary mask from an uncompressed RLE."""
140 | h, w = rle["size"]
141 | mask = np.empty(h * w, dtype=bool)
142 | idx = 0
143 | parity = False
144 | for count in rle["counts"]:
145 | mask[idx : idx + count] = parity
146 | idx += count
147 | parity ^= True
148 | mask = mask.reshape(w, h)
149 | return mask.transpose() # Put in C order
150 |
151 |
152 | def area_from_rle(rle: Dict[str, Any]) -> int:
153 | return sum(rle["counts"][1::2])
154 |
155 |
156 | def calculate_stability_score(
157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158 | ) -> torch.Tensor:
159 | """
160 | Computes the stability score for a batch of masks. The stability
161 | score is the IoU between the binary masks obtained by thresholding
162 | the predicted mask logits at high and low values.
163 | """
164 | # One mask is always contained inside the other.
165 | # Save memory by preventing unnecessary cast to torch.int64
166 | intersections = (
167 | (masks > (mask_threshold + threshold_offset))
168 | .sum(-1, dtype=torch.int16)
169 | .sum(-1, dtype=torch.int32)
170 | )
171 | unions = (
172 | (masks > (mask_threshold - threshold_offset))
173 | .sum(-1, dtype=torch.int16)
174 | .sum(-1, dtype=torch.int32)
175 | )
176 | return intersections / unions
177 |
178 |
179 | def build_point_grid(n_per_side: int) -> np.ndarray:
180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181 | offset = 1 / (2 * n_per_side)
182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186 | return points
187 |
188 |
189 | def build_all_layer_point_grids(
190 | n_per_side: int, n_layers: int, scale_per_layer: int
191 | ) -> List[np.ndarray]:
192 | """Generates point grids for all crop layers."""
193 | points_by_layer = []
194 | for i in range(n_layers + 1):
195 | n_points = int(n_per_side / (scale_per_layer**i))
196 | points_by_layer.append(build_point_grid(n_points))
197 | return points_by_layer
198 |
199 |
200 | def generate_crop_boxes(
201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202 | ) -> Tuple[List[List[int]], List[int]]:
203 | """
204 | Generates a list of crop boxes of different sizes. Each layer
205 | has (2**i)**2 boxes for the ith layer.
206 | """
207 | crop_boxes, layer_idxs = [], []
208 | im_h, im_w = im_size
209 | short_side = min(im_h, im_w)
210 |
211 | # Original image
212 | crop_boxes.append([0, 0, im_w, im_h])
213 | layer_idxs.append(0)
214 |
215 | def crop_len(orig_len, n_crops, overlap):
216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217 |
218 | for i_layer in range(n_layers):
219 | n_crops_per_side = 2 ** (i_layer + 1)
220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221 |
222 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
223 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
224 |
225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227 |
228 | # Crops in XYWH format
229 | for x0, y0 in product(crop_box_x0, crop_box_y0):
230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231 | crop_boxes.append(box)
232 | layer_idxs.append(i_layer + 1)
233 |
234 | return crop_boxes, layer_idxs
235 |
236 |
237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238 | x0, y0, _, _ = crop_box
239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240 | # Check if boxes has a channel dimension
241 | if len(boxes.shape) == 3:
242 | offset = offset.unsqueeze(1)
243 | return boxes + offset
244 |
245 |
246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247 | x0, y0, _, _ = crop_box
248 | offset = torch.tensor([[x0, y0]], device=points.device)
249 | # Check if points has a channel dimension
250 | if len(points.shape) == 3:
251 | offset = offset.unsqueeze(1)
252 | return points + offset
253 |
254 |
255 | def uncrop_masks(
256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257 | ) -> torch.Tensor:
258 | x0, y0, x1, y1 = crop_box
259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260 | return masks
261 | # Coordinate transform masks
262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263 | pad = (x0, pad_x - x0, y0, pad_y - y0)
264 | return torch.nn.functional.pad(masks, pad, value=0)
265 |
266 |
267 | def remove_small_regions(
268 | mask: np.ndarray, area_thresh: float, mode: str
269 | ) -> Tuple[np.ndarray, bool]:
270 | """
271 | Removes small disconnected regions and holes in a mask. Returns the
272 | mask and an indicator of if the mask has been modified.
273 | """
274 | import cv2 # type: ignore
275 |
276 | assert mode in ["holes", "islands"]
277 | correct_holes = mode == "holes"
278 | working_mask = (correct_holes ^ mask).astype(np.uint8)
279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280 | sizes = stats[:, -1][1:] # Row 0 is background label
281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282 | if len(small_regions) == 0:
283 | return mask, False
284 | fill_labels = [0] + small_regions
285 | if not correct_holes:
286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287 | # If every region is below threshold, keep largest
288 | if len(fill_labels) == 0:
289 | fill_labels = [int(np.argmax(sizes)) + 1]
290 | mask = np.isin(regions, fill_labels)
291 | return mask, True
292 |
293 |
294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295 | from pycocotools import mask as mask_utils # type: ignore
296 |
297 | h, w = uncompressed_rle["size"]
298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300 | return rle
301 |
302 |
303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304 | """
305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307 | """
308 | # torch.max below raises an error on empty inputs, just skip in this case
309 | if torch.numel(masks) == 0:
310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311 |
312 | # Normalize shape to CxHxW
313 | shape = masks.shape
314 | h, w = shape[-2:]
315 | if len(shape) > 2:
316 | masks = masks.flatten(0, -3)
317 | else:
318 | masks = masks.unsqueeze(0)
319 |
320 | # Get top and bottom edges
321 | in_height, _ = torch.max(masks, dim=-1)
322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324 | in_height_coords = in_height_coords + h * (~in_height)
325 | top_edges, _ = torch.min(in_height_coords, dim=-1)
326 |
327 | # Get left and right edges
328 | in_width, _ = torch.max(masks, dim=-2)
329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330 | right_edges, _ = torch.max(in_width_coords, dim=-1)
331 | in_width_coords = in_width_coords + w * (~in_width)
332 | left_edges, _ = torch.min(in_width_coords, dim=-1)
333 |
334 | # If the mask is empty the right edge will be to the left of the left edge.
335 | # Replace these boxes with [0, 0, 0, 0]
336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338 | out = out * (~empty_filter).unsqueeze(-1)
339 |
340 | # Return to original shape
341 | if len(shape) > 2:
342 | out = out.reshape(*shape[:-2], 4)
343 | else:
344 | out = out[0]
345 |
346 | return out
347 |
--------------------------------------------------------------------------------
/pre_npz.py:
--------------------------------------------------------------------------------
1 | # %% import packages
2 | import numpy as np
3 | import os
4 | from glob import glob
5 | import pandas as pd
6 | import cv2
7 | join = os.path.join
8 | from skimage import transform, io, segmentation
9 | from tqdm import tqdm
10 | import torch
11 | from segment_anything import sam_model_registry
12 | from segment_anything.utils.transforms import ResizeLongestSide
13 | import argparse
14 | from PIL import Image
15 |
16 | # set up the parser
17 | parser = argparse.ArgumentParser(description="preprocess grey and RGB images")
18 |
19 | # add arguments to the parser
20 | parser.add_argument(
21 | "-i",
22 | "--img_path",
23 | type=str,
24 | default="data/TrainDataset/Imgs",
25 | help="path to the images",
26 | )
27 | parser.add_argument(
28 | "-gt",
29 | "--gt_path",
30 | type=str,
31 | default="data/TrainDataset/GT",
32 | help="path to the ground truth (gt)",
33 | )
34 |
35 | parser.add_argument(
36 | "-depth",
37 | "--depth_path",
38 | type=str,
39 | default="data/TrainDataset/depth",
40 | help="path to the ground truth (gt)",
41 | )
42 |
43 | parser.add_argument(
44 | "--csv",
45 | type=str,
46 | default=None,
47 | help="path to the csv file",
48 | )
49 |
50 | parser.add_argument(
51 | "-o",
52 | "--npz_path",
53 | type=str,
54 | default="data/npz",
55 | help="path to save the npz files",
56 | )
57 | parser.add_argument(
58 | "--data_name",
59 | type=str,
60 | default="COD_Test_CAMO",
61 | help="dataset name; used to name the final npz file, e.g., demo2d.npz",
62 | )
63 | parser.add_argument("--image_size", type=int, default=256, help="image size")
64 | # parser.add_argument("--depth_size", type=int, default=256, help="depth image size")
65 | parser.add_argument(
66 | "--img_name_suffix", type=str, default=".jpg", help="image name suffix"
67 | )
68 | parser.add_argument("--label_id", type=int, default=255, help="label id")
69 | parser.add_argument("--model_type", type=str, default="vit_b", help="model type")
70 | parser.add_argument(
71 | "--checkpoint",
72 | type=str,
73 | default="work_dir/SAM/sam_vit_b_01ec64.pth",
74 | help="checkpoint",
75 | )
76 | parser.add_argument("--device", type=str, default="cuda:2", help="device")
77 | parser.add_argument("--seed", type=int, default=2023, help="random seed")
78 |
79 | # parse the arguments
80 | args = parser.parse_args()
81 |
82 | # convert 2d grey or rgb images to npz file
83 | imgs = []
84 | gts = []
85 | depth_imgs = []
86 | number = []
87 | boundary = []
88 | img_embeddings = []
89 | depth_embeddings = []
90 | global num_of_processed_imgs
91 | num_of_processed_imgs = 0
92 |
93 |
94 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(
95 | args.device
96 | )
97 | # create a directory to save the npz files
98 | save_path = args.npz_path + "_" + args.model_type
99 | os.makedirs(save_path, exist_ok=True)
100 |
101 |
102 | def find_bundary(img, mask, path, name):
103 | mask = mask * 255 # convert to 0-255
104 | kernel = np.ones((3, 3), dtype=np.uint8)
105 | fore = cv2.dilate(mask, kernel, 3)
106 | dilate = (fore - mask)
107 | kernel_again = np.ones((5, 5), dtype=np.uint8)
108 | dilate_again = cv2.dilate(dilate, kernel_again, 3)
109 |
110 | edges = cv2.Canny(img, 0.2, 0.6)
111 | # edges[edges > 0] = 1
112 | # edges = edges.astype(np.uint8)/255.0
113 | os.makedirs(os.path.join(path + '/dilate/'), exist_ok=True)
114 | os.makedirs(os.path.join(path + '/boundary/'), exist_ok=True)
115 | os.makedirs(os.path.join(path + '/canny/'), exist_ok=True)
116 | cv2.imwrite(os.path.join(path + '/dilate/', name), fore)
117 | cv2.imwrite(os.path.join(path + '/boundary/', name), dilate_again)
118 | cv2.imwrite(os.path.join(path + '/canny/', name), edges)
119 |
120 | edges_2 = Image.open(os.path.join(os.path.join(path + '/canny/', name))).convert('1')
121 | dilate_again_2 = Image.open(os.path.join(os.path.join(path + '/boundary/', name))).convert('1')
122 |
123 | boundary_grads = (Image.fromarray(np.array(edges_2) * np.array(dilate_again_2)))
124 | os.makedirs(os.path.join(path + '/boundary_grads/'), exist_ok=True)
125 | boundary_grads.save(os.path.join(os.path.join(path + '/boundary_grads/', gt_name)))
126 |
127 | return boundary_grads
128 |
129 | def process(gt_name: str, image_name: str, num_of_processed_imgs:int):
130 | if image_name == None:
131 | image_name = gt_name.split(".")[0] + args.img_name_suffix # Find the name of images based on the name of GT
132 | gt_data = io.imread(join(args.gt_path, gt_name))
133 |
134 | # if it is rgb, select the first channel
135 | if len(gt_data.shape) == 3:
136 | gt_data = gt_data[:, :, 0]
137 | assert len(gt_data.shape) == 2, "ground truth should be 2D"
138 |
139 | # resize ground truch image
140 | gt_data = transform.resize(
141 | gt_data == args.label_id,
142 | (args.image_size, args.image_size),
143 | order=0,
144 | preserve_range=True,
145 | mode="constant",
146 | )
147 |
148 | # convert to uint8
149 | gt_data = np.uint8(gt_data)
150 |
151 |
152 | if np.sum(gt_data) > 0: # don't exclude tiny objects(Polyps may be small in shape and there are tiny objects in the COD)
153 | """Optional binary thresholding can be added"""
154 | assert (
155 | np.max(gt_data) == 1 and np.unique(gt_data).shape[0] == 2
156 | ), "ground truth should be binary"
157 | # image preprocessing
158 | image_data = io.imread(join(args.img_path, image_name))
159 | # Remove any alpha channel if present.
160 | if image_data.shape[-1] > 3 and len(image_data.shape) == 3:
161 | image_data = image_data[:, :, :3]
162 | # If image is grayscale, then repeat the last channel to convert to rgb
163 | if len(image_data.shape) == 2:
164 | image_data = np.repeat(image_data[:, :, None], 3, axis=-1)
165 | # nii preprocess start
166 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(
167 | image_data, 99.5
168 | )
169 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
170 | # min-max normalize and scale
171 | image_data_pre = (
172 | (image_data_pre - np.min(image_data_pre))
173 | / (np.max(image_data_pre) - np.min(image_data_pre))
174 | * 255.0
175 | )
176 | image_data_pre[image_data == 0] = 0
177 |
178 | image_data_pre = transform.resize(
179 | image_data_pre,
180 | (args.image_size, args.image_size),
181 | order=3,
182 | preserve_range=True,
183 | mode="constant",
184 | anti_aliasing=True,
185 | )
186 | image_data_pre = np.uint8(image_data_pre)
187 |
188 | imgs.append(image_data_pre)
189 | # End of image preprocessing
190 |
191 | # depth image preprocessing
192 | depth_data = io.imread(join(args.depth_path, gt_name))
193 |
194 | gt_data_pre = transform.resize(
195 | gt_data,
196 | (args.image_size, args.image_size),
197 | order=1,
198 | preserve_range=True,
199 | mode="constant",
200 | anti_aliasing=True,
201 | )
202 | depth_data_pre = transform.resize(
203 | depth_data,
204 | (args.image_size, args.image_size),
205 | order=1,
206 | preserve_range=True,
207 | mode="constant",
208 | anti_aliasing=True,
209 | )
210 | y_indices, x_indices = np.where(gt_data_pre > 0)
211 | x_min, x_max = np.min(x_indices), np.max(x_indices)
212 | y_min, y_max = np.min(y_indices), np.max(y_indices)
213 | # add perturbation to bounding box coordinates
214 | H, W = args.image_size, args.image_size
215 | x_min = max(0, x_min - int(10))
216 | x_max = min(W, x_max + int(10))
217 | y_min = max(0, y_min - int(10))
218 | y_max = min(H, y_max + int(10))
219 |
220 | depth_data = depth_data_pre[y_min:y_max, x_min:x_max]
221 | os.makedirs(os.path.join(save_path+'/'+args.data_name + '/depth_crop/'), exist_ok=True)
222 | cv2.imwrite(os.path.join(save_path+'/'+args.data_name + '/depth_crop/', gt_name), depth_data)
223 | # Remove any alpha channel if present.
224 | if depth_data.shape[-1] > 3 and len(depth_data.shape) == 3:
225 | depth_data = depth_data[:, :, :3]
226 | # If image is grayscale, then repeat the last channel to convert to rgb
227 | if len(depth_data.shape) == 2:
228 | depth_data = np.repeat(depth_data[:, :, None], 3, axis=-1)
229 | # nii preprocess start
230 | lower_bound, upper_bound = np.percentile(depth_data, 0.5), np.percentile(
231 | depth_data, 99.5
232 | )
233 | depth_data_pre = np.clip(depth_data, lower_bound, upper_bound)
234 | # min-max normalize and scale
235 | depth_data_pre = (
236 | (depth_data_pre - np.min(depth_data_pre))
237 | / (np.max(depth_data_pre) - np.min(depth_data_pre))
238 | * 255.0
239 | )
240 | depth_data_pre[depth_data == 0] = 0
241 |
242 | depth_data_pre = transform.resize(
243 | depth_data_pre,
244 | (args.image_size, args.image_size),
245 | order=3,
246 | preserve_range=True,
247 | mode="constant",
248 | anti_aliasing=True,
249 | )
250 | depth_data_pre = np.uint8(depth_data_pre)
251 | depth_imgs.append(depth_data_pre)
252 |
253 | # End of depth image preprocessing
254 |
255 | number.append(image_name)
256 |
257 | print("the number of images: " + str(len(imgs)) + " and the name of image: " + str(image_name))
258 | num_of_processed_imgs = num_of_processed_imgs + 1
259 | assert np.sum(gt_data) > 0, "ground truth should have more than 0 pixels ,because of the tiny objects in COD"
260 |
261 | gts.append(gt_data)
262 | boundary.append(find_bundary(image_data_pre, gt_data, save_path+'/'+args.data_name, gt_name))
263 | print("the boundary_grad is produced now!!")
264 |
265 | # img -->img embedding
266 | # resize image to 3*1024*1024
267 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
268 | resize_img = sam_transform.apply_image(image_data_pre)
269 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(
270 | args.device
271 | )
272 | input_image = sam_model.preprocess(
273 | resize_img_tensor[None, :, :, :]
274 | ) # (1, 3, 1024, 1024)
275 | assert input_image.shape == (
276 | 1,
277 | 3,
278 | sam_model.image_encoder.img_size,
279 | sam_model.image_encoder.img_size,
280 | ), "input image should be resized to 1024*1024"
281 | # pre-compute the image embedding
282 | with torch.no_grad():
283 | embedding = sam_model.image_encoder(input_image)
284 | img_embeddings.append(embedding.cpu().numpy()[0])
285 | # end of img -->img embedding
286 |
287 | # depth_img -->depth_img embedding
288 | # resize image to 3*1024*1024
289 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
290 | resize_depth_img = sam_transform.apply_image(depth_data_pre)
291 | resize_depth_img_tensor = torch.as_tensor(resize_depth_img.transpose(2, 0, 1)).to(
292 | args.device
293 | )
294 | input_depth_image = sam_model.preprocess(
295 | resize_depth_img_tensor[None, :, :, :]
296 | ) # (1, 3, 1024, 1024)
297 | assert input_depth_image.shape == (
298 | 1,
299 | 3,
300 | sam_model.image_encoder.img_size,
301 | sam_model.image_encoder.img_size,
302 | ), "input depth image should be resized to 1024*1024"
303 | # pre-compute the image embedding
304 | with torch.no_grad():
305 | depth_embedding = sam_model.image_encoder(input_depth_image)
306 | depth_embeddings.append(depth_embedding.cpu().numpy()[0])
307 | # end of depth_img -->depth_img embedding
308 |
309 | return num_of_processed_imgs
310 |
311 |
312 | if args.csv != None:
313 | # if data is presented in csv format
314 | # columns must be named image_filename and mask_filename respectively
315 | try:
316 | os.path.exists(args.csv)
317 | except FileNotFoundError as e:
318 | print(f"File {args.csv} not found!!")
319 |
320 | df = pd.read_csv(args.csv)
321 | bar = tqdm(df.iterrows(), total=len(df))
322 | for idx, row in bar:
323 | process(row.mask_filename, row.image_filename)
324 |
325 | else:
326 | # get all the names of the images in the ground truth folder
327 | # names = sorted(os.listdir(args.gt_path), key=lambda x: int(x.split('.')[0])) # Sort by numeric size of filenames, but not all filenames are numeric
328 | names = sorted(os.listdir(args.gt_path))
329 | # print the number of images found in the ground truth folder
330 | print("image number:", len(names))
331 | for gt_name in tqdm(names):
332 | num_of_processed_imgs = process(gt_name, None, num_of_processed_imgs)
333 | print("the number of processed images: " + str(num_of_processed_imgs))
334 |
335 |
336 |
337 | # stack the list to array
338 | print("Num. of images:", len(imgs))
339 | if len(imgs) > 1:
340 | # The features are stacked and become embedding
341 | imgs = np.stack(imgs, axis=0) # (n, 256, 256, 3)
342 | gts = np.stack(gts, axis=0) # (n, 256, 256)
343 | depth_imgs = np.stack(depth_imgs, axis=0) # (n, 256, 256)
344 | img_embeddings = np.stack(img_embeddings, axis=0) # (n, 1, 256, 64, 64)
345 | depth_embeddings = np.stack(depth_embeddings, axis=0) # (n, 1, 256, 64, 64)
346 | boundary = np.stack(boundary, axis=0) # (n, 256, 256)
347 | np.savez_compressed(
348 | join(save_path, args.data_name + ".npz"),
349 | imgs=imgs,
350 | gts=gts,
351 | depth_imgs=depth_imgs,
352 | number=number,
353 | img_embeddings=img_embeddings,
354 | boundary=boundary,
355 | depth_embeddings=depth_embeddings,
356 | )
357 | # save an example image for sanity check
358 | idx = np.random.randint(imgs.shape[0])
359 | img_idx = imgs[idx, :, :, :]
360 | gt_idx = gts[idx, :, :]
361 | bd = segmentation.find_boundaries(gt_idx, mode="inner")
362 | img_idx[bd, :] = [255, 0, 0]
363 | io.imsave(save_path + ".png", img_idx, check_contrast=False)
364 | else:
365 | print(
366 | "Do not find image and ground-truth pairs. Please check your dataset and argument settings"
367 | )
--------------------------------------------------------------------------------
/segment_anything/modeling/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from typing import Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d, MLPBlock
14 |
15 |
16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17 | class ImageEncoderViT(nn.Module):
18 | def __init__(
19 | self,
20 | img_size: int = 1024,
21 | patch_size: int = 16,
22 | in_chans: int = 3,
23 | embed_dim: int = 768,
24 | depth: int = 12,
25 | num_heads: int = 12,
26 | mlp_ratio: float = 4.0,
27 | out_chans: int = 256,
28 | qkv_bias: bool = True,
29 | norm_layer: Type[nn.Module] = nn.LayerNorm,
30 | act_layer: Type[nn.Module] = nn.GELU,
31 | use_abs_pos: bool = True,
32 | use_rel_pos: bool = False,
33 | rel_pos_zero_init: bool = True,
34 | window_size: int = 0,
35 | global_attn_indexes: Tuple[int, ...] = (),
36 | ) -> None:
37 | """
38 | Args:
39 | img_size (int): Input image size.
40 | patch_size (int): Patch size.
41 | in_chans (int): Number of input image channels.
42 | embed_dim (int): Patch embedding dimension.
43 | depth (int): Depth of ViT.
44 | num_heads (int): Number of attention heads in each ViT block.
45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
47 | norm_layer (nn.Module): Normalization layer.
48 | act_layer (nn.Module): Activation layer.
49 | use_abs_pos (bool): If True, use absolute positional embeddings.
50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52 | window_size (int): Window size for window attention blocks.
53 | global_attn_indexes (list): Indexes for blocks using global attention.
54 | """
55 | super().__init__()
56 | self.img_size = img_size
57 |
58 | self.patch_embed = PatchEmbed(
59 | kernel_size=(patch_size, patch_size),
60 | stride=(patch_size, patch_size),
61 | in_chans=in_chans,
62 | embed_dim=embed_dim,
63 | )
64 |
65 | self.pos_embed: Optional[nn.Parameter] = None
66 | if use_abs_pos:
67 | # Initialize absolute positional embedding with pretrain image size.
68 | self.pos_embed = nn.Parameter(
69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70 | )
71 |
72 | self.blocks = nn.ModuleList()
73 | for i in range(depth):
74 | block = Block(
75 | dim=embed_dim,
76 | num_heads=num_heads,
77 | mlp_ratio=mlp_ratio,
78 | qkv_bias=qkv_bias,
79 | norm_layer=norm_layer,
80 | act_layer=act_layer,
81 | use_rel_pos=use_rel_pos,
82 | rel_pos_zero_init=rel_pos_zero_init,
83 | window_size=window_size if i not in global_attn_indexes else 0,
84 | input_size=(img_size // patch_size, img_size // patch_size),
85 | )
86 | self.blocks.append(block)
87 |
88 | self.neck = nn.Sequential(
89 | nn.Conv2d(
90 | embed_dim,
91 | out_chans,
92 | kernel_size=1,
93 | bias=False,
94 | ),
95 | LayerNorm2d(out_chans),
96 | nn.Conv2d(
97 | out_chans,
98 | out_chans,
99 | kernel_size=3,
100 | padding=1,
101 | bias=False,
102 | ),
103 | LayerNorm2d(out_chans),
104 | )
105 |
106 | def forward(self, x: torch.Tensor) -> torch.Tensor:
107 | x = self.patch_embed(x)
108 | if self.pos_embed is not None:
109 | x = x + self.pos_embed
110 |
111 | for blk in self.blocks:
112 | x = blk(x)
113 |
114 | x = self.neck(x.permute(0, 3, 1, 2))
115 |
116 | return x
117 |
118 |
119 | class Block(nn.Module):
120 | """Transformer blocks with support of window attention and residual propagation blocks"""
121 |
122 | def __init__(
123 | self,
124 | dim: int,
125 | num_heads: int,
126 | mlp_ratio: float = 4.0,
127 | qkv_bias: bool = True,
128 | norm_layer: Type[nn.Module] = nn.LayerNorm,
129 | act_layer: Type[nn.Module] = nn.GELU,
130 | use_rel_pos: bool = False,
131 | rel_pos_zero_init: bool = True,
132 | window_size: int = 0,
133 | input_size: Optional[Tuple[int, int]] = None,
134 | ) -> None:
135 | """
136 | Args:
137 | dim (int): Number of input channels.
138 | num_heads (int): Number of attention heads in each ViT block.
139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
140 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
141 | norm_layer (nn.Module): Normalization layer.
142 | act_layer (nn.Module): Activation layer.
143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
145 | window_size (int): Window size for window attention blocks. If it equals 0, then
146 | use global attention.
147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
148 | positional parameter size.
149 | """
150 | super().__init__()
151 | self.norm1 = norm_layer(dim)
152 | self.attn = Attention(
153 | dim,
154 | num_heads=num_heads,
155 | qkv_bias=qkv_bias,
156 | use_rel_pos=use_rel_pos,
157 | rel_pos_zero_init=rel_pos_zero_init,
158 | input_size=input_size if window_size == 0 else (window_size, window_size),
159 | )
160 |
161 | self.norm2 = norm_layer(dim)
162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
163 |
164 | self.window_size = window_size
165 |
166 | def forward(self, x: torch.Tensor) -> torch.Tensor:
167 | shortcut = x
168 | x = self.norm1(x)
169 | # Window partition
170 | if self.window_size > 0:
171 | H, W = x.shape[1], x.shape[2]
172 | x, pad_hw = window_partition(x, self.window_size)
173 |
174 | x = self.attn(x)
175 | # Reverse window partition
176 | if self.window_size > 0:
177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
178 |
179 | x = shortcut + x
180 | x = x + self.mlp(self.norm2(x))
181 |
182 | return x
183 |
184 |
185 | class Attention(nn.Module):
186 | """Multi-head Attention block with relative position embeddings."""
187 |
188 | def __init__(
189 | self,
190 | dim: int,
191 | num_heads: int = 8,
192 | qkv_bias: bool = True,
193 | use_rel_pos: bool = False,
194 | rel_pos_zero_init: bool = True,
195 | input_size: Optional[Tuple[int, int]] = None,
196 | ) -> None:
197 | """
198 | Args:
199 | dim (int): Number of input channels.
200 | num_heads (int): Number of attention heads.
201 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
202 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
205 | positional parameter size.
206 | """
207 | super().__init__()
208 | self.num_heads = num_heads
209 | head_dim = dim // num_heads
210 | self.scale = head_dim**-0.5
211 |
212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213 | self.proj = nn.Linear(dim, dim)
214 |
215 | self.use_rel_pos = use_rel_pos
216 | if self.use_rel_pos:
217 | assert (
218 | input_size is not None
219 | ), "Input size must be provided if using relative positional encoding."
220 | # initialize relative positional embeddings
221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
223 |
224 | def forward(self, x: torch.Tensor) -> torch.Tensor:
225 | B, H, W, _ = x.shape
226 | # qkv with shape (3, B, nHead, H * W, C)
227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
228 | # q, k, v with shape (B * nHead, H * W, C)
229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
230 |
231 | attn = (q * self.scale) @ k.transpose(-2, -1)
232 |
233 | if self.use_rel_pos:
234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
235 |
236 | attn = attn.softmax(dim=-1)
237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
238 | x = self.proj(x)
239 |
240 | return x
241 |
242 |
243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
244 | """
245 | Partition into non-overlapping windows with padding if needed.
246 | Args:
247 | x (tensor): input tokens with [B, H, W, C].
248 | window_size (int): window size.
249 |
250 | Returns:
251 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
252 | (Hp, Wp): padded height and width before partition
253 | """
254 | B, H, W, C = x.shape
255 |
256 | pad_h = (window_size - H % window_size) % window_size
257 | pad_w = (window_size - W % window_size) % window_size
258 | if pad_h > 0 or pad_w > 0:
259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
260 | Hp, Wp = H + pad_h, W + pad_w
261 |
262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
264 | return windows, (Hp, Wp)
265 |
266 |
267 | def window_unpartition(
268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
269 | ) -> torch.Tensor:
270 | """
271 | Window unpartition into original sequences and removing padding.
272 | Args:
273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
274 | window_size (int): window size.
275 | pad_hw (Tuple): padded height and width (Hp, Wp).
276 | hw (Tuple): original height and width (H, W) before padding.
277 |
278 | Returns:
279 | x: unpartitioned sequences with [B, H, W, C].
280 | """
281 | Hp, Wp = pad_hw
282 | H, W = hw
283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286 |
287 | if Hp > H or Wp > W:
288 | x = x[:, :H, :W, :].contiguous()
289 | return x
290 |
291 |
292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293 | """
294 | Get relative positional embeddings according to the relative positions of
295 | query and key sizes.
296 | Args:
297 | q_size (int): size of query q.
298 | k_size (int): size of key k.
299 | rel_pos (Tensor): relative position embeddings (L, C).
300 |
301 | Returns:
302 | Extracted positional embeddings according to relative positions.
303 | """
304 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
305 | # Interpolate rel pos if needed.
306 | if rel_pos.shape[0] != max_rel_dist:
307 | # Interpolate rel pos.
308 | rel_pos_resized = F.interpolate(
309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
310 | size=max_rel_dist,
311 | mode="linear",
312 | )
313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
314 | else:
315 | rel_pos_resized = rel_pos
316 |
317 | # Scale the coords with short length if shapes for q and k are different.
318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
321 |
322 | return rel_pos_resized[relative_coords.long()]
323 |
324 |
325 | def add_decomposed_rel_pos(
326 | attn: torch.Tensor,
327 | q: torch.Tensor,
328 | rel_pos_h: torch.Tensor,
329 | rel_pos_w: torch.Tensor,
330 | q_size: Tuple[int, int],
331 | k_size: Tuple[int, int],
332 | ) -> torch.Tensor:
333 | """
334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
336 | Args:
337 | attn (Tensor): attention map.
338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
343 |
344 | Returns:
345 | attn (Tensor): attention map with added relative positional embeddings.
346 | """
347 | q_h, q_w = q_size
348 | k_h, k_w = k_size
349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
351 |
352 | B, _, dim = q.shape
353 | r_q = q.reshape(B, q_h, q_w, dim)
354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
356 |
357 | attn = (
358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
359 | ).view(B, q_h * q_w, k_h * k_w)
360 |
361 | return attn
362 |
363 |
364 | class PatchEmbed(nn.Module):
365 | """
366 | Image to Patch Embedding.
367 | """
368 |
369 | def __init__(
370 | self,
371 | kernel_size: Tuple[int, int] = (16, 16),
372 | stride: Tuple[int, int] = (16, 16),
373 | padding: Tuple[int, int] = (0, 0),
374 | in_chans: int = 3,
375 | embed_dim: int = 768,
376 | ) -> None:
377 | """
378 | Args:
379 | kernel_size (Tuple): kernel size of the projection layer.
380 | stride (Tuple): stride of the projection layer.
381 | padding (Tuple): padding size of the projection layer.
382 | in_chans (int): Number of input image channels.
383 | embed_dim (int): Patch embedding dimension.
384 | """
385 | super().__init__()
386 |
387 | self.proj = nn.Conv2d(
388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
389 | )
390 |
391 | def forward(self, x: torch.Tensor) -> torch.Tensor:
392 | x = self.proj(x)
393 | # B C H W -> B H W C
394 | x = x.permute(0, 2, 3, 1)
395 | return x
396 |
--------------------------------------------------------------------------------
/segment_anything/automatic_mask_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10 |
11 | from typing import Any, Dict, List, Optional, Tuple
12 |
13 | from .modeling import Sam
14 | from .predictor import SamPredictor
15 | from .utils.amg import (
16 | MaskData,
17 | area_from_rle,
18 | batch_iterator,
19 | batched_mask_to_box,
20 | box_xyxy_to_xywh,
21 | build_all_layer_point_grids,
22 | calculate_stability_score,
23 | coco_encode_rle,
24 | generate_crop_boxes,
25 | is_box_near_crop_edge,
26 | mask_to_rle_pytorch,
27 | remove_small_regions,
28 | rle_to_mask,
29 | uncrop_boxes_xyxy,
30 | uncrop_masks,
31 | uncrop_points,
32 | )
33 |
34 |
35 | class SamAutomaticMaskGenerator:
36 | def __init__(
37 | self,
38 | model: Sam,
39 | points_per_side: Optional[int] = 32,
40 | points_per_batch: int = 64,
41 | pred_iou_thresh: float = 0.88,
42 | stability_score_thresh: float = 0.95,
43 | stability_score_offset: float = 1.0,
44 | box_nms_thresh: float = 0.7,
45 | crop_n_layers: int = 0,
46 | crop_nms_thresh: float = 0.7,
47 | crop_overlap_ratio: float = 512 / 1500,
48 | crop_n_points_downscale_factor: int = 1,
49 | point_grids: Optional[List[np.ndarray]] = None,
50 | min_mask_region_area: int = 0,
51 | output_mode: str = "binary_mask",
52 | ) -> None:
53 | """
54 | Using a SAM model, generates masks for the entire image.
55 | Generates a grid of point prompts over the image, then filters
56 | low quality and duplicate masks. The default settings are chosen
57 | for SAM with a ViT-H backbone.
58 |
59 | Arguments:
60 | model (Sam): The SAM model to use for mask prediction.
61 | points_per_side (int or None): The number of points to be sampled
62 | along one side of the image. The total number of points is
63 | points_per_side**2. If None, 'point_grids' must provide explicit
64 | point sampling.
65 | points_per_batch (int): Sets the number of points run simultaneously
66 | by the model. Higher numbers may be faster but use more GPU memory.
67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the
68 | model's predicted mask quality.
69 | stability_score_thresh (float): A filtering threshold in [0,1], using
70 | the stability of the mask under changes to the cutoff used to binarize
71 | the model's mask predictions.
72 | stability_score_offset (float): The amount to shift the cutoff when
73 | calculated the stability score.
74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal
75 | suppression to filter duplicate masks.
76 | crop_n_layers (int): If >0, mask prediction will be run again on
77 | crops of the image. Sets the number of layers to run, where each
78 | layer has 2**i_layer number of image crops.
79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80 | suppression to filter duplicate masks between different crops.
81 | crop_overlap_ratio (float): Sets the degree to which crops overlap.
82 | In the first crop layer, crops will overlap by this fraction of
83 | the image length. Later layers with more crops scale down this overlap.
84 | crop_n_points_downscale_factor (int): The number of points-per-side
85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86 | point_grids (list(np.ndarray) or None): A list over explicit grids
87 | of points used for sampling, normalized to [0,1]. The nth grid in the
88 | list is used in the nth crop layer. Exclusive with points_per_side.
89 | min_mask_region_area (int): If >0, postprocessing will be applied
90 | to remove disconnected regions and holes in masks with area smaller
91 | than min_mask_region_area. Requires opencv.
92 | output_mode (str): The form masks are returned in. Can be 'binary_mask',
93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94 | For large resolutions, 'binary_mask' may consume large amounts of
95 | memory.
96 | """
97 |
98 | assert (points_per_side is None) != (
99 | point_grids is None
100 | ), "Exactly one of points_per_side or point_grid must be provided."
101 | if points_per_side is not None:
102 | self.point_grids = build_all_layer_point_grids(
103 | points_per_side,
104 | crop_n_layers,
105 | crop_n_points_downscale_factor,
106 | )
107 | elif point_grids is not None:
108 | self.point_grids = point_grids
109 | else:
110 | raise ValueError("Can't have both points_per_side and point_grid be None.")
111 |
112 | assert output_mode in [
113 | "binary_mask",
114 | "uncompressed_rle",
115 | "coco_rle",
116 | ], f"Unknown output_mode {output_mode}."
117 | if output_mode == "coco_rle":
118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119 |
120 | if min_mask_region_area > 0:
121 | import cv2 # type: ignore # noqa: F401
122 |
123 | self.predictor = SamPredictor(model)
124 | self.points_per_batch = points_per_batch
125 | self.pred_iou_thresh = pred_iou_thresh
126 | self.stability_score_thresh = stability_score_thresh
127 | self.stability_score_offset = stability_score_offset
128 | self.box_nms_thresh = box_nms_thresh
129 | self.crop_n_layers = crop_n_layers
130 | self.crop_nms_thresh = crop_nms_thresh
131 | self.crop_overlap_ratio = crop_overlap_ratio
132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133 | self.min_mask_region_area = min_mask_region_area
134 | self.output_mode = output_mode
135 |
136 | @torch.no_grad()
137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138 | """
139 | Generates masks for the given image.
140 |
141 | Arguments:
142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143 |
144 | Returns:
145 | list(dict(str, any)): A list over records for masks. Each record is
146 | a dict containing the following keys:
147 | segmentation (dict(str, any) or np.ndarray): The mask. If
148 | output_mode='binary_mask', is an array of shape HW. Otherwise,
149 | is a dictionary containing the RLE.
150 | bbox (list(float)): The box around the mask, in XYWH format.
151 | area (int): The area in pixels of the mask.
152 | predicted_iou (float): The model's own prediction of the mask's
153 | quality. This is filtered by the pred_iou_thresh parameter.
154 | point_coords (list(list(float))): The point coordinates input
155 | to the model to generate this mask.
156 | stability_score (float): A measure of the mask's quality. This
157 | is filtered on using the stability_score_thresh parameter.
158 | crop_box (list(float)): The crop of the image used to generate
159 | the mask, given in XYWH format.
160 | """
161 |
162 | # Generate masks
163 | mask_data = self._generate_masks(image)
164 |
165 | # Filter small disconnected regions and holes in masks
166 | if self.min_mask_region_area > 0:
167 | mask_data = self.postprocess_small_regions(
168 | mask_data,
169 | self.min_mask_region_area,
170 | max(self.box_nms_thresh, self.crop_nms_thresh),
171 | )
172 |
173 | # Encode masks
174 | if self.output_mode == "coco_rle":
175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176 | elif self.output_mode == "binary_mask":
177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178 | else:
179 | mask_data["segmentations"] = mask_data["rles"]
180 |
181 | # Write mask records
182 | curr_anns = []
183 | for idx in range(len(mask_data["segmentations"])):
184 | ann = {
185 | "segmentation": mask_data["segmentations"][idx],
186 | "area": area_from_rle(mask_data["rles"][idx]),
187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188 | "predicted_iou": mask_data["iou_preds"][idx].item(),
189 | "point_coords": [mask_data["points"][idx].tolist()],
190 | "stability_score": mask_data["stability_score"][idx].item(),
191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192 | }
193 | curr_anns.append(ann)
194 |
195 | return curr_anns
196 |
197 | def _generate_masks(self, image: np.ndarray) -> MaskData:
198 | orig_size = image.shape[:2]
199 | crop_boxes, layer_idxs = generate_crop_boxes(
200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio
201 | )
202 |
203 | # Iterate over image crops
204 | data = MaskData()
205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207 | data.cat(crop_data)
208 |
209 | # Remove duplicate masks between crops
210 | if len(crop_boxes) > 1:
211 | # Prefer masks from smaller crops
212 | scores = 1 / box_area(data["crop_boxes"])
213 | scores = scores.to(data["boxes"].device)
214 | keep_by_nms = batched_nms(
215 | data["boxes"].float(),
216 | scores,
217 | torch.zeros_like(data["boxes"][:, 0]), # categories
218 | iou_threshold=self.crop_nms_thresh,
219 | )
220 | data.filter(keep_by_nms)
221 |
222 | data.to_numpy()
223 | return data
224 |
225 | def _process_crop(
226 | self,
227 | image: np.ndarray,
228 | crop_box: List[int],
229 | crop_layer_idx: int,
230 | orig_size: Tuple[int, ...],
231 | ) -> MaskData:
232 | # Crop the image and calculate embeddings
233 | x0, y0, x1, y1 = crop_box
234 | cropped_im = image[y0:y1, x0:x1, :]
235 | cropped_im_size = cropped_im.shape[:2]
236 | self.predictor.set_image(cropped_im)
237 |
238 | # Get points for this crop
239 | points_scale = np.array(cropped_im_size)[None, ::-1]
240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale
241 |
242 | # Generate masks for this crop in batches
243 | data = MaskData()
244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246 | data.cat(batch_data)
247 | del batch_data
248 | self.predictor.reset_image()
249 |
250 | # Remove duplicates within this crop.
251 | keep_by_nms = batched_nms(
252 | data["boxes"].float(),
253 | data["iou_preds"],
254 | torch.zeros_like(data["boxes"][:, 0]), # categories
255 | iou_threshold=self.box_nms_thresh,
256 | )
257 | data.filter(keep_by_nms)
258 |
259 | # Return to the original image frame
260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261 | data["points"] = uncrop_points(data["points"], crop_box)
262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263 |
264 | return data
265 |
266 | def _process_batch(
267 | self,
268 | points: np.ndarray,
269 | im_size: Tuple[int, ...],
270 | crop_box: List[int],
271 | orig_size: Tuple[int, ...],
272 | ) -> MaskData:
273 | orig_h, orig_w = orig_size
274 |
275 | # Run model on this batch
276 | transformed_points = self.predictor.transform.apply_coords(points, im_size)
277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279 | masks, iou_preds, _ = self.predictor.predict_torch(
280 | in_points[:, None, :],
281 | in_labels[:, None],
282 | multimask_output=True,
283 | return_logits=True,
284 | )
285 |
286 | # Serialize predictions and store in MaskData
287 | data = MaskData(
288 | masks=masks.flatten(0, 1),
289 | iou_preds=iou_preds.flatten(0, 1),
290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291 | )
292 | del masks
293 |
294 | # Filter by predicted IoU
295 | if self.pred_iou_thresh > 0.0:
296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh
297 | data.filter(keep_mask)
298 |
299 | # Calculate stability score
300 | data["stability_score"] = calculate_stability_score(
301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302 | )
303 | if self.stability_score_thresh > 0.0:
304 | keep_mask = data["stability_score"] >= self.stability_score_thresh
305 | data.filter(keep_mask)
306 |
307 | # Threshold masks and calculate boxes
308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309 | data["boxes"] = batched_mask_to_box(data["masks"])
310 |
311 | # Filter boxes that touch crop boundaries
312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313 | if not torch.all(keep_mask):
314 | data.filter(keep_mask)
315 |
316 | # Compress to RLE
317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318 | data["rles"] = mask_to_rle_pytorch(data["masks"])
319 | del data["masks"]
320 |
321 | return data
322 |
323 | @staticmethod
324 | def postprocess_small_regions(
325 | mask_data: MaskData, min_area: int, nms_thresh: float
326 | ) -> MaskData:
327 | """
328 | Removes small disconnected regions and holes in masks, then reruns
329 | box NMS to remove any new duplicates.
330 |
331 | Edits mask_data in place.
332 |
333 | Requires open-cv as a dependency.
334 | """
335 | if len(mask_data["rles"]) == 0:
336 | return mask_data
337 |
338 | # Filter small disconnected regions and holes
339 | new_masks = []
340 | scores = []
341 | for rle in mask_data["rles"]:
342 | mask = rle_to_mask(rle)
343 |
344 | mask, changed = remove_small_regions(mask, min_area, mode="holes")
345 | unchanged = not changed
346 | mask, changed = remove_small_regions(mask, min_area, mode="islands")
347 | unchanged = unchanged and not changed
348 |
349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350 | # Give score=0 to changed masks and score=1 to unchanged masks
351 | # so NMS will prefer ones that didn't need postprocessing
352 | scores.append(float(unchanged))
353 |
354 | # Recalculate boxes and remove any new duplicates
355 | masks = torch.cat(new_masks, dim=0)
356 | boxes = batched_mask_to_box(masks)
357 | keep_by_nms = batched_nms(
358 | boxes.float(),
359 | torch.as_tensor(scores),
360 | torch.zeros_like(boxes[:, 0]), # categories
361 | iou_threshold=nms_thresh,
362 | )
363 |
364 | # Only recalculate RLEs for masks that have changed
365 | for i_mask in keep_by_nms:
366 | if scores[i_mask] == 0.0:
367 | mask_torch = masks[i_mask].unsqueeze(0)
368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370 | mask_data.filter(keep_by_nms)
371 |
372 | return mask_data
373 |
--------------------------------------------------------------------------------
/MSCAF_COD_evaluation/sod_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage import convolve, distance_transform_edt as bwdist
3 |
4 | __version__ = '1.2.1'
5 |
6 | _EPS = 1e-16
7 | _TYPE = np.float64
8 |
9 |
10 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
11 | gt = gt > 128
12 | # im2double, mapminmax
13 | pred = pred / 255
14 | if pred.max() != pred.min():
15 | pred = (pred - pred.min()) / (pred.max() - pred.min())
16 | return pred, gt
17 |
18 |
19 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float:
20 | return min(2 * matrix.mean(), max_value)
21 |
22 |
23 | class Fmeasure(object):
24 | def __init__(self, beta: float = 0.3):
25 | self.beta = beta
26 | self.precisions = []
27 | self.recalls = []
28 | self.adaptive_fms = []
29 | self.changeable_fms = []
30 |
31 | def step(self, pred: np.ndarray, gt: np.ndarray):
32 | pred, gt = _prepare_data(pred, gt)
33 |
34 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt)
35 | self.adaptive_fms.append(adaptive_fm)
36 |
37 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt)
38 | self.precisions.append(precisions)
39 | self.recalls.append(recalls)
40 | self.changeable_fms.append(changeable_fms)
41 |
42 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float:
43 | # 快速统计numpy数组的非零值建议使用np.count_nonzero,
44 | # 一个简单的小实验可见tests/test_speed_for_count_nonzero.py
45 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
46 | binary_predcition = pred >= adaptive_threshold
47 | area_intersection = binary_predcition[gt].sum()
48 | if area_intersection == 0:
49 | adaptive_fm = 0
50 | else:
51 | pre = area_intersection / np.count_nonzero(binary_predcition)
52 | rec = area_intersection / np.count_nonzero(gt)
53 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec)
54 | return adaptive_fm
55 |
56 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple:
57 | # 1. 获取预测结果在真值前背景区域中的直方图
58 | pred = (pred * 255).astype(np.uint8)
59 | bins = np.linspace(0, 256, 257)
60 | fg_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256]
61 | bg_hist, _ = np.histogram(pred[~gt], bins=bins)
62 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量
63 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域
64 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0)
65 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0)
66 | # 3. 使用不同阈值的结果计算对应的precision和recall
67 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1
68 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量
69 | TPs = fg_w_thrs
70 | Ps = fg_w_thrs + bg_w_thrs
71 | # 为防止除0,这里针对除0的情况分析后直接对于0分母设为1,因为此时分子必为0
72 | Ps[Ps == 0] = 1
73 | T = max(np.count_nonzero(gt), 1)
74 | # TODO: T=0 或者 特定阈值下fg_w_thrs=0或者bg_w_thrs=0,这些都会包含在TPs[i]=0的情况中,
75 | # 但是这里使用TPs不便于处理列表
76 | # T=0 -> fg_w_thrs=[0, ..., 0] -> TPs=[0, ..., 0] 解决办法:T重新赋值为1
77 | # Ps[i] = 0 -> fg_w_thrs[i] = 0, bg_w_thrs[i] = 0
78 | precisions = TPs / Ps
79 | recalls = TPs / T
80 |
81 | numerator = (1 + self.beta) * precisions * recalls
82 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls)
83 | changeable_fms = numerator / denominator
84 | return precisions, recalls, changeable_fms
85 |
86 | def get_results(self) -> dict:
87 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE))
88 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0)
89 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256
90 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256
91 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm),
92 | pr=dict(p=precision, r=recall))
93 |
94 |
95 | class MAE(object):
96 | def __init__(self):
97 | self.maes = []
98 |
99 | def step(self, pred: np.ndarray, gt: np.ndarray):
100 | pred, gt = _prepare_data(pred, gt)
101 |
102 | mae = self.cal_mae(pred, gt)
103 | self.maes.append(mae)
104 |
105 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float:
106 | mae = np.mean(np.abs(pred - gt))
107 | return mae
108 |
109 | def get_results(self) -> dict:
110 | mae = np.mean(np.array(self.maes, _TYPE))
111 | return dict(mae=mae)
112 |
113 |
114 | class Smeasure(object):
115 | def __init__(self, alpha: float = 0.5):
116 | self.sms = []
117 | self.alpha = alpha
118 |
119 | def step(self, pred: np.ndarray, gt: np.ndarray):
120 | pred, gt = _prepare_data(pred=pred, gt=gt)
121 |
122 | sm = self.cal_sm(pred, gt)
123 | self.sms.append(sm)
124 |
125 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
126 | y = np.mean(gt)
127 | if y == 0:
128 | sm = 1 - np.mean(pred)
129 | elif y == 1:
130 | sm = np.mean(pred)
131 | else:
132 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
133 | sm = max(0, sm)
134 | return sm
135 |
136 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
137 | fg = pred * gt
138 | bg = (1 - pred) * (1 - gt)
139 | u = np.mean(gt)
140 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt)
141 | return object_score
142 |
143 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
144 | x = np.mean(pred[gt == 1])
145 | sigma_x = np.std(pred[gt == 1])
146 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS)
147 | return score
148 |
149 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
150 | x, y = self.centroid(gt)
151 | part_info = self.divide_with_xy(pred, gt, x, y)
152 | w1, w2, w3, w4 = part_info['weight']
153 | # assert np.isclose(w1 + w2 + w3 + w4, 1), (w1 + w2 + w3 + w4, pred.mean(), gt.mean())
154 |
155 | pred1, pred2, pred3, pred4 = part_info['pred']
156 | gt1, gt2, gt3, gt4 = part_info['gt']
157 | score1 = self.ssim(pred1, gt1)
158 | score2 = self.ssim(pred2, gt2)
159 | score3 = self.ssim(pred3, gt3)
160 | score4 = self.ssim(pred4, gt4)
161 |
162 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
163 |
164 | def centroid(self, matrix: np.ndarray) -> tuple:
165 | """
166 | 为了保证与matlab代码的一致性,这里对中心坐标进行了加一,在后面划分区域的时候就不用使用多余的加一操作
167 | 因为matlab里的1:X生成的序列会包含X这个值
168 | """
169 | h, w = matrix.shape
170 | if matrix.sum() == 0:
171 | x = np.round(w / 2)
172 | y = np.round(h / 2)
173 | else:
174 | area_object = np.sum(matrix)
175 | row_ids = np.arange(h)
176 | col_ids = np.arange(w)
177 | x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object)
178 | y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object)
179 | return int(x) + 1, int(y) + 1
180 |
181 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
182 | h, w = gt.shape
183 | area = h * w
184 |
185 | gt_LT = gt[0:y, 0:x]
186 | gt_RT = gt[0:y, x:w]
187 | gt_LB = gt[y:h, 0:x]
188 | gt_RB = gt[y:h, x:w]
189 |
190 | pred_LT = pred[0:y, 0:x]
191 | pred_RT = pred[0:y, x:w]
192 | pred_LB = pred[y:h, 0:x]
193 | pred_RB = pred[y:h, x:w]
194 |
195 | w1 = x * y / area
196 | w2 = y * (w - x) / area
197 | w3 = (h - y) * x / area
198 | # w4 = (h - y) * (w - x) / area
199 | w4 = 1 - w1 - w2 - w3
200 |
201 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB),
202 | pred=(pred_LT, pred_RT, pred_LB, pred_RB),
203 | weight=(w1, w2, w3, w4))
204 |
205 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
206 | h, w = pred.shape
207 | N = h * w
208 |
209 | x = np.mean(pred)
210 | y = np.mean(gt)
211 |
212 | sigma_x = np.sum((pred - x) ** 2) / (N - 1)
213 | sigma_y = np.sum((gt - y) ** 2) / (N - 1)
214 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1)
215 |
216 | alpha = 4 * x * y * sigma_xy
217 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
218 |
219 | if alpha != 0:
220 | score = alpha / (beta + _EPS)
221 | elif alpha == 0 and beta == 0:
222 | score = 1
223 | else:
224 | score = 0
225 | return score
226 |
227 | def get_results(self) -> dict:
228 | sm = np.mean(np.array(self.sms, dtype=_TYPE))
229 | return dict(sm=sm)
230 |
231 |
232 | class Emeasure(object):
233 | def __init__(self):
234 | self.adaptive_ems = []
235 | self.changeable_ems = []
236 |
237 | def step(self, pred: np.ndarray, gt: np.ndarray):
238 | pred, gt = _prepare_data(pred=pred, gt=gt)
239 | self.gt_fg_numel = np.count_nonzero(gt)
240 | self.gt_size = gt.shape[0] * gt.shape[1]
241 |
242 | changeable_ems = self.cal_changeable_em(pred, gt)
243 | self.changeable_ems.append(changeable_ems)
244 | adaptive_em = self.cal_adaptive_em(pred, gt)
245 | self.adaptive_ems.append(adaptive_em)
246 |
247 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float:
248 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
249 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold)
250 | return adaptive_em
251 |
252 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
253 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt)
254 | return changeable_ems
255 |
256 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
257 | """
258 | 函数内部变量命名规则:
259 | pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
260 | 如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
261 | """
262 | binarized_pred = pred >= threshold
263 | fg_fg_numel = np.count_nonzero(binarized_pred & gt)
264 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
265 |
266 | fg___numel = fg_fg_numel + fg_bg_numel
267 | bg___numel = self.gt_size - fg___numel
268 |
269 | if self.gt_fg_numel == 0:
270 | enhanced_matrix_sum = bg___numel
271 | elif self.gt_fg_numel == self.gt_size:
272 | enhanced_matrix_sum = fg___numel
273 | else:
274 | parts_numel, combinations = self.generate_parts_numel_combinations(
275 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
276 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
277 | )
278 |
279 | results_parts = []
280 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
281 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
282 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
283 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
284 | results_parts.append(enhanced_matrix_value * part_numel)
285 | enhanced_matrix_sum = sum(results_parts)
286 |
287 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
288 | return em
289 |
290 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
291 | """
292 | 函数内部变量命名规则:
293 | pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
294 | 如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
295 | """
296 | pred = (pred * 255).astype(np.uint8)
297 | bins = np.linspace(0, 256, 257)
298 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
299 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
300 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
301 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
302 |
303 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
304 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
305 |
306 | if self.gt_fg_numel == 0:
307 | enhanced_matrix_sum = bg___numel_w_thrs
308 | elif self.gt_fg_numel == self.gt_size:
309 | enhanced_matrix_sum = fg___numel_w_thrs
310 | else:
311 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
312 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
313 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
314 | )
315 |
316 | results_parts = np.empty(shape=(4, 256), dtype=np.float64)
317 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
318 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
319 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
320 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
321 | results_parts[i] = enhanced_matrix_value * part_numel
322 | enhanced_matrix_sum = results_parts.sum(axis=0)
323 |
324 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
325 | return em
326 |
327 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
328 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel
329 | bg_bg_numel = pred_bg_numel - bg_fg_numel
330 |
331 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
332 |
333 | mean_pred_value = pred_fg_numel / self.gt_size
334 | mean_gt_value = self.gt_fg_numel / self.gt_size
335 |
336 | demeaned_pred_fg_value = 1 - mean_pred_value
337 | demeaned_pred_bg_value = 0 - mean_pred_value
338 | demeaned_gt_fg_value = 1 - mean_gt_value
339 | demeaned_gt_bg_value = 0 - mean_gt_value
340 |
341 | combinations = [
342 | (demeaned_pred_fg_value, demeaned_gt_fg_value),
343 | (demeaned_pred_fg_value, demeaned_gt_bg_value),
344 | (demeaned_pred_bg_value, demeaned_gt_fg_value),
345 | (demeaned_pred_bg_value, demeaned_gt_bg_value)
346 | ]
347 | return parts_numel, combinations
348 |
349 | def get_results(self) -> dict:
350 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE))
351 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0)
352 | return dict(em=dict(adp=adaptive_em, curve=changeable_em))
353 |
354 |
355 | class WeightedFmeasure(object):
356 | def __init__(self, beta: float = 1):
357 | self.beta = beta
358 | self.weighted_fms = []
359 |
360 | def step(self, pred: np.ndarray, gt: np.ndarray):
361 | pred, gt = _prepare_data(pred=pred, gt=gt)
362 |
363 | if np.all(~gt):
364 | wfm = 0
365 | else:
366 | wfm = self.cal_wfm(pred, gt)
367 | self.weighted_fms.append(wfm)
368 |
369 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float:
370 | # [Dst,IDXT] = bwdist(dGT);
371 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
372 |
373 | # %Pixel dependency
374 | # E = abs(FG-dGT);
375 | E = np.abs(pred - gt)
376 | # Et = E;
377 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region
378 | Et = np.copy(E)
379 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
380 |
381 | # K = fspecial('gaussian',7,5);
382 | # EA = imfilter(Et,K);
383 | K = self.matlab_style_gauss2D((7, 7), sigma=5)
384 | EA = convolve(Et, weights=K, mode="constant", cval=0)
385 | # MIN_E_EA = E;
386 | # MIN_E_EA(GT & EA np.ndarray:
413 | """
414 | 2D gaussian mask - should give the same result as MATLAB's
415 | fspecial('gaussian',[shape],[sigma])
416 | """
417 | m, n = [(ss - 1) / 2 for ss in shape]
418 | y, x = np.ogrid[-m: m + 1, -n: n + 1]
419 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
420 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
421 | sumh = h.sum()
422 | if sumh != 0:
423 | h /= sumh
424 | return h
425 |
426 | def get_results(self) -> dict:
427 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE))
428 | return dict(wfm=weighted_fm)
429 |
--------------------------------------------------------------------------------