├── README.md ├── __init__.py ├── configs └── cod-sam-vit-l.yaml ├── datasets ├── __init__.py ├── datasets.py ├── image_folder.py └── wrappers.py ├── models ├── __init__.py ├── iou_loss.py ├── models.py ├── sam.py └── sammodel │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── requirements.txt ├── sod_metric.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RSAM-Seg 2 | Code for RSAM-Seg 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chief-byte/RSAM-Seg/7430868113de2f85d63a76911b69ab8085822b30/__init__.py -------------------------------------------------------------------------------- /configs/cod-sam-vit-l.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: "D:/datasets/DGROAD/archive/trainmod" 6 | root_path_2: "D:/datasets/DGROAD/archive/trainmodmask" 7 | cache: none 8 | split_key: train 9 | wrapper: 10 | name: train 11 | args: 12 | inp_size: 1024 13 | augment: false 14 | batch_size: 2 15 | 16 | val_dataset: 17 | dataset: 18 | name: paired-image-folders 19 | args: 20 | root_path_1: "D:/datasets/DGROAD/archive/valmod" 21 | root_path_2: "D:/datasets/DGROAD/archive/valmodmaskrename" 22 | cache: none 23 | split_key: test 24 | wrapper: 25 | name: val 26 | args: 27 | inp_size: 1024 28 | batch_size: 1 29 | 30 | test_dataset: 31 | dataset: 32 | name: paired-image-folders 33 | args: 34 | root_path_1: "D:/datasets/DGROAD/archive/valmod" 35 | root_path_2: "D:/datasets/DGROAD/archive/valmodmaskrename" 36 | cache: none 37 | split_key: test 38 | wrapper: 39 | name: val 40 | args: 41 | inp_size: 1024 42 | batch_size: 1 43 | 44 | eval_type: cod 45 | sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth 46 | data_norm: 47 | inp: 48 | sub: 49 | - 0.5 50 | div: 51 | - 0.5 52 | gt: 53 | sub: 54 | - 0.5 55 | div: 56 | - 0.5 57 | gt_rgb: 58 | sub: 59 | - 0.5 60 | div: 61 | - 0.5 62 | model: 63 | name: sam 64 | args: 65 | inp_size: 1024 66 | loss: iou 67 | encoder_mode: 68 | name: sam 69 | img_size: 1024 70 | mlp_ratio: 4 71 | patch_size: 16 72 | qkv_bias: true 73 | use_rel_pos: true 74 | window_size: 14 75 | out_chans: 256 76 | scale_factor: 32 77 | input_type: fft 78 | freq_nums: 0.25 79 | prompt_type: highpass 80 | prompt_embed_dim: 256 81 | tuning_stage: 1234 82 | handcrafted_tune: true 83 | embedding_tune: true 84 | adaptor: adaptor 85 | embed_dim: 1024 86 | depth: 24 87 | num_heads: 16 88 | global_attn_indexes: 89 | - 5 90 | - 11 91 | - 17 92 | - 23 93 | optimizer: 94 | name: adamw 95 | args: 96 | lr: 0.0002 97 | lr_min: 1.0e-7 98 | epoch_max: 120 99 | 100 | multi_step_lr: 101 | milestones: 102 | - 1 103 | gamma: 0.1 104 | epoch_val: 1 105 | epoch_save: 1 106 | 107 | #resume: 60 108 | #start_epoch: 60 109 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import register, make 2 | from . import image_folder 3 | from . import wrappers 4 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | datasets = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | datasets[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(dataset_spec, args=None): 15 | if args is not None: 16 | dataset_args = copy.deepcopy(dataset_spec['args']) 17 | dataset_args.update(args) 18 | else: 19 | dataset_args = dataset_spec['args'] 20 | dataset = datasets[dataset_spec['name']](**dataset_args) 21 | return dataset 22 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | 5 | import pickle 6 | import imageio 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | import random 12 | from datasets import register 13 | 14 | 15 | @register('image-folder') 16 | class ImageFolder(Dataset): 17 | def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, 18 | repeat=1, cache='none', mask=False): 19 | self.repeat = repeat 20 | self.cache = cache 21 | self.path = path 22 | self.Train = False 23 | self.split_key = split_key 24 | 25 | self.size = size 26 | self.mask = mask 27 | if self.mask: 28 | self.img_transform = transforms.Compose([ 29 | transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), 30 | transforms.ToTensor(), 31 | ]) 32 | else: 33 | self.img_transform = transforms.Compose([ 34 | transforms.Resize((self.size, self.size)), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | ]) 39 | 40 | if split_file is None: 41 | filenames = sorted(os.listdir(path)) 42 | else: 43 | with open(split_file, 'r') as f: 44 | filenames = json.load(f)[split_key] 45 | if first_k is not None: 46 | filenames = filenames[:first_k] 47 | 48 | self.files = [] 49 | 50 | for filename in filenames: 51 | file = os.path.join(path, filename) 52 | self.append_file(file) 53 | 54 | def append_file(self, file): 55 | if self.cache == 'none': 56 | self.files.append(file) 57 | elif self.cache == 'in_memory': 58 | self.files.append(self.img_process(file)) 59 | 60 | def __len__(self): 61 | return len(self.files) * self.repeat 62 | 63 | def __getitem__(self, idx): 64 | x = self.files[idx % len(self.files)] 65 | 66 | if self.cache == 'none': 67 | return self.img_process(x) 68 | elif self.cache == 'in_memory': 69 | return x 70 | 71 | def img_process(self, file): 72 | if self.mask: 73 | return Image.open(file).convert('L') 74 | else: 75 | return Image.open(file).convert('RGB') 76 | 77 | @register('paired-image-folders') 78 | class PairedImageFolders(Dataset): 79 | 80 | def __init__(self, root_path_1, root_path_2, **kwargs): 81 | self.dataset_1 = ImageFolder(root_path_1, **kwargs) 82 | self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True) 83 | 84 | def __len__(self): 85 | return len(self.dataset_1) 86 | 87 | def __getitem__(self, idx): 88 | return self.dataset_1[idx], self.dataset_2[idx] 89 | -------------------------------------------------------------------------------- /datasets/wrappers.py: -------------------------------------------------------------------------------- 1 | 2 | import functools 3 | import random 4 | import math 5 | from PIL import Image 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | import torchvision 12 | 13 | from datasets import register 14 | import cv2 15 | from math import pi 16 | from torchvision.transforms import InterpolationMode 17 | 18 | import torch.nn.functional as F 19 | def to_mask(mask): 20 | return transforms.ToTensor()( 21 | transforms.Grayscale(num_output_channels=1)( 22 | transforms.ToPILImage()(mask))) 23 | 24 | 25 | def resize_fn(img, size): 26 | return transforms.ToTensor()( 27 | transforms.Resize(size)( 28 | transforms.ToPILImage()(img))) 29 | 30 | 31 | @register('val') 32 | class ValDataset(Dataset): 33 | def __init__(self, dataset, inp_size=None, augment=False): 34 | self.dataset = dataset 35 | self.inp_size = inp_size 36 | self.augment = augment 37 | 38 | self.img_transform = transforms.Compose([ 39 | transforms.Resize((inp_size, inp_size)), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | ]) 44 | self.mask_transform = transforms.Compose([ 45 | transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST), 46 | transforms.ToTensor(), 47 | ]) 48 | 49 | def __len__(self): 50 | return len(self.dataset) 51 | 52 | def __getitem__(self, idx): 53 | img, mask = self.dataset[idx] 54 | 55 | return { 56 | 'inp': self.img_transform(img), 57 | 'gt': self.mask_transform(mask) 58 | } 59 | 60 | 61 | @register('train') 62 | class TrainDataset(Dataset): 63 | def __init__(self, dataset, size_min=None, size_max=None, inp_size=None, 64 | augment=False, gt_resize=None): 65 | self.dataset = dataset 66 | self.size_min = size_min 67 | if size_max is None: 68 | size_max = size_min 69 | self.size_max = size_max 70 | self.augment = augment 71 | self.gt_resize = gt_resize 72 | 73 | self.inp_size = inp_size 74 | self.img_transform = transforms.Compose([ 75 | transforms.Resize((self.inp_size, self.inp_size)), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 78 | std=[0.229, 0.224, 0.225]) 79 | ]) 80 | self.inverse_transform = transforms.Compose([ 81 | transforms.Normalize(mean=[0., 0., 0.], 82 | std=[1/0.229, 1/0.224, 1/0.225]), 83 | transforms.Normalize(mean=[-0.485, -0.456, -0.406], 84 | std=[1, 1, 1]) 85 | ]) 86 | self.mask_transform = transforms.Compose([ 87 | transforms.Resize((self.inp_size, self.inp_size)), 88 | transforms.ToTensor(), 89 | ]) 90 | 91 | def __len__(self): 92 | return len(self.dataset) 93 | 94 | def __getitem__(self, idx): 95 | img, mask = self.dataset[idx] 96 | 97 | # random filp 98 | if random.random() < 0.5: 99 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 100 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 101 | 102 | img = transforms.Resize((self.inp_size, self.inp_size))(img) 103 | mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask) 104 | 105 | return { 106 | 'inp': self.img_transform(img), 107 | 'gt': self.mask_transform(mask) 108 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import register, make 2 | from . import sam 3 | -------------------------------------------------------------------------------- /models/iou_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ################################################################### 6 | # ########################## iou loss ############################# 7 | ################################################################### 8 | class IOU(torch.nn.Module): 9 | def __init__(self): 10 | super(IOU, self).__init__() 11 | 12 | def _iou(self, pred, target): 13 | pred = torch.sigmoid(pred) 14 | inter = (pred * target).sum(dim=(2, 3)) 15 | union = (pred + target).sum(dim=(2, 3)) - inter 16 | iou = 1 - (inter / union) 17 | 18 | return iou.mean() 19 | 20 | def forward(self, pred, target): 21 | return self._iou(pred, target) 22 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | models = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | models[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(model_spec, args=None, load_sd=False): 15 | if args is not None: 16 | model_args = copy.deepcopy(model_spec['args']) 17 | model_args.update(args) 18 | else: 19 | model_args = model_spec['args'] 20 | model = models[model_spec['name']](**model_args) 21 | if load_sd: 22 | model.load_state_dict(model_spec['sd']) 23 | return model 24 | -------------------------------------------------------------------------------- /models/sam.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from models import register 10 | from models.sammodel import ImageEncoderViT, MaskDecoder, TwoWayTransformer 11 | 12 | logger = logging.getLogger(__name__) 13 | from .iou_loss import IOU 14 | from typing import Any, Optional, Tuple 15 | 16 | 17 | def init_weights(layer): 18 | if type(layer) == nn.Conv2d: 19 | nn.init.normal_(layer.weight, mean=0.0, std=0.02) 20 | nn.init.constant_(layer.bias, 0.0) 21 | elif type(layer) == nn.Linear: 22 | nn.init.normal_(layer.weight, mean=0.0, std=0.02) 23 | nn.init.constant_(layer.bias, 0.0) 24 | elif type(layer) == nn.BatchNorm2d: 25 | # print(layer) 26 | nn.init.normal_(layer.weight, mean=1.0, std=0.02) 27 | nn.init.constant_(layer.bias, 0.0) 28 | 29 | class BBCEWithLogitLoss(nn.Module): 30 | ''' 31 | Balanced BCEWithLogitLoss 32 | ''' 33 | def __init__(self): 34 | super(BBCEWithLogitLoss, self).__init__() 35 | 36 | def forward(self, pred, gt): 37 | eps = 1e-10 38 | count_pos = torch.sum(gt) + eps 39 | count_neg = torch.sum(1. - gt) 40 | ratio = count_neg / count_pos 41 | w_neg = count_pos / (count_pos + count_neg) 42 | 43 | bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) 44 | loss = w_neg * bce1(pred, gt) 45 | 46 | return loss 47 | 48 | def _iou_loss(pred, target): 49 | pred = torch.sigmoid(pred) 50 | inter = (pred * target).sum(dim=(2, 3)) 51 | union = (pred + target).sum(dim=(2, 3)) - inter 52 | iou = 1 - (inter / union) 53 | 54 | return iou.mean() 55 | # from prompt_encoder.py 56 | class PositionEmbeddingRandom(nn.Module): 57 | """ 58 | Positional encoding using random spatial frequencies. 59 | 60 | removed forward_with_coords which is 这个方法可以用于在对图像做处理时,对非归一化的点坐标进行位置编码,以便在后续的模型中使用 61 | """ 62 | 63 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 64 | super().__init__() 65 | if scale is None or scale <= 0.0: 66 | scale = 1.0 67 | self.register_buffer( 68 | "positional_encoding_gaussian_matrix", 69 | scale * torch.randn((2, num_pos_feats)), 70 | ) 71 | 72 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 73 | """Positionally encode points that are normalized to [0,1].""" 74 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 75 | coords = 2 * coords - 1 76 | coords = coords @ self.positional_encoding_gaussian_matrix 77 | coords = 2 * np.pi * coords 78 | # outputs d_1 x ... x d_n x C shape 79 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 80 | 81 | def forward(self, size: int) -> torch.Tensor: 82 | """Generate positional encoding for a grid of the specified size.""" 83 | h, w = size, size 84 | device: Any = self.positional_encoding_gaussian_matrix.device 85 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 86 | y_embed = grid.cumsum(dim=0) - 0.5 87 | x_embed = grid.cumsum(dim=1) - 0.5 88 | y_embed = y_embed / h 89 | x_embed = x_embed / w 90 | 91 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 92 | return pe.permute(2, 0, 1) # C x H x W 93 | 94 | 95 | @register('sam') 96 | class SAM(nn.Module): 97 | def __init__(self, inp_size=None, encoder_mode=None, loss=None): 98 | super().__init__() 99 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 100 | self.embed_dim = encoder_mode['embed_dim'] 101 | self.image_encoder = ImageEncoderViT( 102 | img_size=inp_size, 103 | patch_size=encoder_mode['patch_size'], 104 | in_chans=3, 105 | embed_dim=encoder_mode['embed_dim'], 106 | depth=encoder_mode['depth'], 107 | num_heads=encoder_mode['num_heads'], 108 | mlp_ratio=encoder_mode['mlp_ratio'], 109 | out_chans=encoder_mode['out_chans'], 110 | qkv_bias=encoder_mode['qkv_bias'], 111 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 112 | act_layer=nn.GELU, 113 | use_rel_pos=encoder_mode['use_rel_pos'], 114 | rel_pos_zero_init=True, 115 | window_size=encoder_mode['window_size'], 116 | global_attn_indexes=encoder_mode['global_attn_indexes'], 117 | ) 118 | self.prompt_embed_dim = encoder_mode['prompt_embed_dim'] 119 | self.mask_decoder = MaskDecoder( 120 | num_multimask_outputs=3, 121 | transformer=TwoWayTransformer( 122 | depth=2, 123 | embedding_dim=self.prompt_embed_dim, 124 | mlp_dim=2048, 125 | num_heads=8, 126 | ), 127 | transformer_dim=self.prompt_embed_dim, 128 | iou_head_depth=3, 129 | iou_head_hidden_dim=256, 130 | ) 131 | 132 | if 'evp' in encoder_mode['name']: 133 | for k, p in self.encoder.named_parameters(): 134 | if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k: 135 | p.requires_grad = False 136 | 137 | 138 | 139 | self.loss_mode = loss 140 | if self.loss_mode == 'bce': 141 | self.criterionBCE = torch.nn.BCEWithLogitsLoss() 142 | 143 | elif self.loss_mode == 'bbce': 144 | self.criterionBCE = BBCEWithLogitLoss() 145 | 146 | elif self.loss_mode == 'iou': 147 | self.criterionBCE = torch.nn.BCEWithLogitsLoss() 148 | self.criterionIOU = IOU() 149 | 150 | self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2) 151 | self.inp_size = inp_size 152 | self.image_embedding_size = inp_size // encoder_mode['patch_size'] 153 | self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim']) 154 | 155 | def set_input(self, input, gt_mask): 156 | self.input = input.to(self.device) 157 | self.gt_mask = gt_mask.to(self.device) 158 | 159 | def get_dense_pe(self) -> torch.Tensor: 160 | """ 161 | Returns the positional encoding used to encode point prompts, 162 | applied to a dense set of points the shape of the image encoding. 163 | 164 | Returns: 165 | torch.Tensor: Positional encoding with shape 166 | 1x(embed_dim)x(embedding_h)x(embedding_w) 167 | """ 168 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 169 | 170 | 171 | def forward(self): 172 | bs = 1 173 | 174 | # Embed prompts 175 | sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device) 176 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 177 | bs, -1, self.image_embedding_size, self.image_embedding_size 178 | ) 179 | 180 | self.features = self.image_encoder(self.input) 181 | 182 | # Predict masks 183 | low_res_masks, iou_predictions = self.mask_decoder( 184 | image_embeddings=self.features, 185 | image_pe=self.get_dense_pe(), 186 | sparse_prompt_embeddings=sparse_embeddings, 187 | dense_prompt_embeddings=dense_embeddings, 188 | multimask_output=False, 189 | ) 190 | 191 | # Upscale the masks to the original image resolution 192 | masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) 193 | self.pred_mask = masks 194 | 195 | def infer(self, input): 196 | bs = 1 197 | 198 | # Embed prompts 199 | sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device) 200 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 201 | bs, -1, self.image_embedding_size, self.image_embedding_size 202 | ) 203 | 204 | self.features = self.image_encoder(input) 205 | 206 | # Predict masks 207 | low_res_masks, iou_predictions = self.mask_decoder( 208 | image_embeddings=self.features, 209 | image_pe=self.get_dense_pe(), 210 | sparse_prompt_embeddings=sparse_embeddings, 211 | dense_prompt_embeddings=dense_embeddings, 212 | multimask_output=False, 213 | ) 214 | 215 | # Upscale the masks to the original image resolution 216 | masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) 217 | return masks 218 | 219 | def postprocess_masks( 220 | self, 221 | masks: torch.Tensor, 222 | input_size: Tuple[int, ...], 223 | original_size: Tuple[int, ...], 224 | ) -> torch.Tensor: 225 | """ 226 | Remove padding and upscale masks to the original image size. 227 | 228 | Arguments: 229 | masks (torch.Tensor): Batched masks from the mask_decoder, 230 | in BxCxHxW format. 231 | input_size (tuple(int, int)): The size of the image input to the 232 | model, in (H, W) format. Used to remove padding. 233 | original_size (tuple(int, int)): The original size of the image 234 | before resizing for input to the model, in (H, W) format. 235 | 236 | Returns: 237 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 238 | is given by original_size. 239 | """ 240 | masks = F.interpolate( 241 | masks, 242 | (self.image_encoder.img_size, self.image_encoder.img_size), 243 | mode="bilinear", 244 | align_corners=False, 245 | ) 246 | masks = masks[..., : input_size, : input_size] 247 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 248 | return masks 249 | 250 | def backward_G(self): 251 | """Calculate GAN and L1 loss for the generator""" 252 | self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) 253 | if self.loss_mode == 'iou': 254 | self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) 255 | 256 | self.loss_G.backward() 257 | 258 | def optimize_parameters(self): 259 | self.forward() 260 | self.optimizer.zero_grad() # set G's gradients to zero 261 | self.backward_G() # calculate graidents for G 262 | self.optimizer.step() # udpate G's weights 263 | 264 | def set_requires_grad(self, nets, requires_grad=False): 265 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 266 | Parameters: 267 | nets (network list) -- a list of networks 268 | requires_grad (bool) -- whether the networks require gradients or not 269 | """ 270 | if not isinstance(nets, list): 271 | nets = [nets] 272 | for net in nets: 273 | if net is not None: 274 | for param in net.parameters(): 275 | param.requires_grad = requires_grad 276 | -------------------------------------------------------------------------------- /models/sammodel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /models/sammodel/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): 14 | super().__init__() 15 | self.skip_connect = skip_connect 16 | D_hidden_features = int(D_features * mlp_ratio) 17 | self.act = act_layer() 18 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 19 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 20 | 21 | def forward(self, x): 22 | # x is (BT, HW+1, D) 23 | xs = self.D_fc1(x) 24 | xs = self.act(xs) 25 | xs = self.D_fc2(xs) 26 | if self.skip_connect: 27 | x = x + xs 28 | else: 29 | x = xs 30 | return x 31 | 32 | 33 | class MLPBlock(nn.Module): 34 | def __init__( 35 | self, 36 | embedding_dim: int, 37 | mlp_dim: int, 38 | act: Type[nn.Module] = nn.GELU, 39 | ) -> None: 40 | super().__init__() 41 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 42 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 43 | self.act = act() 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | return self.lin2(self.act(self.lin1(x))) 47 | 48 | 49 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 50 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 51 | class LayerNorm2d(nn.Module): 52 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 53 | super().__init__() 54 | self.weight = nn.Parameter(torch.ones(num_channels)) 55 | self.bias = nn.Parameter(torch.zeros(num_channels)) 56 | self.eps = eps 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | u = x.mean(1, keepdim=True) 60 | s = (x - u).pow(2).mean(1, keepdim=True) 61 | x = (x - u) / torch.sqrt(s + self.eps) 62 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 63 | return x 64 | -------------------------------------------------------------------------------- /models/sammodel/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock,Adapter 14 | import math 15 | import warnings 16 | from itertools import repeat 17 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 18 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 19 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 20 | from torch._six import container_abcs 21 | else: 22 | import collections.abc as container_abcs 23 | 24 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 25 | class ImageEncoderViT(nn.Module): 26 | def __init__( 27 | self, 28 | img_size: int = 1024, 29 | patch_size: int = 16, 30 | in_chans: int = 3, 31 | embed_dim: int = 768, 32 | depth: int = 12, 33 | num_heads: int = 12, 34 | mlp_ratio: float = 4.0, 35 | out_chans: int = 256, 36 | qkv_bias: bool = True, 37 | norm_layer: Type[nn.Module] = nn.LayerNorm, 38 | act_layer: Type[nn.Module] = nn.GELU, 39 | use_abs_pos: bool = True, 40 | use_rel_pos: bool = False, 41 | rel_pos_zero_init: bool = True, 42 | window_size: int = 0, 43 | global_attn_indexes: Tuple[int, ...] = (), 44 | ) -> None: 45 | """ 46 | Args: 47 | img_size (int): Input image size. 48 | patch_size (int): Patch size. 49 | in_chans (int): Number of input image channels. 50 | embed_dim (int): Patch embedding dimension. 51 | depth (int): Depth of ViT. 52 | num_heads (int): Number of attention heads in each ViT block. 53 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 54 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 55 | norm_layer (nn.Module): Normalization layer. 56 | act_layer (nn.Module): Activation layer. 57 | use_abs_pos (bool): If True, use absolute positional embeddings. 58 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 59 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 60 | window_size (int): Window size for window attention blocks. 61 | global_attn_indexes (list): Indexes for blocks using global attention. 62 | """ 63 | super().__init__() 64 | self.img_size = img_size 65 | self.embed_dim = embed_dim 66 | self.depth = depth 67 | 68 | self.patch_embed = PatchEmbed( 69 | kernel_size=(patch_size, patch_size), 70 | stride=(patch_size, patch_size), 71 | in_chans=in_chans, 72 | embed_dim=embed_dim, 73 | ) 74 | 75 | self.pos_embed: Optional[nn.Parameter] = None 76 | if use_abs_pos: 77 | # Initialize absolute positional embedding with pretrain image size. 78 | self.pos_embed = nn.Parameter( 79 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 80 | ) 81 | 82 | self.blocks = nn.ModuleList() 83 | for i in range(depth): 84 | block = Block( 85 | dim=embed_dim, 86 | num_heads=num_heads, 87 | mlp_ratio=mlp_ratio, 88 | qkv_bias=qkv_bias, 89 | norm_layer=norm_layer, 90 | act_layer=act_layer, 91 | use_rel_pos=use_rel_pos, 92 | rel_pos_zero_init=rel_pos_zero_init, 93 | window_size=window_size if i not in global_attn_indexes else 0, 94 | input_size=(img_size // patch_size, img_size // patch_size), 95 | ) 96 | self.blocks.append(block) 97 | 98 | self.neck = nn.Sequential( 99 | nn.Conv2d( 100 | embed_dim, 101 | out_chans, 102 | kernel_size=1, 103 | bias=False, 104 | ), 105 | LayerNorm2d(out_chans), 106 | nn.Conv2d( 107 | out_chans, 108 | out_chans, 109 | kernel_size=3, 110 | padding=1, 111 | bias=False, 112 | ), 113 | LayerNorm2d(out_chans), 114 | ) 115 | 116 | self.scale_factor = 32 117 | self.prompt_type = 'highpass' 118 | self.tuning_stage = 1234 119 | self.input_type = 'fft' 120 | self.freq_nums = 0.25 121 | self.handcrafted_tune = True 122 | self.embedding_tune = True 123 | self.adaptor = 'adaptor' 124 | self.prompt_generator = PromptGenerator(self.scale_factor, self.prompt_type, self.embed_dim, 125 | self.tuning_stage, self.depth, 126 | self.input_type, self.freq_nums, 127 | self.handcrafted_tune, self.embedding_tune, self.adaptor, 128 | img_size, patch_size) 129 | self.num_stages = self.depth 130 | self.out_indices = tuple(range(self.num_stages)) 131 | 132 | def forward(self, x: torch.Tensor) -> torch.Tensor: 133 | inp = x 134 | x = self.patch_embed(x) 135 | 136 | embedding_feature = self.prompt_generator.init_embeddings(x) 137 | handcrafted_feature = self.prompt_generator.init_handcrafted(inp) 138 | prompt = self.prompt_generator.get_prompt(handcrafted_feature, embedding_feature) 139 | if self.pos_embed is not None: 140 | x = x + self.pos_embed 141 | 142 | B, H, W = x.shape[0], x.shape[1], x.shape[2] 143 | outs = [] 144 | for i, blk in enumerate(self.blocks): 145 | x = prompt[i].reshape(B, H, W, -1) + x 146 | x = blk(x) 147 | if i in self.out_indices: 148 | outs.append(x) 149 | 150 | x = self.neck(x.permute(0, 3, 1, 2)) 151 | 152 | return x 153 | 154 | def to_2tuple(x): 155 | if isinstance(x, container_abcs.Iterable): 156 | return x 157 | return tuple(repeat(x, 2)) 158 | 159 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 160 | # type: (Tensor, float, float, float, float) -> Tensor 161 | r"""Fills the input Tensor with values drawn from a truncated 162 | normal distribution. The values are effectively drawn from the 163 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 164 | with values outside :math:`[a, b]` redrawn until they are within 165 | the bounds. The method used for generating the random values works 166 | best when :math:`a \leq \text{mean} \leq b`. 167 | Args: 168 | tensor: an n-dimensional `torch.Tensor` 169 | mean: the mean of the normal distribution 170 | std: the standard deviation of the normal distribution 171 | a: the minimum cutoff value 172 | b: the maximum cutoff value 173 | Examples: 174 | >>> w = torch.empty(3, 5) 175 | >>> nn.init.trunc_normal_(w) 176 | """ 177 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 178 | 179 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 180 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 181 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 182 | def norm_cdf(x): 183 | # Computes standard normal cumulative distribution function 184 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 185 | 186 | if (mean < a - 2 * std) or (mean > b + 2 * std): 187 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 188 | "The distribution of values may be incorrect.", 189 | stacklevel=2) 190 | 191 | with torch.no_grad(): 192 | # Values are generated by using a truncated uniform distribution and 193 | # then using the inverse CDF for the normal distribution. 194 | # Get upper and lower cdf values 195 | l = norm_cdf((a - mean) / std) 196 | u = norm_cdf((b - mean) / std) 197 | 198 | # Uniformly fill tensor with values from [l, u], then translate to 199 | # [2l-1, 2u-1]. 200 | tensor.uniform_(2 * l - 1, 2 * u - 1) 201 | 202 | # Use inverse cdf transform for normal distribution to get truncated 203 | # standard normal 204 | tensor.erfinv_() 205 | 206 | # Transform to proper mean, std 207 | tensor.mul_(std * math.sqrt(2.)) 208 | tensor.add_(mean) 209 | 210 | # Clamp to ensure it's in the proper range 211 | tensor.clamp_(min=a, max=b) 212 | return tensor 213 | 214 | 215 | class PromptGenerator(nn.Module): 216 | def __init__(self, scale_factor, prompt_type, embed_dim, tuning_stage, depth, input_type, 217 | freq_nums, handcrafted_tune, embedding_tune, adaptor, img_size, patch_size): 218 | """ 219 | Args: 220 | """ 221 | super(PromptGenerator, self).__init__() 222 | self.scale_factor = scale_factor 223 | self.prompt_type = prompt_type 224 | self.embed_dim = embed_dim 225 | self.input_type = input_type 226 | self.freq_nums = freq_nums 227 | self.tuning_stage = tuning_stage 228 | self.depth = depth 229 | self.handcrafted_tune = handcrafted_tune 230 | self.embedding_tune = embedding_tune 231 | self.adaptor = adaptor 232 | 233 | self.shared_mlp = nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim) 234 | self.embedding_generator = nn.Linear(self.embed_dim, self.embed_dim//self.scale_factor) 235 | for i in range(self.depth): 236 | lightweight_mlp = nn.Sequential( 237 | nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim//self.scale_factor), 238 | nn.GELU() 239 | ) 240 | setattr(self, 'lightweight_mlp_{}'.format(str(i)), lightweight_mlp) 241 | 242 | self.prompt_generator = PatchEmbed2(img_size=img_size, 243 | patch_size=patch_size, in_chans=3, 244 | embed_dim=self.embed_dim//self.scale_factor) 245 | 246 | self.apply(self._init_weights) 247 | 248 | def _init_weights(self, m): 249 | if isinstance(m, nn.Linear): 250 | trunc_normal_(m.weight, std=.02) 251 | if isinstance(m, nn.Linear) and m.bias is not None: 252 | nn.init.constant_(m.bias, 0) 253 | elif isinstance(m, nn.LayerNorm): 254 | nn.init.constant_(m.bias, 0) 255 | nn.init.constant_(m.weight, 1.0) 256 | elif isinstance(m, nn.Conv2d): 257 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 258 | fan_out //= m.groups 259 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 260 | if m.bias is not None: 261 | m.bias.data.zero_() 262 | 263 | def init_embeddings(self, x): 264 | N, C, H, W = x.permute(0, 3, 1, 2).shape 265 | x = x.reshape(N, C, H*W).permute(0, 2, 1) 266 | return self.embedding_generator(x) 267 | 268 | def init_handcrafted(self, x): 269 | x = self.fft(x, self.freq_nums) 270 | return self.prompt_generator(x) 271 | 272 | def get_prompt(self, handcrafted_feature, embedding_feature): 273 | N, C, H, W = handcrafted_feature.shape 274 | handcrafted_feature = handcrafted_feature.view(N, C, H*W).permute(0, 2, 1) 275 | prompts = [] 276 | for i in range(self.depth): 277 | lightweight_mlp = getattr(self, 'lightweight_mlp_{}'.format(str(i))) 278 | # prompt = proj_prompt(prompt) 279 | prompt = lightweight_mlp(handcrafted_feature + embedding_feature) 280 | prompts.append(self.shared_mlp(prompt)) 281 | return prompts 282 | 283 | def forward(self, x): 284 | if self.input_type == 'laplacian': 285 | pyr_A = self.lap_pyramid.pyramid_decom(img=x, num=self.freq_nums) 286 | x = pyr_A[:-1] 287 | laplacian = x[0] 288 | for x_i in x[1:]: 289 | x_i = F.interpolate(x_i, size=(laplacian.size(2), laplacian.size(3)), mode='bilinear', align_corners=True) 290 | laplacian = torch.cat([laplacian, x_i], dim=1) 291 | x = laplacian 292 | elif self.input_type == 'fft': 293 | x = self.fft(x, self.freq_nums) 294 | elif self.input_type == 'all': 295 | x = self.prompt.unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 296 | 297 | # get prompting 298 | prompt = self.prompt_generator(x) 299 | 300 | if self.mode == 'input': 301 | prompt = self.proj(prompt) 302 | return prompt 303 | elif self.mode == 'stack': 304 | prompts = [] 305 | for i in range(self.depth): 306 | proj = getattr(self, 'proj_{}'.format(str(i))) 307 | prompts.append(proj(prompt)) 308 | return prompts 309 | elif self.mode == 'hierarchical': 310 | prompts = [] 311 | for i in range(self.depth): 312 | proj_prompt = getattr(self, 'proj_prompt_{}'.format(str(i))) 313 | prompt = proj_prompt(prompt) 314 | prompts.append(self.proj_token(prompt)) 315 | return prompts 316 | 317 | def fft(self, x, rate): 318 | # the smaller rate, the smoother; the larger rate, the darker 319 | # rate = 4, 8, 16, 32 320 | mask = torch.zeros(x.shape).to(x.device) 321 | w, h = x.shape[-2:] 322 | line = int((w * h * rate) ** .5 // 2) 323 | mask[:, :, w//2-line:w//2+line, h//2-line:h//2+line] = 1 324 | 325 | fft = torch.fft.fftshift(torch.fft.fft2(x, norm="forward")) 326 | # mask[fft.float() > self.freq_nums] = 1 327 | # high pass: 1-mask, low pass: mask 328 | fft = fft * (1 - mask) 329 | # fft = fft * mask 330 | fr = fft.real 331 | fi = fft.imag 332 | 333 | fft_hires = torch.fft.ifftshift(torch.complex(fr, fi)) 334 | inv = torch.fft.ifft2(fft_hires, norm="forward").real 335 | 336 | inv = torch.abs(inv) 337 | 338 | return inv 339 | 340 | class PatchEmbed2(nn.Module): 341 | """ Image to Patch Embedding 342 | """ 343 | 344 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 345 | super().__init__() 346 | img_size = to_2tuple(img_size) 347 | patch_size = to_2tuple(patch_size) 348 | num_patches = (img_size[1] // patch_size[1]) * \ 349 | (img_size[0] // patch_size[0]) 350 | self.img_size = img_size 351 | self.patch_size = patch_size 352 | self.num_patches = num_patches 353 | 354 | self.proj = nn.Conv2d(in_chans, embed_dim, 355 | kernel_size=patch_size, stride=patch_size) 356 | 357 | def forward(self, x): 358 | B, C, H, W = x.shape 359 | # FIXME look at relaxing size constraints 360 | assert H == self.img_size[0] and W == self.img_size[1], \ 361 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 362 | 363 | # x = F.interpolate(x, size=2*x.shape[-1], mode='bilinear', align_corners=True) 364 | x = self.proj(x) 365 | return x 366 | 367 | 368 | class Block(nn.Module): 369 | """Transformer blocks with support of window attention and residual propagation blocks""" 370 | 371 | def __init__( 372 | self, 373 | dim: int, 374 | num_heads: int, 375 | mlp_ratio: float = 4.0, 376 | qkv_bias: bool = True, 377 | scale: float = 0.5, 378 | norm_layer: Type[nn.Module] = nn.LayerNorm, 379 | act_layer: Type[nn.Module] = nn.GELU, 380 | use_rel_pos: bool = False, 381 | rel_pos_zero_init: bool = True, 382 | window_size: int = 0, 383 | input_size: Optional[Tuple[int, int]] = None, 384 | ) -> None: 385 | """ 386 | Args: 387 | dim (int): Number of input channels. 388 | num_heads (int): Number of attention heads in each ViT block. 389 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 390 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 391 | norm_layer (nn.Module): Normalization layer. 392 | act_layer (nn.Module): Activation layer. 393 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 394 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 395 | window_size (int): Window size for window attention blocks. If it equals 0, then 396 | use global attention. 397 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 398 | positional parameter size. 399 | """ 400 | super().__init__() 401 | self.norm1 = norm_layer(dim) 402 | self.attn = Attention( 403 | dim, 404 | num_heads=num_heads, 405 | qkv_bias=qkv_bias, 406 | use_rel_pos=use_rel_pos, 407 | rel_pos_zero_init=rel_pos_zero_init, 408 | input_size=input_size if window_size == 0 else (window_size, window_size), 409 | ) 410 | 411 | self.MLP_Adapter = Adapter(dim, skip_connect=False) # MLP-adapter, no skip connection 412 | self.Space_Adapter = Adapter(dim) # with skip connection 413 | self.scale = scale 414 | self.Depth_Adapter = Adapter(dim, skip_connect=False) # no skip connection 415 | 416 | self.norm2 = norm_layer(dim) 417 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 418 | 419 | self.window_size = window_size 420 | 421 | def forward(self, x: torch.Tensor) -> torch.Tensor: 422 | shortcut = x 423 | x = self.norm1(x) 424 | # Window partition 425 | if self.window_size > 0: 426 | H, W = x.shape[1], x.shape[2] 427 | x, pad_hw = window_partition(x, self.window_size) 428 | 429 | x = self.attn(x) 430 | x = self.Space_Adapter(x) 431 | # Reverse window partition 432 | if self.window_size > 0: 433 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 434 | 435 | x = shortcut + x 436 | # x = x + self.mlp(self.norm2(x)) 437 | xn = self.norm2(x) 438 | x = x + self.mlp(xn) + self.scale * self.MLP_Adapter(xn) 439 | return x 440 | 441 | 442 | class Attention(nn.Module): 443 | """Multi-head Attention block with relative position embeddings.""" 444 | 445 | def __init__( 446 | self, 447 | dim: int, 448 | num_heads: int = 8, 449 | qkv_bias: bool = True, 450 | use_rel_pos: bool = False, 451 | rel_pos_zero_init: bool = True, 452 | input_size: Optional[Tuple[int, int]] = None, 453 | ) -> None: 454 | """ 455 | Args: 456 | dim (int): Number of input channels. 457 | num_heads (int): Number of attention heads. 458 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 459 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 460 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 461 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 462 | positional parameter size. 463 | """ 464 | super().__init__() 465 | self.num_heads = num_heads 466 | head_dim = dim // num_heads 467 | self.scale = head_dim**-0.5 468 | 469 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 470 | self.proj = nn.Linear(dim, dim) 471 | 472 | self.use_rel_pos = use_rel_pos 473 | if self.use_rel_pos: 474 | assert ( 475 | input_size is not None 476 | ), "Input size must be provided if using relative positional encoding." 477 | # initialize relative positional embeddings 478 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 479 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 480 | 481 | def forward(self, x: torch.Tensor) -> torch.Tensor: 482 | B, H, W, _ = x.shape 483 | # qkv with shape (3, B, nHead, H * W, C) 484 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 485 | # q, k, v with shape (B * nHead, H * W, C) 486 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 487 | 488 | attn = (q * self.scale) @ k.transpose(-2, -1) 489 | 490 | if self.use_rel_pos: 491 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 492 | 493 | attn = attn.softmax(dim=-1) 494 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 495 | x = self.proj(x) 496 | 497 | return x 498 | 499 | 500 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 501 | """ 502 | Partition into non-overlapping windows with padding if needed. 503 | Args: 504 | x (tensor): input tokens with [B, H, W, C]. 505 | window_size (int): window size. 506 | 507 | Returns: 508 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 509 | (Hp, Wp): padded height and width before partition 510 | """ 511 | B, H, W, C = x.shape 512 | 513 | pad_h = (window_size - H % window_size) % window_size 514 | pad_w = (window_size - W % window_size) % window_size 515 | if pad_h > 0 or pad_w > 0: 516 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 517 | Hp, Wp = H + pad_h, W + pad_w 518 | 519 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 520 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 521 | return windows, (Hp, Wp) 522 | 523 | 524 | def window_unpartition( 525 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 526 | ) -> torch.Tensor: 527 | """ 528 | Window unpartition into original sequences and removing padding. 529 | Args: 530 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 531 | window_size (int): window size. 532 | pad_hw (Tuple): padded height and width (Hp, Wp). 533 | hw (Tuple): original height and width (H, W) before padding. 534 | 535 | Returns: 536 | x: unpartitioned sequences with [B, H, W, C]. 537 | """ 538 | Hp, Wp = pad_hw 539 | H, W = hw 540 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 541 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 542 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 543 | 544 | if Hp > H or Wp > W: 545 | x = x[:, :H, :W, :].contiguous() 546 | return x 547 | 548 | 549 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 550 | """ 551 | Get relative positional embeddings according to the relative positions of 552 | query and key sizes. 553 | Args: 554 | q_size (int): size of query q. 555 | k_size (int): size of key k. 556 | rel_pos (Tensor): relative position embeddings (L, C). 557 | 558 | Returns: 559 | Extracted positional embeddings according to relative positions. 560 | """ 561 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 562 | # Interpolate rel pos if needed. 563 | if rel_pos.shape[0] != max_rel_dist: 564 | # Interpolate rel pos. 565 | rel_pos_resized = F.interpolate( 566 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 567 | size=max_rel_dist, 568 | mode="linear", 569 | ) 570 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 571 | else: 572 | rel_pos_resized = rel_pos 573 | 574 | # Scale the coords with short length if shapes for q and k are different. 575 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 576 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 577 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 578 | 579 | return rel_pos_resized[relative_coords.long()] 580 | 581 | 582 | def add_decomposed_rel_pos( 583 | attn: torch.Tensor, 584 | q: torch.Tensor, 585 | rel_pos_h: torch.Tensor, 586 | rel_pos_w: torch.Tensor, 587 | q_size: Tuple[int, int], 588 | k_size: Tuple[int, int], 589 | ) -> torch.Tensor: 590 | """ 591 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 592 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 593 | Args: 594 | attn (Tensor): attention map. 595 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 596 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 597 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 598 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 599 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 600 | 601 | Returns: 602 | attn (Tensor): attention map with added relative positional embeddings. 603 | """ 604 | q_h, q_w = q_size 605 | k_h, k_w = k_size 606 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 607 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 608 | 609 | B, _, dim = q.shape 610 | r_q = q.reshape(B, q_h, q_w, dim) 611 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 612 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 613 | 614 | attn = ( 615 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 616 | ).view(B, q_h * q_w, k_h * k_w) 617 | 618 | return attn 619 | 620 | 621 | class PatchEmbed(nn.Module): 622 | """ 623 | Image to Patch Embedding. 624 | """ 625 | 626 | def __init__( 627 | self, 628 | kernel_size: Tuple[int, int] = (16, 16), 629 | stride: Tuple[int, int] = (16, 16), 630 | padding: Tuple[int, int] = (0, 0), 631 | in_chans: int = 3, 632 | embed_dim: int = 768, 633 | ) -> None: 634 | """ 635 | Args: 636 | kernel_size (Tuple): kernel size of the projection layer. 637 | stride (Tuple): stride of the projection layer. 638 | padding (Tuple): padding size of the projection layer. 639 | in_chans (int): Number of input image channels. 640 | embed_dim (int): embed_dim (int): Patch embedding dimension. 641 | """ 642 | super().__init__() 643 | 644 | self.proj = nn.Conv2d( 645 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 646 | ) 647 | 648 | def forward(self, x: torch.Tensor) -> torch.Tensor: 649 | x = self.proj(x) 650 | # B C H W -> B H W C 651 | x = x.permute(0, 2, 3, 1) 652 | return x 653 | -------------------------------------------------------------------------------- /models/sammodel/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /models/sammodel/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /models/sammodel/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /models/sammodel/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.9.0 2 | ipython~=8.11.0 3 | matplotlib~=3.1.2 4 | opencv-python~=4.7.0.72 5 | PyYAML~=6.0 6 | scikit-learn~=1.2.2 7 | scipy~=1.10.1 8 | 9 | tqdm~=4.51.0 10 | torch~=1.13.0+cu116 11 | numpy~=1.20.3 12 | typing~=3.7.4.3 13 | terminaltables~=3.1.10 14 | Pillow~=9.4.0 15 | torchvision~=0.14.0+cu116 16 | tensorboardX~=2.6 17 | onnxruntime~=1.14.1 18 | setuptools~=67.6.1 19 | timm~=0.3.2 20 | easydict~=1.10 21 | attr~=0.3.2 22 | thop~=0.1.1.post2209072238 23 | torchsummary~=1.5.1 -------------------------------------------------------------------------------- /sod_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from scipy.ndimage import convolve 5 | from scipy.ndimage import distance_transform_edt as bwdist 6 | 7 | import cv2 8 | _EPS = np.spacing(1) # the different implementation of epsilon (extreme min value) between numpy and matlab 9 | _TYPE = np.float64 10 | 11 | 12 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: 13 | """ 14 | A numpy-based function for preparing ``pred`` and ``gt``. 15 | - for ``pred``, it looks like ``mapminmax(im2double(...))`` of matlab; 16 | - ``gt`` will be binarized by 128. 17 | :param pred: prediction 18 | :param gt: mask 19 | :return: pred, gt 20 | """ 21 | gt = gt > 128 22 | # im2double, mapminmax 23 | pred = pred / 255 24 | if pred.max() != pred.min(): 25 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 26 | return pred, gt 27 | 28 | 29 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 30 | """ 31 | Return an adaptive threshold, which is equal to twice the mean of ``matrix``. 32 | :param matrix: a data array 33 | :param max_value: the upper limit of the threshold 34 | :return: min(2 * matrix.mean(), max_value) 35 | """ 36 | return min(2 * matrix.mean(), max_value) 37 | 38 | 39 | class Fmeasure(object): 40 | def __init__(self, beta: float = 1.0): 41 | """ 42 | F-measure for SOD. 43 | :: 44 | @inproceedings{Fmeasure, 45 | title={Frequency-tuned salient region detection}, 46 | author={Achanta, Radhakrishna and Hemami, Sheila and Estrada, Francisco and S{\"u}sstrunk, Sabine}, 47 | booktitle=CVPR, 48 | number={CONF}, 49 | pages={1597--1604}, 50 | year={2009} 51 | } 52 | :param beta: the weight of the precision 53 | """ 54 | self.beta = beta 55 | self.precisions = [] 56 | self.recalls = [] 57 | self.adaptive_fms = [] 58 | self.changeable_fms = [] 59 | 60 | def step(self, pred: np.ndarray, gt: np.ndarray): 61 | pred, gt = _prepare_data(pred, gt) 62 | 63 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 64 | self.adaptive_fms.append(adaptive_fm) 65 | 66 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 67 | self.precisions.append(precisions) 68 | self.recalls.append(recalls) 69 | self.changeable_fms.append(changeable_fms) 70 | 71 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 72 | """ 73 | Calculate the adaptive F-measure. 74 | :return: adaptive_fm 75 | """ 76 | # ``np.count_nonzero`` is faster and better 77 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 78 | binary_predcition = pred >= adaptive_threshold 79 | area_intersection = binary_predcition[gt].sum() 80 | if area_intersection == 0: 81 | adaptive_fm = 0 82 | else: 83 | pre = area_intersection / np.count_nonzero(binary_predcition) 84 | rec = area_intersection / np.count_nonzero(gt) 85 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 86 | return adaptive_fm 87 | 88 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 89 | """ 90 | Calculate the corresponding precision and recall when the threshold changes from 0 to 255. 91 | These precisions and recalls can be used to obtain the mean F-measure, maximum F-measure, 92 | precision-recall curve and F-measure-threshold curve. 93 | For convenience, ``changeable_fms`` is provided here, which can be used directly to obtain 94 | the mean F-measure, maximum F-measure and F-measure-threshold curve. 95 | :return: precisions, recalls, changeable_fms 96 | """ 97 | # 1. 获取预测结果在真值前背景区域中的直方图 98 | pred = (pred * 255).astype(np.uint8) 99 | bins = np.linspace(0, 256, 257) 100 | fg_hist, _ = np.histogram(pred[gt], bins=bins) # 最后一个bin为[255, 256] 101 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 102 | # 2. 使用累积直方图(Cumulative Histogram)获得对应真值前背景中大于不同阈值的像素数量 103 | # 这里使用累加(cumsum)就是为了一次性得出 >=不同阈值 的像素数量, 这里仅计算了前景区域 104 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 105 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 106 | # 3. 使用不同阈值的结果计算对应的precision和recall 107 | # p和r的计算的真值是pred==1>==1,二者仅有分母不同,分母前者是pred==1,后者是gt==1 108 | # 为了同时计算不同阈值的结果,这里使用hsitogram&flip&cumsum 获得了不同各自的前景像素数量 109 | TPs = fg_w_thrs 110 | Ps = fg_w_thrs + bg_w_thrs 111 | # 为防止除0,这里针对除0的情况分析后直接对于0分母设为1,因为此时分子必为0 112 | Ps[Ps == 0] = 1 113 | T = max(np.count_nonzero(gt), 1) 114 | # TODO: T=0 或者 特定阈值下fg_w_thrs=0或者bg_w_thrs=0,这些都会包含在TPs[i]=0的情况中, 115 | # 但是这里使用TPs不便于处理列表 116 | precisions = TPs / Ps 117 | recalls = TPs / T 118 | 119 | numerator = (1 + self.beta) * precisions * recalls 120 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 121 | changeable_fms = numerator / denominator 122 | return precisions, recalls, changeable_fms 123 | 124 | def get_results(self) -> dict: 125 | """ 126 | Return the results about F-measure. 127 | :return: dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall)) 128 | """ 129 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 130 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 131 | precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 132 | recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 133 | return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), pr=dict(p=precision, r=recall)) 134 | 135 | 136 | class MAE(object): 137 | def __init__(self): 138 | """ 139 | MAE(mean absolute error) for SOD. 140 | :: 141 | @inproceedings{MAE, 142 | title={Saliency filters: Contrast based filtering for salient region detection}, 143 | author={Perazzi, Federico and Kr{\"a}henb{\"u}hl, Philipp and Pritch, Yael and Hornung, Alexander}, 144 | booktitle=CVPR, 145 | pages={733--740}, 146 | year={2012} 147 | } 148 | """ 149 | self.maes = [] 150 | 151 | def step(self, pred: np.ndarray, gt: np.ndarray): 152 | pred, gt = _prepare_data(pred, gt) 153 | 154 | mae = self.cal_mae(pred, gt) 155 | # mae = np.sum(cv2.absdiff(gt.astype(float), pred.astype(float))) / (pred.shape[1] * pred.shape[0]) 156 | self.maes.append(mae) 157 | 158 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 159 | """ 160 | Calculate the mean absolute error. 161 | :return: mae 162 | """ 163 | mae = np.mean(np.abs(pred - gt)) 164 | return mae 165 | 166 | def get_results(self) -> dict: 167 | """ 168 | Return the results about MAE. 169 | :return: dict(mae=mae) 170 | """ 171 | mae = np.mean(np.array(self.maes, _TYPE)) 172 | return dict(mae=mae) 173 | 174 | 175 | class Smeasure(object): 176 | def __init__(self, alpha: float = 0.5): 177 | """ 178 | S-measure(Structure-measure) of SOD. 179 | :: 180 | @inproceedings{Smeasure, 181 | title={Structure-measure: A new way to eval foreground maps}, 182 | author={Fan, Deng-Ping and Cheng, Ming-Ming and Liu, Yun and Li, Tao and Borji, Ali}, 183 | booktitle=ICCV, 184 | pages={4548--4557}, 185 | year={2017} 186 | } 187 | :param alpha: the weight for balancing the object score and the region score 188 | """ 189 | self.sms = [] 190 | self.alpha = alpha 191 | 192 | def step(self, pred: np.ndarray, gt: np.ndarray): 193 | pred, gt = _prepare_data(pred=pred, gt=gt) 194 | 195 | sm = self.cal_sm(pred, gt) 196 | self.sms.append(sm) 197 | 198 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 199 | """ 200 | Calculate the S-measure. 201 | :return: s-measure 202 | """ 203 | y = np.mean(gt) 204 | if y == 0: 205 | sm = 1 - np.mean(pred) 206 | elif y == 1: 207 | sm = np.mean(pred) 208 | else: 209 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 210 | sm = max(0, sm) 211 | return sm 212 | 213 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 214 | """ 215 | Calculate the object score. 216 | """ 217 | fg = pred * gt 218 | bg = (1 - pred) * (1 - gt) 219 | u = np.mean(gt) 220 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 221 | return object_score 222 | 223 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 224 | x = np.mean(pred[gt == 1]) 225 | sigma_x = np.std(pred[gt == 1], ddof=1) 226 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 227 | return score 228 | 229 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 230 | """ 231 | Calculate the region score. 232 | """ 233 | x, y = self.centroid(gt) 234 | part_info = self.divide_with_xy(pred, gt, x, y) 235 | w1, w2, w3, w4 = part_info["weight"] 236 | # assert np.isclose(w1 + w2 + w3 + w4, 1), (w1 + w2 + w3 + w4, pred.mean(), gt.mean()) 237 | 238 | pred1, pred2, pred3, pred4 = part_info["pred"] 239 | gt1, gt2, gt3, gt4 = part_info["gt"] 240 | score1 = self.ssim(pred1, gt1) 241 | score2 = self.ssim(pred2, gt2) 242 | score3 = self.ssim(pred3, gt3) 243 | score4 = self.ssim(pred4, gt4) 244 | 245 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 246 | 247 | def centroid(self, matrix: np.ndarray) -> tuple: 248 | """ 249 | To ensure consistency with the matlab code, one is added to the centroid coordinate, 250 | so there is no need to use the redundant addition operation when dividing the region later, 251 | because the sequence generated by ``1:X`` in matlab will contain ``X``. 252 | :param matrix: a bool data array 253 | :return: the centroid coordinate 254 | """ 255 | h, w = matrix.shape 256 | area_object = np.count_nonzero(matrix) 257 | if area_object == 0: 258 | x = np.round(w / 2) 259 | y = np.round(h / 2) 260 | else: 261 | # More details can be found at: https://www.yuque.com/lart/blog/gpbigm 262 | y, x = np.argwhere(matrix).mean(axis=0).round() 263 | return int(x) + 1, int(y) + 1 264 | 265 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x: int, y: int) -> dict: 266 | """ 267 | Use (x,y) to divide the ``pred`` and the ``gt`` into four submatrices, respectively. 268 | """ 269 | h, w = gt.shape 270 | area = h * w 271 | 272 | gt_LT = gt[0:y, 0:x] 273 | gt_RT = gt[0:y, x:w] 274 | gt_LB = gt[y:h, 0:x] 275 | gt_RB = gt[y:h, x:w] 276 | 277 | pred_LT = pred[0:y, 0:x] 278 | pred_RT = pred[0:y, x:w] 279 | pred_LB = pred[y:h, 0:x] 280 | pred_RB = pred[y:h, x:w] 281 | 282 | w1 = x * y / area 283 | w2 = y * (w - x) / area 284 | w3 = (h - y) * x / area 285 | w4 = 1 - w1 - w2 - w3 286 | 287 | return dict( 288 | gt=(gt_LT, gt_RT, gt_LB, gt_RB), 289 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 290 | weight=(w1, w2, w3, w4), 291 | ) 292 | 293 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 294 | """ 295 | Calculate the ssim score. 296 | """ 297 | h, w = pred.shape 298 | N = h * w 299 | 300 | x = np.mean(pred) 301 | y = np.mean(gt) 302 | 303 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 304 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 305 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 306 | 307 | alpha = 4 * x * y * sigma_xy 308 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 309 | 310 | if alpha != 0: 311 | score = alpha / (beta + _EPS) 312 | elif alpha == 0 and beta == 0: 313 | score = 1 314 | else: 315 | score = 0 316 | return score 317 | 318 | def get_results(self) -> dict: 319 | """ 320 | Return the results about S-measure. 321 | :return: dict(sm=sm) 322 | """ 323 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 324 | return dict(sm=sm) 325 | 326 | 327 | class Emeasure(object): 328 | def __init__(self): 329 | """ 330 | E-measure(Enhanced-alignment Measure) for SOD. 331 | More details about the implementation can be found in https://www.yuque.com/lart/blog/lwgt38 332 | :: 333 | @inproceedings{Emeasure, 334 | title="Enhanced-alignment Measure for Binary Foreground Map Evaluation", 335 | author="Deng-Ping {Fan} and Cheng {Gong} and Yang {Cao} and Bo {Ren} and Ming-Ming {Cheng} and Ali {Borji}", 336 | booktitle=IJCAI, 337 | pages="698--704", 338 | year={2018} 339 | } 340 | """ 341 | self.adaptive_ems = [] 342 | self.changeable_ems = [] 343 | 344 | def step(self, pred: np.ndarray, gt: np.ndarray): 345 | pred, gt = _prepare_data(pred=pred, gt=gt) 346 | 347 | self.gt_fg_numel = np.count_nonzero(gt) 348 | self.gt_size = gt.shape[0] * gt.shape[1] 349 | 350 | changeable_ems = self.cal_changeable_em(pred, gt) 351 | self.changeable_ems.append(changeable_ems) 352 | adaptive_em = self.cal_adaptive_em(pred, gt) 353 | self.adaptive_ems.append(adaptive_em) 354 | 355 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 356 | """ 357 | Calculate the adaptive E-measure. 358 | :return: adaptive_em 359 | """ 360 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 361 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 362 | return adaptive_em 363 | 364 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 365 | """ 366 | Calculate the changeable E-measure, which can be used to obtain the mean E-measure, 367 | the maximum E-measure and the E-measure-threshold curve. 368 | :return: changeable_ems 369 | """ 370 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 371 | return changeable_ems 372 | 373 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 374 | """ 375 | Calculate the E-measure corresponding to the specific threshold. 376 | Variable naming rules within the function: 377 | ``[pred attribute(foreground fg, background bg)]_[gt attribute(foreground fg, background bg)]_[meaning]`` 378 | If only ``pred`` or ``gt`` is considered, another corresponding attribute location is replaced with '``_``'. 379 | """ 380 | binarized_pred = pred >= threshold 381 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 382 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 383 | 384 | fg___numel = fg_fg_numel + fg_bg_numel 385 | bg___numel = self.gt_size - fg___numel 386 | 387 | if self.gt_fg_numel == 0: 388 | enhanced_matrix_sum = bg___numel 389 | elif self.gt_fg_numel == self.gt_size: 390 | enhanced_matrix_sum = fg___numel 391 | else: 392 | parts_numel, combinations = self.generate_parts_numel_combinations( 393 | fg_fg_numel=fg_fg_numel, 394 | fg_bg_numel=fg_bg_numel, 395 | pred_fg_numel=fg___numel, 396 | pred_bg_numel=bg___numel, 397 | ) 398 | 399 | results_parts = [] 400 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 401 | align_matrix_value = ( 402 | 2 403 | * (combination[0] * combination[1]) 404 | / (combination[0] ** 2 + combination[1] ** 2 + _EPS) 405 | ) 406 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 407 | results_parts.append(enhanced_matrix_value * part_numel) 408 | enhanced_matrix_sum = sum(results_parts) 409 | 410 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 411 | return em 412 | 413 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 414 | """ 415 | Calculate the E-measure corresponding to the threshold that varies from 0 to 255.. 416 | Variable naming rules within the function: 417 | ``[pred attribute(foreground fg, background bg)]_[gt attribute(foreground fg, background bg)]_[meaning]`` 418 | If only ``pred`` or ``gt`` is considered, another corresponding attribute location is replaced with '``_``'. 419 | """ 420 | pred = (pred * 255).astype(np.uint8) 421 | bins = np.linspace(0, 256, 257) 422 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 423 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 424 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 425 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 426 | 427 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 428 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 429 | 430 | if self.gt_fg_numel == 0: 431 | enhanced_matrix_sum = bg___numel_w_thrs 432 | elif self.gt_fg_numel == self.gt_size: 433 | enhanced_matrix_sum = fg___numel_w_thrs 434 | else: 435 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 436 | fg_fg_numel=fg_fg_numel_w_thrs, 437 | fg_bg_numel=fg_bg_numel_w_thrs, 438 | pred_fg_numel=fg___numel_w_thrs, 439 | pred_bg_numel=bg___numel_w_thrs, 440 | ) 441 | 442 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 443 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 444 | align_matrix_value = ( 445 | 2 446 | * (combination[0] * combination[1]) 447 | / (combination[0] ** 2 + combination[1] ** 2 + _EPS) 448 | ) 449 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 450 | results_parts[i] = enhanced_matrix_value * part_numel 451 | enhanced_matrix_sum = results_parts.sum(axis=0) 452 | 453 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 454 | return em 455 | 456 | def generate_parts_numel_combinations( 457 | self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel 458 | ): 459 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 460 | bg_bg_numel = pred_bg_numel - bg_fg_numel 461 | 462 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 463 | 464 | mean_pred_value = pred_fg_numel / self.gt_size 465 | mean_gt_value = self.gt_fg_numel / self.gt_size 466 | 467 | demeaned_pred_fg_value = 1 - mean_pred_value 468 | demeaned_pred_bg_value = 0 - mean_pred_value 469 | demeaned_gt_fg_value = 1 - mean_gt_value 470 | demeaned_gt_bg_value = 0 - mean_gt_value 471 | 472 | combinations = [ 473 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 474 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 475 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 476 | (demeaned_pred_bg_value, demeaned_gt_bg_value), 477 | ] 478 | return parts_numel, combinations 479 | 480 | def get_results(self) -> dict: 481 | """ 482 | Return the results about E-measure. 483 | :return: dict(em=dict(adp=adaptive_em, curve=changeable_em)) 484 | """ 485 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 486 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 487 | return dict(em=dict(adp=adaptive_em, curve=changeable_em)) 488 | 489 | 490 | class WeightedFmeasure(object): 491 | def __init__(self, beta: float = 0.3): 492 | """ 493 | Weighted F-measure for SOD. 494 | :: 495 | @inproceedings{wFmeasure, 496 | title={How to eval foreground maps?}, 497 | author={Margolin, Ran and Zelnik-Manor, Lihi and Tal, Ayellet}, 498 | booktitle=CVPR, 499 | pages={248--255}, 500 | year={2014} 501 | } 502 | :param beta: the weight of the precision 503 | """ 504 | self.beta = beta 505 | self.weighted_fms = [] 506 | 507 | def step(self, pred: np.ndarray, gt: np.ndarray): 508 | pred, gt = _prepare_data(pred=pred, gt=gt) 509 | 510 | if np.all(~gt): 511 | wfm = 0 512 | else: 513 | wfm = self.cal_wfm(pred, gt) 514 | self.weighted_fms.append(wfm) 515 | 516 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 517 | """ 518 | Calculate the weighted F-measure. 519 | """ 520 | # [Dst,IDXT] = bwdist(dGT); 521 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 522 | 523 | # %Pixel dependency 524 | # E = abs(FG-dGT); 525 | E = np.abs(pred - gt) 526 | # Et = E; 527 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 528 | Et = np.copy(E) 529 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 530 | 531 | # K = fspecial('gaussian',7,5); 532 | # EA = imfilter(Et,K); 533 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 534 | EA = convolve(Et, weights=K, mode="constant", cval=0) 535 | # MIN_E_EA = E; 536 | # MIN_E_EA(GT & EA np.ndarray: 563 | """ 564 | 2D gaussian mask - should give the same result as MATLAB's 565 | fspecial('gaussian',[shape],[sigma]) 566 | """ 567 | m, n = [(ss - 1) / 2 for ss in shape] 568 | y, x = np.ogrid[-m : m + 1, -n : n + 1] 569 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 570 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 571 | sumh = h.sum() 572 | if sumh != 0: 573 | h /= sumh 574 | return h 575 | 576 | def get_results(self) -> dict: 577 | """ 578 | Return the results about weighted F-measure. 579 | :return: dict(wfm=weighted_fm) 580 | """ 581 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 582 | return dict(wfm=weighted_fm) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import yaml 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | import datasets 10 | import models 11 | import utils 12 | 13 | from torchvision import transforms 14 | from mmcv.runner import load_checkpoint 15 | 16 | 17 | def batched_predict(model, inp, coord, bsize): 18 | with torch.no_grad(): 19 | model.gen_feat(inp) 20 | n = coord.shape[1] 21 | ql = 0 22 | preds = [] 23 | while ql < n: 24 | qr = min(ql + bsize, n) 25 | pred = model.query_rgb(coord[:, ql: qr, :]) 26 | preds.append(pred) 27 | ql = qr 28 | pred = torch.cat(preds, dim=1) 29 | return pred, preds 30 | 31 | 32 | def tensor2PIL(tensor): 33 | toPIL = transforms.ToPILImage() 34 | return toPIL(tensor) 35 | 36 | 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | 39 | 40 | def eval_psnr(loader, model, data_norm=None, eval_type=None, eval_bsize=None, 41 | verbose=False): 42 | model.eval() 43 | if data_norm is None: 44 | data_norm = { 45 | 'inp': {'sub': [0], 'div': [1]}, 46 | 'gt': {'sub': [0], 'div': [1]} 47 | } 48 | 49 | if eval_type == 'f1': 50 | metric_fn = utils.calc_f1 51 | metric1, metric2, metric3, metric4 = 'f1', 'auc', 'none', 'none' 52 | elif eval_type == 'fmeasure': 53 | metric_fn = utils.calc_fmeasure 54 | metric1, metric2, metric3, metric4 = 'f_mea', 'mae', 'none', 'none' 55 | elif eval_type == 'ber': 56 | metric_fn = utils.calc_ber 57 | metric1, metric2, metric3, metric4 = 'shadow', 'non_shadow', 'ber', 'none' 58 | elif eval_type == 'cod': 59 | metric_fn = utils.calc_cod 60 | metric1, metric2, metric3, metric4 = 'sm', 'em', 'wfm', 'mae' 61 | 62 | val_metric1 = utils.Averager() 63 | val_metric2 = utils.Averager() 64 | val_metric3 = utils.Averager() 65 | val_metric4 = utils.Averager() 66 | 67 | pbar = tqdm(loader, leave=False, desc='val') 68 | 69 | for batch in pbar: 70 | for k, v in batch.items(): 71 | batch[k] = v.cuda() 72 | 73 | inp = batch['inp'] 74 | 75 | pred = torch.sigmoid(model.infer(inp)) 76 | 77 | result1, result2, result3, result4 = metric_fn(pred, batch['gt']) 78 | val_metric1.add(result1.item(), inp.shape[0]) 79 | val_metric2.add(result2.item(), inp.shape[0]) 80 | val_metric3.add(result3.item(), inp.shape[0]) 81 | val_metric4.add(result4.item(), inp.shape[0]) 82 | 83 | if verbose: 84 | pbar.set_description('val {} {:.4f}'.format(metric1, val_metric1.item())) 85 | pbar.set_description('val {} {:.4f}'.format(metric2, val_metric2.item())) 86 | pbar.set_description('val {} {:.4f}'.format(metric3, val_metric3.item())) 87 | pbar.set_description('val {} {:.4f}'.format(metric4, val_metric4.item())) 88 | 89 | return val_metric1.item(), val_metric2.item(), val_metric3.item(), val_metric4.item() 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--config') 95 | parser.add_argument('--model') 96 | parser.add_argument('--prompt', default='none') 97 | args = parser.parse_args() 98 | 99 | with open(args.config, 'r') as f: 100 | config = yaml.load(f, Loader=yaml.FullLoader) 101 | spec = config['test_dataset'] 102 | dataset = datasets.make(spec['dataset']) 103 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) 104 | loader = DataLoader(dataset, batch_size=spec['batch_size'], 105 | num_workers=8) 106 | 107 | model = models.make(config['model']).cuda() 108 | sam_checkpoint = torch.load(args.model, map_location='cuda:0') 109 | model.load_state_dict(sam_checkpoint, strict=True) 110 | 111 | metric1, metric2, metric3, metric4 = eval_psnr(loader, model, 112 | data_norm=config.get('data_norm'), 113 | eval_type=config.get('eval_type'), 114 | eval_bsize=config.get('eval_bsize'), 115 | verbose=True) 116 | print('metric1: {:.4f}'.format(metric1)) 117 | print('metric2: {:.4f}'.format(metric2)) 118 | print('metric3: {:.4f}'.format(metric3)) 119 | print('metric4: {:.4f}'.format(metric4)) 120 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import yaml 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | from torch.optim.lr_scheduler import CosineAnnealingLR 8 | 9 | import datasets 10 | import models 11 | import utils 12 | from statistics import mean 13 | import torch 14 | import torch.distributed as dist 15 | 16 | torch.distributed.init_process_group(backend='nccl') 17 | local_rank = torch.distributed.get_rank() 18 | torch.cuda.set_device(local_rank) 19 | device = torch.device("cuda", local_rank) 20 | 21 | 22 | def make_data_loader(spec, tag=''): 23 | if spec is None: 24 | return None 25 | 26 | dataset = datasets.make(spec['dataset']) 27 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) 28 | if local_rank == 0: 29 | log('{} dataset: size={}'.format(tag, len(dataset))) 30 | for k, v in dataset[0].items(): 31 | log(' {}: shape={}'.format(k, tuple(v.shape))) 32 | 33 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 34 | loader = DataLoader(dataset, batch_size=spec['batch_size'], 35 | shuffle=False, num_workers=8, pin_memory=True, sampler=sampler) 36 | return loader 37 | 38 | 39 | def make_data_loaders(): 40 | train_loader = make_data_loader(config.get('train_dataset'), tag='train') 41 | val_loader = make_data_loader(config.get('val_dataset'), tag='val') 42 | return train_loader, val_loader 43 | 44 | 45 | def eval_psnr(loader, model, eval_type=None): 46 | model.eval() 47 | 48 | if eval_type == 'f1': 49 | metric_fn = utils.calc_f1 50 | metric1, metric2, metric3, metric4 = 'f1', 'auc', 'none', 'none' 51 | elif eval_type == 'fmeasure': 52 | metric_fn = utils.calc_fmeasure 53 | metric1, metric2, metric3, metric4 = 'f_mea', 'mae', 'none', 'none' 54 | elif eval_type == 'ber': 55 | metric_fn = utils.calc_ber 56 | metric1, metric2, metric3, metric4 = 'shadow', 'non_shadow', 'ber', 'none' 57 | elif eval_type == 'cod': 58 | metric_fn = utils.calc_cod 59 | metric1, metric2, metric3, metric4 = 'sm', 'em', 'wfm', 'mae' 60 | 61 | if local_rank == 0: 62 | pbar = tqdm(total=len(loader), leave=False, desc='val') 63 | else: 64 | pbar = None 65 | 66 | pred_list = [] 67 | gt_list = [] 68 | for batch in loader: 69 | for k, v in batch.items(): 70 | batch[k] = v.cuda() 71 | 72 | inp = batch['inp'] 73 | 74 | pred = torch.sigmoid(model.infer(inp)) 75 | 76 | batch_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())] 77 | batch_gt = [torch.zeros_like(batch['gt']) for _ in range(dist.get_world_size())] 78 | 79 | dist.all_gather(batch_pred, pred) 80 | pred_list.extend(batch_pred) 81 | dist.all_gather(batch_gt, batch['gt']) 82 | gt_list.extend(batch_gt) 83 | if pbar is not None: 84 | pbar.update(1) 85 | 86 | if pbar is not None: 87 | pbar.close() 88 | 89 | pred_list = torch.cat(pred_list, 1) 90 | gt_list = torch.cat(gt_list, 1) 91 | result1, result2, result3, result4 = metric_fn(pred_list, gt_list) 92 | 93 | return result1, result2, result3, result4, metric1, metric2, metric3, metric4 94 | 95 | 96 | def prepare_training(): 97 | if config.get('resume') is not None: 98 | model = models.make(config['model']).cuda() 99 | optimizer = utils.make_optimizer( 100 | model.parameters(), config['optimizer']) 101 | epoch_start = config.get('resume') + 1 102 | else: 103 | model = models.make(config['model']).cuda() 104 | optimizer = utils.make_optimizer( 105 | model.parameters(), config['optimizer']) 106 | epoch_start = 1 107 | max_epoch = config.get('epoch_max') 108 | lr_scheduler = CosineAnnealingLR(optimizer, max_epoch, eta_min=config.get('lr_min')) 109 | if local_rank == 0: 110 | log('model: #params={}'.format(utils.compute_num_params(model, text=True))) 111 | return model, optimizer, epoch_start, lr_scheduler 112 | 113 | def train(train_loader, model): 114 | model.train() 115 | 116 | if local_rank == 0: 117 | pbar = tqdm(total=len(train_loader), leave=False, desc='train') 118 | else: 119 | pbar = None 120 | 121 | loss_list = [] 122 | for batch in train_loader: 123 | for k, v in batch.items(): 124 | batch[k] = v.to(device) 125 | inp = batch['inp'] 126 | gt = batch['gt'] 127 | model.set_input(inp, gt) 128 | model.optimize_parameters() 129 | batch_loss = [torch.zeros_like(model.loss_G) for _ in range(dist.get_world_size())] 130 | dist.all_gather(batch_loss, model.loss_G) 131 | loss_list.extend(batch_loss) 132 | if pbar is not None: 133 | pbar.update(1) 134 | 135 | if pbar is not None: 136 | pbar.close() 137 | 138 | loss = [i.item() for i in loss_list] 139 | return mean(loss) 140 | 141 | 142 | def main(config_, save_path, args): 143 | global config, log, writer, log_info 144 | config = config_ 145 | log, writer = utils.set_save_path(save_path, remove=False) 146 | with open(os.path.join(save_path, 'config.yaml'), 'w') as f: 147 | yaml.dump(config, f, sort_keys=False) 148 | 149 | train_loader, val_loader = make_data_loaders() 150 | if config.get('data_norm') is None: 151 | config['data_norm'] = { 152 | 'inp': {'sub': [0], 'div': [1]}, 153 | 'gt': {'sub': [0], 'div': [1]} 154 | } 155 | 156 | model, optimizer, epoch_start, lr_scheduler = prepare_training() 157 | model.optimizer = optimizer 158 | lr_scheduler = CosineAnnealingLR(model.optimizer, config['epoch_max'], eta_min=config.get('lr_min')) 159 | model = model.cuda() 160 | model = torch.nn.parallel.DistributedDataParallel( 161 | model, 162 | device_ids=[args.local_rank], 163 | output_device=args.local_rank, 164 | find_unused_parameters=True, 165 | broadcast_buffers=False 166 | ) 167 | model = model.module 168 | 169 | sam_checkpoint = torch.load(config['sam_checkpoint']) 170 | model.load_state_dict(sam_checkpoint, strict=False) 171 | for name, para in model.named_parameters(): 172 | if "image_encoder" in name and "prompt_generator" not in name: 173 | para.requires_grad_(False) 174 | if local_rank == 0: 175 | model_total_params = sum(p.numel() for p in model.parameters()) 176 | model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 177 | print('model_grad_params:' + str(model_grad_params), '\nmodel_total_params:' + str(model_total_params)) 178 | 179 | epoch_max = config['epoch_max'] 180 | epoch_val = config.get('epoch_val') 181 | max_val_v = -1e18 if config['eval_type'] != 'ber' else 1e8 182 | timer = utils.Timer() 183 | for epoch in range(epoch_start, epoch_max + 1): 184 | train_loader.sampler.set_epoch(epoch) 185 | t_epoch_start = timer.t() 186 | train_loss_G = train(train_loader, model) 187 | lr_scheduler.step() 188 | 189 | if local_rank == 0: 190 | log_info = ['epoch {}/{}'.format(epoch, epoch_max)] 191 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) 192 | log_info.append('train G: loss={:.4f}'.format(train_loss_G)) 193 | writer.add_scalars('loss', {'train G': train_loss_G}, epoch) 194 | 195 | model_spec = config['model'] 196 | model_spec['sd'] = model.state_dict() 197 | optimizer_spec = config['optimizer'] 198 | optimizer_spec['sd'] = optimizer.state_dict() 199 | 200 | save(config, model, save_path, 'last') 201 | 202 | if (epoch_val is not None) and (epoch % epoch_val == 0): 203 | result1, result2, result3, result4, metric1, metric2, metric3, metric4 = eval_psnr(val_loader, model, 204 | eval_type=config.get('eval_type')) 205 | 206 | if local_rank == 0: 207 | log_info.append('val: {}={:.4f}'.format(metric1, result1)) 208 | writer.add_scalars(metric1, {'val': result1}, epoch) 209 | log_info.append('val: {}={:.4f}'.format(metric2, result2)) 210 | writer.add_scalars(metric2, {'val': result2}, epoch) 211 | log_info.append('val: {}={:.4f}'.format(metric3, result3)) 212 | writer.add_scalars(metric3, {'val': result3}, epoch) 213 | log_info.append('val: {}={:.4f}'.format(metric4, result4)) 214 | writer.add_scalars(metric4, {'val': result4}, epoch) 215 | 216 | if config['eval_type'] != 'ber': 217 | if result1 > max_val_v: 218 | max_val_v = result1 219 | save(config, model, save_path, 'best') 220 | else: 221 | if result3 < max_val_v: 222 | max_val_v = result3 223 | save(config, model, save_path, 'best') 224 | 225 | t = timer.t() 226 | prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1) 227 | t_epoch = utils.time_text(t - t_epoch_start) 228 | t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog) 229 | log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all)) 230 | 231 | log(', '.join(log_info)) 232 | writer.flush() 233 | 234 | 235 | def save(config, model, save_path, name): 236 | if config['model']['name'] == 'segformer' or config['model']['name'] == 'setr': 237 | if config['model']['args']['encoder_mode']['name'] == 'evp': 238 | prompt_generator = model.encoder.backbone.prompt_generator.state_dict() 239 | decode_head = model.encoder.decode_head.state_dict() 240 | torch.save({"prompt": prompt_generator, "decode_head": decode_head}, 241 | os.path.join(save_path, f"prompt_epoch_{name}.pth")) 242 | else: 243 | torch.save(model.state_dict(), os.path.join(save_path, f"model_epoch_{name}.pth")) 244 | else: 245 | torch.save(model.state_dict(), os.path.join(save_path, f"model_epoch_{name}.pth")) 246 | 247 | 248 | if __name__ == '__main__': 249 | parser = argparse.ArgumentParser() 250 | parser.add_argument('--config', default="configs/train/setr/train_setr_evp_cod.yaml") 251 | parser.add_argument('--name', default=None) 252 | parser.add_argument('--tag', default=None) 253 | parser.add_argument("--local_rank", type=int, default=-1, help="") 254 | args = parser.parse_args() 255 | 256 | with open(args.config, 'r') as f: 257 | config = yaml.load(f, Loader=yaml.FullLoader) 258 | if local_rank == 0: 259 | print('config loaded.') 260 | 261 | save_name = args.name 262 | if save_name is None: 263 | save_name = '_' + args.config.split('/')[-1][:-len('.yaml')] 264 | if args.tag is not None: 265 | save_name += '_' + args.tag 266 | save_path = os.path.join('./save', save_name) 267 | 268 | main(config, save_path, args=args) 269 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | 5 | import torch 6 | import numpy as np 7 | from torch.optim import SGD, Adam, AdamW 8 | from tensorboardX import SummaryWriter 9 | 10 | import sod_metric 11 | class Averager(): 12 | 13 | def __init__(self): 14 | self.n = 0.0 15 | self.v = 0.0 16 | 17 | def add(self, v, n=1.0): 18 | self.v = (self.v * self.n + v * n) / (self.n + n) 19 | self.n += n 20 | 21 | def item(self): 22 | return self.v 23 | 24 | 25 | class Timer(): 26 | 27 | def __init__(self): 28 | self.v = time.time() 29 | 30 | def s(self): 31 | self.v = time.time() 32 | 33 | def t(self): 34 | return time.time() - self.v 35 | 36 | 37 | def time_text(t): 38 | if t >= 3600: 39 | return '{:.1f}h'.format(t / 3600) 40 | elif t >= 60: 41 | return '{:.1f}m'.format(t / 60) 42 | else: 43 | return '{:.1f}s'.format(t) 44 | 45 | 46 | _log_path = None 47 | 48 | 49 | def set_log_path(path): 50 | global _log_path 51 | _log_path = path 52 | 53 | 54 | def log(obj, filename='log.txt'): 55 | print(obj) 56 | if _log_path is not None: 57 | with open(os.path.join(_log_path, filename), 'a') as f: 58 | print(obj, file=f) 59 | 60 | 61 | def ensure_path(path, remove=True): 62 | basename = os.path.basename(path.rstrip('/')) 63 | if os.path.exists(path): 64 | if remove and (basename.startswith('_') 65 | or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'): 66 | shutil.rmtree(path) 67 | os.makedirs(path, exist_ok=True) 68 | else: 69 | os.makedirs(path, exist_ok=True) 70 | 71 | 72 | def set_save_path(save_path, remove=True): 73 | ensure_path(save_path, remove=remove) 74 | set_log_path(save_path) 75 | writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) 76 | return log, writer 77 | 78 | 79 | def compute_num_params(model, text=False): 80 | tot = int(sum([np.prod(p.shape) for p in model.parameters()])) 81 | if text: 82 | if tot >= 1e6: 83 | return '{:.1f}M'.format(tot / 1e6) 84 | else: 85 | return '{:.1f}K'.format(tot / 1e3) 86 | else: 87 | return tot 88 | 89 | 90 | def make_optimizer(param_list, optimizer_spec, load_sd=False): 91 | Optimizer = { 92 | 'sgd': SGD, 93 | 'adam': Adam, 94 | 'adamw': AdamW 95 | }[optimizer_spec['name']] 96 | optimizer = Optimizer(param_list, **optimizer_spec['args']) 97 | if load_sd: 98 | optimizer.load_state_dict(optimizer_spec['sd']) 99 | return optimizer 100 | 101 | 102 | def make_coord(shape, ranges=None, flatten=True): 103 | """ Make coordinates at grid centers. 104 | """ 105 | coord_seqs = [] 106 | for i, n in enumerate(shape): 107 | if ranges is None: 108 | v0, v1 = -1, 1 109 | else: 110 | v0, v1 = ranges[i] 111 | r = (v1 - v0) / (2 * n) 112 | seq = v0 + r + (2 * r) * torch.arange(n).float() 113 | coord_seqs.append(seq) 114 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 115 | # if flatten: 116 | # ret = ret.view(-1, ret.shape[-1]) 117 | 118 | return ret 119 | 120 | 121 | 122 | def calc_cod(y_pred, y_true): 123 | batchsize = y_true.shape[0] 124 | 125 | metric_FM = sod_metric.Fmeasure() 126 | metric_WFM = sod_metric.WeightedFmeasure() 127 | metric_SM = sod_metric.Smeasure() 128 | metric_EM = sod_metric.Emeasure() 129 | metric_MAE = sod_metric.MAE() 130 | with torch.no_grad(): 131 | assert y_pred.shape == y_true.shape 132 | 133 | for i in range(batchsize): 134 | true, pred = \ 135 | y_true[i, 0].cpu().data.numpy() * 255, y_pred[i, 0].cpu().data.numpy() * 255 136 | 137 | metric_FM.step(pred=pred, gt=true) 138 | metric_WFM.step(pred=pred, gt=true) 139 | metric_SM.step(pred=pred, gt=true) 140 | metric_EM.step(pred=pred, gt=true) 141 | metric_MAE.step(pred=pred, gt=true) 142 | 143 | fm = metric_FM.get_results()["fm"] 144 | wfm = metric_WFM.get_results()["wfm"] 145 | sm = metric_SM.get_results()["sm"] 146 | em = metric_EM.get_results()["em"]["curve"].mean() 147 | mae = metric_MAE.get_results()["mae"] 148 | 149 | return sm, em, wfm, mae 150 | 151 | 152 | from sklearn.metrics import precision_recall_curve 153 | 154 | 155 | def calc_f1(y_pred,y_true): 156 | batchsize = y_true.shape[0] 157 | with torch.no_grad(): 158 | assert y_pred.shape == y_true.shape 159 | f1, auc = 0, 0 160 | y_true = y_true.cpu().numpy() 161 | y_pred = y_pred.cpu().numpy() 162 | for i in range(batchsize): 163 | true = y_true[i].flatten() 164 | true = true.astype(np.int) 165 | pred = y_pred[i].flatten() 166 | 167 | precision, recall, thresholds = precision_recall_curve(true, pred) 168 | 169 | # auc 170 | auc += roc_auc_score(true, pred) 171 | # auc += roc_auc_score(np.array(true>0).astype(np.int), pred) 172 | f1 += max([(2 * p * r) / (p + r+1e-10) for p, r in zip(precision, recall)]) 173 | 174 | return f1/batchsize, auc/batchsize, np.array(0), np.array(0) 175 | 176 | def calc_fmeasure(y_pred,y_true): 177 | batchsize = y_true.shape[0] 178 | 179 | mae, preds, gts = [], [], [] 180 | with torch.no_grad(): 181 | for i in range(batchsize): 182 | gt_float, pred_float = \ 183 | y_true[i, 0].cpu().data.numpy(), y_pred[i, 0].cpu().data.numpy() 184 | 185 | # # MAE 186 | mae.append(np.sum(cv2.absdiff(gt_float.astype(float), pred_float.astype(float))) / ( 187 | pred_float.shape[1] * pred_float.shape[0])) 188 | # mae.append(np.mean(np.abs(pred_float - gt_float))) 189 | # 190 | pred = np.uint8(pred_float * 255) 191 | gt = np.uint8(gt_float * 255) 192 | 193 | pred_float_ = np.where(pred > min(1.5 * np.mean(pred), 255), np.ones_like(pred_float), 194 | np.zeros_like(pred_float)) 195 | gt_float_ = np.where(gt > min(1.5 * np.mean(gt), 255), np.ones_like(pred_float), 196 | np.zeros_like(pred_float)) 197 | 198 | preds.extend(pred_float_.ravel()) 199 | gts.extend(gt_float_.ravel()) 200 | 201 | RECALL = recall_score(gts, preds) 202 | PERC = precision_score(gts, preds) 203 | 204 | fmeasure = (1 + 0.3) * PERC * RECALL / (0.3 * PERC + RECALL) 205 | MAE = np.mean(mae) 206 | 207 | return fmeasure, MAE, np.array(0), np.array(0) 208 | 209 | from sklearn.metrics import roc_auc_score,recall_score,precision_score 210 | import cv2 211 | def calc_ber(y_pred, y_true): 212 | batchsize = y_true.shape[0] 213 | y_pred, y_true = y_pred.permute(0, 2, 3, 1).squeeze(-1), y_true.permute(0, 2, 3, 1).squeeze(-1) 214 | with torch.no_grad(): 215 | assert y_pred.shape == y_true.shape 216 | pos_err, neg_err, ber = 0, 0, 0 217 | y_true = y_true.cpu().numpy() 218 | y_pred = y_pred.cpu().numpy() 219 | for i in range(batchsize): 220 | true = y_true[i].flatten() 221 | pred = y_pred[i].flatten() 222 | 223 | TP, TN, FP, FN, BER, ACC = get_binary_classification_metrics(pred * 255, 224 | true * 255, 125) 225 | pos_err += (1 - TP / (TP + FN)) * 100 226 | neg_err += (1 - TN / (TN + FP)) * 100 227 | 228 | return pos_err / batchsize, neg_err / batchsize, (pos_err + neg_err) / 2 / batchsize, np.array(0) 229 | 230 | def get_binary_classification_metrics(pred, gt, threshold=None): 231 | if threshold is not None: 232 | gt = (gt > threshold) 233 | pred = (pred > threshold) 234 | TP = np.logical_and(gt, pred).sum() 235 | TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum() 236 | FN = np.logical_and(gt, np.logical_not(pred)).sum() 237 | FP = np.logical_and(np.logical_not(gt), pred).sum() 238 | BER = cal_ber(TN, TP, FN, FP) 239 | ACC = cal_acc(TN, TP, FN, FP) 240 | return TP, TN, FP, FN, BER, ACC 241 | 242 | def cal_ber(tn, tp, fn, fp): 243 | return 0.5*(fp/(tn+fp) + fn/(fn+tp)) 244 | 245 | def cal_acc(tn, tp, fn, fp): 246 | return (tp + tn) / (tp + tn + fp + fn) 247 | 248 | def _sigmoid(x): 249 | return 1 / (1 + np.exp(-x)) 250 | 251 | 252 | def _eval_pr(y_pred, y, num): 253 | prec, recall = torch.zeros(num), torch.zeros(num) 254 | thlist = torch.linspace(0, 1 - 1e-10, num) 255 | for i in range(num): 256 | y_temp = (y_pred >= thlist[i]).float() 257 | tp = (y_temp * y).sum() 258 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 259 | 1e-20) 260 | return prec, recall 261 | 262 | def _S_object(pred, gt): 263 | fg = torch.where(gt == 0, torch.zeros_like(pred), pred) 264 | bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred) 265 | o_fg = _object(fg, gt) 266 | o_bg = _object(bg, 1 - gt) 267 | u = gt.mean() 268 | Q = u * o_fg + (1 - u) * o_bg 269 | return Q 270 | 271 | def _object(pred, gt): 272 | temp = pred[gt == 1] 273 | x = temp.mean() 274 | sigma_x = temp.std() 275 | score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) 276 | 277 | return score 278 | 279 | def _S_region(pred, gt): 280 | X, Y = _centroid(gt) 281 | gt1, gt2, gt3, gt4, w1, w2, w3, w4 = _divideGT(gt, X, Y) 282 | p1, p2, p3, p4 = _dividePrediction(pred, X, Y) 283 | Q1 = _ssim(p1, gt1) 284 | Q2 = _ssim(p2, gt2) 285 | Q3 = _ssim(p3, gt3) 286 | Q4 = _ssim(p4, gt4) 287 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4 288 | return Q 289 | 290 | def _centroid(gt): 291 | rows, cols = gt.size()[-2:] 292 | gt = gt.view(rows, cols) 293 | if gt.sum() == 0: 294 | X = torch.eye(1) * round(cols / 2) 295 | Y = torch.eye(1) * round(rows / 2) 296 | else: 297 | total = gt.sum() 298 | i = torch.from_numpy(np.arange(0, cols)).float().cuda() 299 | j = torch.from_numpy(np.arange(0, rows)).float().cuda() 300 | X = torch.round((gt.sum(dim=0) * i).sum() / total + 1e-20) 301 | Y = torch.round((gt.sum(dim=1) * j).sum() / total + 1e-20) 302 | return X.long(), Y.long() 303 | 304 | 305 | def _divideGT(gt, X, Y): 306 | h, w = gt.size()[-2:] 307 | area = h * w 308 | gt = gt.view(h, w) 309 | LT = gt[:Y, :X] 310 | RT = gt[:Y, X:w] 311 | LB = gt[Y:h, :X] 312 | RB = gt[Y:h, X:w] 313 | X = X.float() 314 | Y = Y.float() 315 | w1 = X * Y / area 316 | w2 = (w - X) * Y / area 317 | w3 = X * (h - Y) / area 318 | w4 = 1 - w1 - w2 - w3 319 | return LT, RT, LB, RB, w1, w2, w3, w4 320 | 321 | 322 | def _dividePrediction(pred, X, Y): 323 | h, w = pred.size()[-2:] 324 | pred = pred.view(h, w) 325 | LT = pred[:Y, :X] 326 | RT = pred[:Y, X:w] 327 | LB = pred[Y:h, :X] 328 | RB = pred[Y:h, X:w] 329 | return LT, RT, LB, RB 330 | 331 | 332 | def _ssim(pred, gt): 333 | gt = gt.float() 334 | h, w = pred.size()[-2:] 335 | N = h * w 336 | x = pred.mean() 337 | y = gt.mean() 338 | sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20) 339 | sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20) 340 | sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20) 341 | 342 | aplha = 4 * x * y * sigma_xy 343 | beta = (x * x + y * y) * (sigma_x2 + sigma_y2) 344 | 345 | if aplha != 0: 346 | Q = aplha / (beta + 1e-20) 347 | elif aplha == 0 and beta == 0: 348 | Q = 1.0 349 | else: 350 | Q = 0 351 | return Q 352 | 353 | def _eval_e(y_pred, y, num): 354 | score = torch.zeros(num) 355 | thlist = torch.linspace(0, 1 - 1e-10, num) 356 | for i in range(num): 357 | y_pred_th = (y_pred >= thlist[i]).float() 358 | fm = y_pred_th - y_pred_th.mean() 359 | gt = y - y.mean() 360 | align_matrix = 2 * gt * fm / (gt * gt + fm * fm + 1e-20) 361 | enhanced = ((align_matrix + 1) * (align_matrix + 1)) / 4 362 | score[i] = torch.sum(enhanced) / (y.numel() - 1 + 1e-20) 363 | return score 364 | --------------------------------------------------------------------------------