├── Model ├── ImageEncoder │ ├── __init__.py │ ├── tinyvit │ │ ├── adapter_block.py │ │ ├── block.py │ │ ├── tiny_vit.py │ │ └── utils.py │ └── vit │ │ ├── __init__.py │ │ ├── adapter_block.py │ │ └── block.py ├── common │ ├── MaskDecoder │ │ ├── __init__.py │ │ └── two_way_transformer.py │ ├── __init__.py │ ├── adapter.py │ ├── layer_norm.py │ ├── mlp.py │ └── two_way_transformer.py ├── discriminator.py ├── model.py ├── prompt.py ├── sam │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── sam.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── unet.py └── vnet.py ├── README.md ├── SampleData.rar ├── dataloader ├── TwoStreamBatchSampler.py ├── dataset.py └── transforms.py ├── prediction.py ├── prediction_ACDC.py ├── requirements.txt ├── train_semi_SAM.py ├── train_semi_SAM_ACDC.py ├── trainer.py └── utils ├── losses.py └── utils.py /Model/ImageEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .tinyvit.tiny_vit import TinyViT 2 | from .vit import AdapterBlock, Block 3 | -------------------------------------------------------------------------------- /Model/ImageEncoder/tinyvit/adapter_block.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ...common import Adapter 8 | from .utils import Conv2d_BN, DropPath, Mlp 9 | 10 | 11 | class Attention(torch.nn.Module): 12 | def __init__(self, dim, key_dim, num_heads=8, 13 | attn_ratio=4, 14 | resolution=(14, 14), 15 | ): 16 | super().__init__() 17 | # (h, w) 18 | assert isinstance(resolution, tuple) and len(resolution) == 2 19 | self.num_heads = num_heads 20 | self.scale = key_dim ** -0.5 21 | self.key_dim = key_dim 22 | self.nh_kd = nh_kd = key_dim * num_heads 23 | self.d = int(attn_ratio * key_dim) 24 | self.dh = int(attn_ratio * key_dim) * num_heads 25 | self.attn_ratio = attn_ratio 26 | h = self.dh + nh_kd * 2 27 | 28 | self.norm = nn.LayerNorm(dim) 29 | self.qkv = nn.Linear(dim, h) 30 | self.proj = nn.Linear(self.dh, dim) 31 | 32 | points = list(itertools.product( 33 | range(resolution[0]), range(resolution[1]))) 34 | N = len(points) 35 | attention_offsets = {} 36 | idxs = [] 37 | for p1 in points: 38 | for p2 in points: 39 | offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) 40 | if offset not in attention_offsets: 41 | attention_offsets[offset] = len(attention_offsets) 42 | idxs.append(attention_offsets[offset]) 43 | self.attention_biases = torch.nn.Parameter( 44 | torch.zeros(num_heads, len(attention_offsets))) 45 | self.register_buffer('attention_bias_idxs', 46 | torch.LongTensor(idxs).view(N, N), 47 | persistent=False) 48 | 49 | @torch.no_grad() 50 | def train(self, mode=True): 51 | super().train(mode) 52 | if mode and hasattr(self, 'ab'): 53 | del self.ab 54 | else: 55 | self.ab = self.attention_biases[:, self.attention_bias_idxs] 56 | # self.register_buffer('ab', 57 | # self.attention_biases[:, self.attention_bias_idxs], 58 | # persistent=False) 59 | def forward(self, x): # x (B,N,C) 60 | B, N, _ = x.shape 61 | 62 | # Normalization 63 | x = self.norm(x) 64 | 65 | qkv = self.qkv(x) 66 | # (B, N, num_heads, d) 67 | q, k, v = qkv.view(B, N, self.num_heads, - 68 | 1).split([self.key_dim, self.key_dim, self.d], dim=3) 69 | # (B, num_heads, N, d) 70 | q = q.permute(0, 2, 1, 3) 71 | k = k.permute(0, 2, 1, 3) 72 | v = v.permute(0, 2, 1, 3) 73 | 74 | attn = ( 75 | (q @ k.transpose(-2, -1)) * self.scale 76 | + 77 | (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab) 78 | ) 79 | attn = attn.softmax(dim=-1) 80 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) 81 | x = self.proj(x) 82 | return x 83 | 84 | class TinyViTAdapterBlock(nn.Module): 85 | r""" TinyViT Block. 86 | 87 | Args: 88 | dim (int): Number of input channels. 89 | input_resolution (tuple[int, int]): Input resulotion. 90 | num_heads (int): Number of attention heads. 91 | window_size (int): Window size. 92 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 93 | drop (float, optional): Dropout rate. Default: 0.0 94 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 95 | local_conv_size (int): the kernel size of the convolution between 96 | Attention and MLP. Default: 3 97 | activation: the activation function. Default: nn.GELU 98 | """ 99 | 100 | def __init__(self, args, dim, input_resolution, num_heads, window_size=7, 101 | mlp_ratio=4., drop=0., drop_path=0., 102 | local_conv_size=3, 103 | activation=nn.GELU, 104 | ): 105 | super().__init__() 106 | self.args = args, 107 | self.dim = dim 108 | self.input_resolution = input_resolution 109 | self.num_heads = num_heads 110 | assert window_size > 0, 'window_size must be greater than 0' 111 | self.window_size = window_size 112 | self.mlp_ratio = mlp_ratio 113 | 114 | self.drop_path = DropPath( 115 | drop_path) if drop_path > 0. else nn.Identity() 116 | 117 | assert dim % num_heads == 0, 'dim must be divisible by num_heads' 118 | head_dim = dim // num_heads 119 | 120 | window_resolution = (window_size, window_size) 121 | self.attn = Attention(dim, head_dim, num_heads, 122 | attn_ratio=1, resolution=window_resolution) 123 | 124 | mlp_hidden_dim = int(dim * mlp_ratio) 125 | mlp_activation = activation 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 127 | act_layer=mlp_activation, drop=drop) 128 | 129 | self.MLP_Adapter = Adapter(dim, skip_connect=False) # MLP-adapter, no skip connection 130 | self.Space_Adapter = Adapter(dim) # with skip connection 131 | self.Depth_Adapter = Adapter(dim, skip_connect=False) # no skip connection 132 | 133 | pad = local_conv_size // 2 134 | self.local_conv = Conv2d_BN( 135 | dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) 136 | 137 | def forward(self, x): 138 | H, W = self.input_resolution 139 | B, L, C = x.shape 140 | assert L == H * W, "input feature has wrong size" 141 | res_x = x 142 | if H == self.window_size and W == self.window_size: 143 | x = self.attn(x) 144 | else: 145 | x = x.view(B, H, W, C) 146 | pad_b = (self.window_size - H % 147 | self.window_size) % self.window_size 148 | pad_r = (self.window_size - W % 149 | self.window_size) % self.window_size 150 | padding = pad_b > 0 or pad_r > 0 151 | 152 | if padding: 153 | x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) 154 | 155 | pH, pW = H + pad_b, W + pad_r 156 | nH = pH // self.window_size 157 | nW = pW // self.window_size 158 | # window partition 159 | x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( 160 | B * nH * nW, self.window_size * self.window_size, C) 161 | 162 | ## 3d branch 163 | if self.args[0].thd: 164 | from einops import rearrange 165 | hh, ww = x.shape[1], x.shape[2] 166 | depth = self.args.chunk 167 | xd = rearrange(x, '(b d) h w c -> (b h w) d c ', d=depth) 168 | # xd = rearrange(xd, '(b d) n c -> (b n) d c', d=self.in_chans) 169 | xd = self.norm1(xd) 170 | dh, _ = closest_numbers(depth) 171 | xd = rearrange(xd, 'bhw (dh dw) c -> bhw dh dw c', dh= dh) 172 | xd = self.Depth_Adapter(self.attn(xd)) 173 | xd = rearrange(xd, '(b n) dh dw c ->(b dh dw) n c', n= hh * ww ) 174 | 175 | x = self.attn(x) 176 | x = self.Space_Adapter(x) 177 | 178 | if self.args[0].thd: 179 | xd = rearrange(xd, 'b (hh ww) c -> b hh ww c', hh= hh ) 180 | x = x + xd 181 | 182 | # window reverse 183 | x = x.view(B, nH, nW, self.window_size, self.window_size, 184 | C).transpose(2, 3).reshape(B, pH, pW, C) 185 | 186 | if padding: 187 | x = x[:, :H, :W].contiguous() 188 | 189 | x = x.view(B, L, C) 190 | 191 | x = res_x + self.drop_path(x) 192 | 193 | x = x.transpose(1, 2).reshape(B, C, H, W) 194 | x = self.local_conv(x) 195 | x = x.view(B, C, L).transpose(1, 2) 196 | 197 | x = x + self.drop_path(self.mlp(x)) + 0.5 * self.MLP_Adapter(x) 198 | return x 199 | 200 | def extra_repr(self) -> str: 201 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 202 | f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" 203 | 204 | def closest_numbers(target): 205 | a = int(target ** 0.5) 206 | b = a + 1 207 | while True: 208 | if a * b == target: 209 | return (a, b) 210 | elif a * b < target: 211 | b += 1 212 | else: 213 | a -= 1 -------------------------------------------------------------------------------- /Model/ImageEncoder/tinyvit/block.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import DropPath as TimmDropPath 7 | 8 | from .utils import Conv2d_BN, DropPath, Mlp 9 | 10 | 11 | class Attention(torch.nn.Module): 12 | def __init__(self, dim, key_dim, num_heads=8, 13 | attn_ratio=4, 14 | resolution=(14, 14), 15 | ): 16 | super().__init__() 17 | # (h, w) 18 | assert isinstance(resolution, tuple) and len(resolution) == 2 19 | self.num_heads = num_heads 20 | self.scale = key_dim ** -0.5 21 | self.key_dim = key_dim 22 | self.nh_kd = nh_kd = key_dim * num_heads 23 | self.d = int(attn_ratio * key_dim) 24 | self.dh = int(attn_ratio * key_dim) * num_heads 25 | self.attn_ratio = attn_ratio 26 | h = self.dh + nh_kd * 2 27 | 28 | self.norm = nn.LayerNorm(dim) 29 | self.qkv = nn.Linear(dim, h) 30 | self.proj = nn.Linear(self.dh, dim) 31 | 32 | points = list(itertools.product( 33 | range(resolution[0]), range(resolution[1]))) 34 | N = len(points) 35 | attention_offsets = {} 36 | idxs = [] 37 | for p1 in points: 38 | for p2 in points: 39 | offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) 40 | if offset not in attention_offsets: 41 | attention_offsets[offset] = len(attention_offsets) 42 | idxs.append(attention_offsets[offset]) 43 | self.attention_biases = torch.nn.Parameter( 44 | torch.zeros(num_heads, len(attention_offsets))) 45 | self.register_buffer('attention_bias_idxs', 46 | torch.LongTensor(idxs).view(N, N), 47 | persistent=False) 48 | 49 | @torch.no_grad() 50 | def train(self, mode=True): 51 | super().train(mode) 52 | if mode and hasattr(self, 'ab'): 53 | del self.ab 54 | else: 55 | self.ab = self.attention_biases[:, self.attention_bias_idxs] 56 | # self.register_buffer('ab', 57 | # self.attention_biases[:, self.attention_bias_idxs], 58 | # persistent=False) 59 | def forward(self, x): # x (B,N,C) 60 | B, N, _ = x.shape 61 | 62 | # Normalization 63 | x = self.norm(x) 64 | 65 | qkv = self.qkv(x) 66 | # (B, N, num_heads, d) 67 | q, k, v = qkv.view(B, N, self.num_heads, - 68 | 1).split([self.key_dim, self.key_dim, self.d], dim=3) 69 | # (B, num_heads, N, d) 70 | q = q.permute(0, 2, 1, 3) 71 | k = k.permute(0, 2, 1, 3) 72 | v = v.permute(0, 2, 1, 3) 73 | 74 | attn = ( 75 | (q @ k.transpose(-2, -1)) * self.scale 76 | + 77 | (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab) 78 | ) 79 | attn = attn.softmax(dim=-1) 80 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) 81 | x = self.proj(x) 82 | return x 83 | 84 | class TinyViTBlock(nn.Module): 85 | r""" TinyViT Block. 86 | 87 | Args: 88 | dim (int): Number of input channels. 89 | input_resolution (tuple[int, int]): Input resulotion. 90 | num_heads (int): Number of attention heads. 91 | window_size (int): Window size. 92 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 93 | drop (float, optional): Dropout rate. Default: 0.0 94 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 95 | local_conv_size (int): the kernel size of the convolution between 96 | Attention and MLP. Default: 3 97 | activation: the activation function. Default: nn.GELU 98 | """ 99 | 100 | def __init__(self, args, dim, input_resolution, num_heads, window_size=7, 101 | mlp_ratio=4., drop=0., drop_path=0., 102 | local_conv_size=3, 103 | activation=nn.GELU, 104 | ): 105 | super().__init__() 106 | self.dim = dim 107 | self.input_resolution = input_resolution 108 | self.num_heads = num_heads 109 | assert window_size > 0, 'window_size must be greater than 0' 110 | self.window_size = window_size 111 | self.mlp_ratio = mlp_ratio 112 | 113 | self.drop_path = DropPath( 114 | drop_path) if drop_path > 0. else nn.Identity() 115 | 116 | assert dim % num_heads == 0, 'dim must be divisible by num_heads' 117 | head_dim = dim // num_heads 118 | 119 | window_resolution = (window_size, window_size) 120 | self.attn = Attention(dim, head_dim, num_heads, 121 | attn_ratio=1, resolution=window_resolution) 122 | 123 | mlp_hidden_dim = int(dim * mlp_ratio) 124 | mlp_activation = activation 125 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 126 | act_layer=mlp_activation, drop=drop) 127 | 128 | pad = local_conv_size // 2 129 | self.local_conv = Conv2d_BN( 130 | dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) 131 | 132 | def forward(self, x): 133 | H, W = self.input_resolution 134 | B, L, C = x.shape 135 | assert L == H * W, "input feature has wrong size" 136 | res_x = x 137 | if H == self.window_size and W == self.window_size: 138 | x = self.attn(x) 139 | else: 140 | x = x.view(B, H, W, C) 141 | pad_b = (self.window_size - H % 142 | self.window_size) % self.window_size 143 | pad_r = (self.window_size - W % 144 | self.window_size) % self.window_size 145 | padding = pad_b > 0 or pad_r > 0 146 | 147 | if padding: 148 | x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) 149 | 150 | pH, pW = H + pad_b, W + pad_r 151 | nH = pH // self.window_size 152 | nW = pW // self.window_size 153 | # window partition 154 | x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( 155 | B * nH * nW, self.window_size * self.window_size, C) 156 | x = self.attn(x) 157 | # window reverse 158 | x = x.view(B, nH, nW, self.window_size, self.window_size, 159 | C).transpose(2, 3).reshape(B, pH, pW, C) 160 | 161 | if padding: 162 | x = x[:, :H, :W].contiguous() 163 | 164 | x = x.view(B, L, C) 165 | 166 | x = res_x + self.drop_path(x) 167 | 168 | x = x.transpose(1, 2).reshape(B, C, H, W) 169 | x = self.local_conv(x) 170 | x = x.view(B, C, L).transpose(1, 2) 171 | 172 | x = x + self.drop_path(self.mlp(x)) 173 | return x 174 | 175 | def extra_repr(self) -> str: 176 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 177 | f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" -------------------------------------------------------------------------------- /Model/ImageEncoder/tinyvit/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath as TimmDropPath 5 | 6 | 7 | class Conv2d_BN(torch.nn.Sequential): 8 | def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, 9 | groups=1, bn_weight_init=1): 10 | super().__init__() 11 | self.add_module('c', torch.nn.Conv2d( 12 | a, b, ks, stride, pad, dilation, groups, bias=False)) 13 | bn = torch.nn.BatchNorm2d(b) 14 | torch.nn.init.constant_(bn.weight, bn_weight_init) 15 | torch.nn.init.constant_(bn.bias, 0) 16 | self.add_module('bn', bn) 17 | 18 | @torch.no_grad() 19 | def fuse(self): 20 | c, bn = self._modules.values() 21 | w = bn.weight / (bn.running_var + bn.eps)**0.5 22 | w = c.weight * w[:, None, None, None] 23 | b = bn.bias - bn.running_mean * bn.weight / \ 24 | (bn.running_var + bn.eps)**0.5 25 | m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( 26 | 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 27 | m.weight.data.copy_(w) 28 | m.bias.data.copy_(b) 29 | return m 30 | 31 | class Mlp(nn.Module): 32 | def __init__(self, in_features, hidden_features=None, 33 | out_features=None, act_layer=nn.GELU, drop=0.): 34 | super().__init__() 35 | out_features = out_features or in_features 36 | hidden_features = hidden_features or in_features 37 | self.norm = nn.LayerNorm(in_features) 38 | self.fc1 = nn.Linear(in_features, hidden_features) 39 | self.fc2 = nn.Linear(hidden_features, out_features) 40 | self.act = act_layer() 41 | self.drop = nn.Dropout(drop) 42 | 43 | def forward(self, x): 44 | x = self.norm(x) 45 | 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | class DropPath(TimmDropPath): 54 | def __init__(self, drop_prob=None): 55 | super().__init__(drop_prob=drop_prob) 56 | self.drop_prob = drop_prob 57 | 58 | def __repr__(self): 59 | msg = super().__repr__() 60 | msg += f'(drop_prob={self.drop_prob})' 61 | return msg -------------------------------------------------------------------------------- /Model/ImageEncoder/vit/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapter_block import AdapterBlock 2 | from .block import Block -------------------------------------------------------------------------------- /Model/ImageEncoder/vit/adapter_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Type 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | from ...common import Adapter, LayerNorm2d 10 | 11 | 12 | class AdapterBlock(nn.Module): 13 | """Transformer blocks with support of window attention and residual propagation blocks""" 14 | 15 | def __init__( 16 | self, 17 | args, 18 | dim: int, 19 | num_heads: int, 20 | mlp_ratio: float = 4.0, 21 | scale: float = 0.5, 22 | qkv_bias: bool = True, 23 | norm_layer: Type[nn.Module] = nn.LayerNorm, 24 | act_layer: Type[nn.Module] = nn.GELU, 25 | use_rel_pos: bool = False, 26 | rel_pos_zero_init: bool = True, 27 | window_size: int = 0, 28 | input_size: Optional[Tuple[int, int]] = None, 29 | ) -> None: 30 | """ 31 | Args: 32 | dim (int): Number of input channels. 33 | num_heads (int): Number of attention heads in each ViT block. 34 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 35 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 36 | norm_layer (nn.Module): Normalization layer. 37 | act_layer (nn.Module): Activation layer. 38 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 39 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 40 | window_size (int): Window size for window attention blocks. If it equals 0, then 41 | use global attention. 42 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 43 | positional parameter size. 44 | """ 45 | super().__init__() 46 | self.args = args 47 | self.norm1 = norm_layer(dim) 48 | self.attn = Attention( 49 | dim, 50 | num_heads=num_heads, 51 | qkv_bias=qkv_bias, 52 | use_rel_pos=use_rel_pos, 53 | rel_pos_zero_init=rel_pos_zero_init, 54 | input_size=input_size if window_size == 0 else (window_size, window_size), 55 | ) 56 | self.MLP_Adapter = Adapter(dim, skip_connect=False) # MLP-adapter, no skip connection 57 | self.Space_Adapter = Adapter(dim) # with skip connection 58 | self.scale = scale 59 | self.Depth_Adapter = Adapter(dim, skip_connect=False) # no skip connection 60 | self.norm2 = norm_layer(dim) 61 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 62 | 63 | self.window_size = window_size 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | shortcut = x 67 | # Window partition 68 | if self.window_size > 0: 69 | H, W = x.shape[1], x.shape[2] 70 | x, pad_hw = window_partition(x, self.window_size) 71 | 72 | ## 3d branch 73 | if self.args.thd: 74 | hh, ww = x.shape[1], x.shape[2] 75 | depth = self.args.chunk 76 | xd = rearrange(x, '(b d) h w c -> (b h w) d c ', d=depth) 77 | # xd = rearrange(xd, '(b d) n c -> (b n) d c', d=self.in_chans) 78 | xd = self.norm1(xd) 79 | dh, _ = closest_numbers(depth) 80 | xd = rearrange(xd, 'bhw (dh dw) c -> bhw dh dw c', dh=dh) 81 | xd = self.Depth_Adapter(self.attn(xd)) 82 | xd = rearrange(xd, '(b n) dh dw c ->(b dh dw) n c', n=hh * ww) 83 | 84 | x = self.norm1(x) 85 | x = self.attn(x) 86 | x = self.Space_Adapter(x) 87 | 88 | if self.args.thd: 89 | xd = rearrange(xd, 'b (hh ww) c -> b hh ww c', hh= hh ) 90 | x = x + xd 91 | 92 | # Reverse window partition 93 | if self.window_size > 0: 94 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 95 | 96 | x = shortcut + x 97 | xn = self.norm2(x) 98 | x = x + self.mlp(xn) + self.scale * self.MLP_Adapter(xn) 99 | return x 100 | 101 | 102 | class Attention(nn.Module): 103 | """Multi-head Attention block with relative position embeddings.""" 104 | 105 | def __init__( 106 | self, 107 | dim: int, 108 | num_heads: int = 8, 109 | qkv_bias: bool = True, 110 | use_rel_pos: bool = False, 111 | rel_pos_zero_init: bool = True, 112 | input_size: Optional[Tuple[int, int]] = None, 113 | ) -> None: 114 | """ 115 | Args: 116 | dim (int): Number of input channels. 117 | num_heads (int): Number of attention heads. 118 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 119 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 120 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 121 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 122 | positional parameter size. 123 | """ 124 | super().__init__() 125 | self.num_heads = num_heads 126 | head_dim = dim // num_heads 127 | self.scale = head_dim**-0.5 128 | 129 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 130 | self.proj = nn.Linear(dim, dim) 131 | 132 | self.use_rel_pos = use_rel_pos 133 | if self.use_rel_pos: 134 | assert ( 135 | input_size is not None 136 | ), "Input size must be provided if using relative positional encoding." 137 | # initialize relative positional embeddings 138 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 139 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | B, H, W, _ = x.shape 143 | # qkv with shape (3, B, nHead, H * W, C) 144 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 145 | # q, k, v with shape (B * nHead, H * W, C) 146 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 147 | 148 | attn = (q * self.scale) @ k.transpose(-2, -1) 149 | 150 | if self.use_rel_pos: 151 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 152 | 153 | attn = attn.softmax(dim=-1) 154 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 155 | x = self.proj(x) 156 | 157 | return x 158 | 159 | 160 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 161 | """ 162 | Partition into non-overlapping windows with padding if needed. 163 | Args: 164 | x (tensor): input tokens with [B, H, W, C]. 165 | window_size (int): window size. 166 | 167 | Returns: 168 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 169 | (Hp, Wp): padded height and width before partition 170 | """ 171 | B, H, W, C = x.shape 172 | 173 | pad_h = (window_size - H % window_size) % window_size 174 | pad_w = (window_size - W % window_size) % window_size 175 | if pad_h > 0 or pad_w > 0: 176 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 177 | Hp, Wp = H + pad_h, W + pad_w 178 | 179 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 180 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 181 | return windows, (Hp, Wp) 182 | 183 | 184 | def window_unpartition( 185 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 186 | ) -> torch.Tensor: 187 | """ 188 | Window unpartition into original sequences and removing padding. 189 | Args: 190 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 191 | window_size (int): window size. 192 | pad_hw (Tuple): padded height and width (Hp, Wp). 193 | hw (Tuple): original height and width (H, W) before padding. 194 | 195 | Returns: 196 | x: unpartitioned sequences with [B, H, W, C]. 197 | """ 198 | Hp, Wp = pad_hw 199 | H, W = hw 200 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 201 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 202 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 203 | 204 | if Hp > H or Wp > W: 205 | x = x[:, :H, :W, :].contiguous() 206 | return x 207 | 208 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 209 | """ 210 | Get relative positional embeddings according to the relative positions of 211 | query and key sizes. 212 | Args: 213 | q_size (int): size of query q. 214 | k_size (int): size of key k. 215 | rel_pos (Tensor): relative position embeddings (L, C). 216 | 217 | Returns: 218 | Extracted positional embeddings according to relative positions. 219 | """ 220 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 221 | # Interpolate rel pos if needed. 222 | if rel_pos.shape[0] != max_rel_dist: 223 | # Interpolate rel pos. 224 | rel_pos_resized = F.interpolate( 225 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 226 | size=max_rel_dist, 227 | mode="linear", 228 | ) 229 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 230 | else: 231 | rel_pos_resized = rel_pos 232 | 233 | # Scale the coords with short length if shapes for q and k are different. 234 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 235 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 236 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 237 | 238 | return rel_pos_resized[relative_coords.long()] 239 | 240 | def add_decomposed_rel_pos( 241 | attn: torch.Tensor, 242 | q: torch.Tensor, 243 | rel_pos_h: torch.Tensor, 244 | rel_pos_w: torch.Tensor, 245 | q_size: Tuple[int, int], 246 | k_size: Tuple[int, int], 247 | ) -> torch.Tensor: 248 | """ 249 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 250 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 251 | Args: 252 | attn (Tensor): attention map. 253 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 254 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 255 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 256 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 257 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 258 | 259 | Returns: 260 | attn (Tensor): attention map with added relative positional embeddings. 261 | """ 262 | q_h, q_w = q_size 263 | k_h, k_w = k_size 264 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 265 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 266 | 267 | B, _, dim = q.shape 268 | r_q = q.reshape(B, q_h, q_w, dim) 269 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 270 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 271 | 272 | attn = ( 273 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 274 | ).view(B, q_h * q_w, k_h * k_w) 275 | 276 | return attn 277 | 278 | def closest_numbers(target): 279 | a = int(target ** 0.5) 280 | b = a + 1 281 | while True: 282 | if a * b == target: 283 | return (a, b) 284 | elif a * b < target: 285 | b += 1 286 | else: 287 | a -= 1 288 | 289 | 290 | class MLPBlock(nn.Module): 291 | def __init__( 292 | self, 293 | embedding_dim: int, 294 | mlp_dim: int, 295 | act: Type[nn.Module] = nn.GELU, 296 | ) -> None: 297 | super().__init__() 298 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 299 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 300 | self.act = act() 301 | 302 | def forward(self, x: torch.Tensor) -> torch.Tensor: 303 | return self.lin2(self.act(self.lin1(x))) -------------------------------------------------------------------------------- /Model/ImageEncoder/vit/block.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Block(nn.Module): 9 | """Transformer blocks with support of window attention and residual propagation blocks""" 10 | 11 | def __init__( 12 | self, 13 | args, 14 | dim: int, 15 | num_heads: int, 16 | mlp_ratio: float = 4.0, 17 | qkv_bias: bool = True, 18 | norm_layer: Type[nn.Module] = nn.LayerNorm, 19 | act_layer: Type[nn.Module] = nn.GELU, 20 | use_rel_pos: bool = False, 21 | rel_pos_zero_init: bool = True, 22 | window_size: int = 0, 23 | input_size: Optional[Tuple[int, int]] = None, 24 | ) -> None: 25 | """ 26 | Args: 27 | dim (int): Number of input channels. 28 | num_heads (int): Number of attention heads in each ViT block. 29 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 30 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 31 | norm_layer (nn.Module): Normalization layer. 32 | act_layer (nn.Module): Activation layer. 33 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 34 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 35 | window_size (int): Window size for window attention blocks. If it equals 0, then 36 | use global attention. 37 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 38 | positional parameter size. 39 | """ 40 | super().__init__() 41 | self.norm1 = norm_layer(dim) 42 | self.attn = Attention( 43 | dim, 44 | num_heads=num_heads, 45 | qkv_bias=qkv_bias, 46 | use_rel_pos=use_rel_pos, 47 | rel_pos_zero_init=rel_pos_zero_init, 48 | input_size=input_size if window_size == 0 else (window_size, window_size), 49 | ) 50 | 51 | self.norm2 = norm_layer(dim) 52 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 53 | 54 | self.window_size = window_size 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | shortcut = x 58 | x = self.norm1(x) 59 | # Window partition 60 | if self.window_size > 0: 61 | H, W = x.shape[1], x.shape[2] 62 | x, pad_hw = window_partition(x, self.window_size) 63 | 64 | x = self.attn(x) 65 | # Reverse window partition 66 | if self.window_size > 0: 67 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 68 | 69 | x = shortcut + x 70 | x = x + self.mlp(self.norm2(x)) 71 | 72 | return x 73 | 74 | class MLPBlock(nn.Module): 75 | def __init__( 76 | self, 77 | embedding_dim: int, 78 | mlp_dim: int, 79 | act: Type[nn.Module] = nn.GELU, 80 | ) -> None: 81 | super().__init__() 82 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 83 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 84 | self.act = act() 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | return self.lin2(self.act(self.lin1(x))) 88 | 89 | 90 | class Attention(nn.Module): 91 | """Multi-head Attention block with relative position embeddings.""" 92 | 93 | def __init__( 94 | self, 95 | dim: int, 96 | num_heads: int = 8, 97 | qkv_bias: bool = True, 98 | use_rel_pos: bool = False, 99 | rel_pos_zero_init: bool = True, 100 | input_size: Optional[Tuple[int, int]] = None, 101 | ) -> None: 102 | """ 103 | Args: 104 | dim (int): Number of input channels. 105 | num_heads (int): Number of attention heads. 106 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 107 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 108 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 109 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 110 | positional parameter size. 111 | """ 112 | super().__init__() 113 | self.num_heads = num_heads 114 | head_dim = dim // num_heads 115 | self.scale = head_dim**-0.5 116 | 117 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 118 | self.proj = nn.Linear(dim, dim) 119 | 120 | self.use_rel_pos = use_rel_pos 121 | if self.use_rel_pos: 122 | assert ( 123 | input_size is not None 124 | ), "Input size must be provided if using relative positional encoding." 125 | # initialize relative positional embeddings 126 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 127 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 128 | 129 | def forward(self, x: torch.Tensor) -> torch.Tensor: 130 | B, H, W, _ = x.shape 131 | # qkv with shape (3, B, nHead, H * W, C) 132 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 133 | # q, k, v with shape (B * nHead, H * W, C) 134 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 135 | 136 | attn = (q * self.scale) @ k.transpose(-2, -1) 137 | 138 | if self.use_rel_pos: 139 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 140 | 141 | attn = attn.softmax(dim=-1) 142 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 143 | x = self.proj(x) 144 | 145 | return x 146 | 147 | 148 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 149 | """ 150 | Partition into non-overlapping windows with padding if needed. 151 | Args: 152 | x (tensor): input tokens with [B, H, W, C]. 153 | window_size (int): window size. 154 | 155 | Returns: 156 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 157 | (Hp, Wp): padded height and width before partition 158 | """ 159 | B, H, W, C = x.shape 160 | 161 | pad_h = (window_size - H % window_size) % window_size 162 | pad_w = (window_size - W % window_size) % window_size 163 | if pad_h > 0 or pad_w > 0: 164 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 165 | Hp, Wp = H + pad_h, W + pad_w 166 | 167 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 168 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 169 | return windows, (Hp, Wp) 170 | 171 | 172 | def window_unpartition( 173 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 174 | ) -> torch.Tensor: 175 | """ 176 | Window unpartition into original sequences and removing padding. 177 | Args: 178 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 179 | window_size (int): window size. 180 | pad_hw (Tuple): padded height and width (Hp, Wp). 181 | hw (Tuple): original height and width (H, W) before padding. 182 | 183 | Returns: 184 | x: unpartitioned sequences with [B, H, W, C]. 185 | """ 186 | Hp, Wp = pad_hw 187 | H, W = hw 188 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 189 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 190 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 191 | 192 | if Hp > H or Wp > W: 193 | x = x[:, :H, :W, :].contiguous() 194 | return x 195 | 196 | 197 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 198 | """ 199 | Get relative positional embeddings according to the relative positions of 200 | query and key sizes. 201 | Args: 202 | q_size (int): size of query q. 203 | k_size (int): size of key k. 204 | rel_pos (Tensor): relative position embeddings (L, C). 205 | 206 | Returns: 207 | Extracted positional embeddings according to relative positions. 208 | """ 209 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 210 | # Interpolate rel pos if needed. 211 | if rel_pos.shape[0] != max_rel_dist: 212 | # Interpolate rel pos. 213 | rel_pos_resized = F.interpolate( 214 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 215 | size=max_rel_dist, 216 | mode="linear", 217 | ) 218 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 219 | else: 220 | rel_pos_resized = rel_pos 221 | 222 | # Scale the coords with short length if shapes for q and k are different. 223 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 224 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 225 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 226 | 227 | return rel_pos_resized[relative_coords.long()] 228 | 229 | 230 | def add_decomposed_rel_pos( 231 | attn: torch.Tensor, 232 | q: torch.Tensor, 233 | rel_pos_h: torch.Tensor, 234 | rel_pos_w: torch.Tensor, 235 | q_size: Tuple[int, int], 236 | k_size: Tuple[int, int], 237 | ) -> torch.Tensor: 238 | """ 239 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 240 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 241 | Args: 242 | attn (Tensor): attention map. 243 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 244 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 245 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 246 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 247 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 248 | 249 | Returns: 250 | attn (Tensor): attention map with added relative positional embeddings. 251 | """ 252 | q_h, q_w = q_size 253 | k_h, k_w = k_size 254 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 255 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 256 | 257 | B, _, dim = q.shape 258 | r_q = q.reshape(B, q_h, q_w, dim) 259 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 260 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 261 | 262 | attn = ( 263 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 264 | ).view(B, q_h * q_w, k_h * k_w) 265 | 266 | return attn -------------------------------------------------------------------------------- /Model/common/MaskDecoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .two_way_transformer import TwoWayTransformer -------------------------------------------------------------------------------- /Model/common/MaskDecoder/two_way_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Type 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | from ..mlp import MLPBlock 8 | 9 | 10 | class TwoWayTransformer(nn.Module): 11 | def __init__( 12 | self, 13 | depth: int, 14 | embedding_dim: int, 15 | num_heads: int, 16 | mlp_dim: int, 17 | activation: Type[nn.Module] = nn.ReLU, 18 | normalize_before_activation: bool = False, 19 | attention_downsample_rate: int = 2, 20 | ) -> None: 21 | """ 22 | A transformer decoder that attends to an input image using 23 | queries whose positional embedding is supplied. 24 | 25 | Args: 26 | depth (int): number of layers in the transformer 27 | embedding_dim (int): the channel dimension for the input embeddings 28 | num_heads (int): the number of heads for multihead attention. Must 29 | divide embedding_dim 30 | mlp_dim (int): the channel dimension internal to the MLP block 31 | activation (nn.Module): the activation to use in the MLP block 32 | """ 33 | super().__init__() 34 | self.depth = depth 35 | self.embedding_dim = embedding_dim 36 | self.num_heads = num_heads 37 | self.mlp_dim = mlp_dim 38 | self.layers = nn.ModuleList() 39 | 40 | for i in range(depth): 41 | curr_layer = TwoWayAttentionBlock( 42 | embedding_dim=embedding_dim, 43 | num_heads=num_heads, 44 | mlp_dim=mlp_dim, 45 | activation=activation, 46 | normalize_before_activation=normalize_before_activation, 47 | attention_downsample_rate=attention_downsample_rate, 48 | skip_first_layer_pe=(i == 0), 49 | ) 50 | self.layers.append(curr_layer) 51 | 52 | self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( 53 | embedding_dim, 54 | num_heads, 55 | downsample_rate=attention_downsample_rate, 56 | ) 57 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 58 | 59 | def forward( 60 | self, 61 | image_embedding: Tensor, 62 | image_pe: Tensor, 63 | point_embedding: Tensor, 64 | ) -> Tuple[Tensor, Tensor]: 65 | """ 66 | Args: 67 | image_embedding (torch.Tensor): image to attend to. Should be shape 68 | B x embedding_dim x h x w for any h and w. 69 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 70 | have the same shape as image_embedding. 71 | point_embedding (torch.Tensor): the embedding to add to the query points. 72 | Must have shape B x N_points x embedding_dim for any N_points. 73 | 74 | Returns: 75 | torch.Tensor: the processed point_embedding 76 | torch.Tensor: the processed image_embedding 77 | """ 78 | 79 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 80 | bs, c, h, w = image_embedding.shape 81 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 82 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 83 | 84 | # Prepare queries 85 | queries = point_embedding 86 | keys = image_embedding 87 | 88 | # Apply transformer blocks and final layernorm 89 | for idx, layer in enumerate(self.layers): 90 | queries, keys = layer( 91 | queries=queries, 92 | keys=keys, 93 | query_pe=point_embedding, 94 | key_pe=image_pe, 95 | ) 96 | 97 | # Apply the final attention layer from the points to the image 98 | q = queries + point_embedding 99 | k = keys + image_pe 100 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 101 | queries = queries + attn_out 102 | queries = self.norm_final_attn(queries) 103 | return queries, keys 104 | 105 | 106 | class TwoWayAttentionBlock(nn.Module): 107 | def __init__( 108 | self, 109 | embedding_dim: int, 110 | num_heads: int, 111 | mlp_dim: int, 112 | activation: Type[nn.Module], 113 | normalize_before_activation: bool, 114 | attention_downsample_rate: int = 2, 115 | skip_first_layer_pe: bool = False, 116 | ) -> None: 117 | """ 118 | A transformer block with four layers: (1) self-attention of sparse 119 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 120 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 121 | inputs. 122 | 123 | Arguments: 124 | embedding_dim (int): the channel dimension of the embeddings 125 | num_heads (int): the number of heads in the attention layers 126 | mlp_dim (int): the hidden dimension of the mlp block 127 | activation (nn.Module): the activation of the mlp block 128 | skip_first_layer_pe (bool): skip the PE on the first layer 129 | """ 130 | super().__init__() 131 | self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) 132 | self.norm1 = nn.LayerNorm(embedding_dim) 133 | 134 | self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( 135 | embedding_dim, 136 | num_heads, 137 | downsample_rate=attention_downsample_rate, 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock( 142 | embedding_dim, 143 | mlp_dim, 144 | embedding_dim, 145 | 1, 146 | activation, 147 | ) 148 | 149 | self.norm3 = nn.LayerNorm(embedding_dim) 150 | 151 | self.norm4 = nn.LayerNorm(embedding_dim) 152 | self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( 153 | embedding_dim, 154 | num_heads, 155 | downsample_rate=attention_downsample_rate, 156 | ) 157 | 158 | self.skip_first_layer_pe = skip_first_layer_pe 159 | 160 | def forward( 161 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 162 | ) -> Tuple[Tensor, Tensor]: 163 | # Self attention block 164 | if not self.skip_first_layer_pe: 165 | queries = queries + query_pe 166 | attn_out = self.self_attn(q=queries, k=queries, v=queries) 167 | queries = queries + attn_out 168 | queries = self.norm1(queries) 169 | 170 | # Cross attention block, tokens attending to image embedding 171 | q = queries + query_pe 172 | k = keys + key_pe 173 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 174 | queries = queries + attn_out 175 | queries = self.norm2(queries) 176 | 177 | # MLP block 178 | mlp_out = self.mlp(queries) 179 | queries = queries + mlp_out 180 | queries = self.norm3(queries) 181 | 182 | # Cross attention block, image embedding attending to tokens 183 | q = queries + query_pe 184 | k = keys + key_pe 185 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 186 | keys = keys + attn_out 187 | keys = self.norm4(keys) 188 | 189 | return queries, keys 190 | 191 | 192 | class AttentionForTwoWayAttentionBlock(nn.Module): 193 | """ 194 | An attention layer that allows for downscaling the size of the embedding 195 | after projection to queries, keys, and values. 196 | """ 197 | 198 | def __init__( 199 | self, 200 | embedding_dim: int, 201 | num_heads: int, 202 | downsample_rate: int = 1, 203 | ) -> None: 204 | super().__init__() 205 | self.embedding_dim = embedding_dim 206 | self.internal_dim = embedding_dim // downsample_rate 207 | self.num_heads = num_heads 208 | assert ( 209 | self.internal_dim % num_heads == 0 210 | ), "num_heads must divide embedding_dim." 211 | self.c_per_head = self.internal_dim / num_heads 212 | self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head) 213 | 214 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 215 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 216 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 217 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 218 | self._reset_parameters() 219 | 220 | def _reset_parameters(self) -> None: 221 | # The fan_out is incorrect, but matches pytorch's initialization 222 | # for which qkv is a single 3*embedding_dim x embedding_dim matrix 223 | fan_in = self.embedding_dim 224 | fan_out = 3 * self.internal_dim 225 | # Xavier uniform with our custom fan_out 226 | bnd = math.sqrt(6 / (fan_in + fan_out)) 227 | nn.init.uniform_(self.q_proj.weight, -bnd, bnd) 228 | nn.init.uniform_(self.k_proj.weight, -bnd, bnd) 229 | nn.init.uniform_(self.v_proj.weight, -bnd, bnd) 230 | # out_proj.weight is left with default initialization, like pytorch attention 231 | nn.init.zeros_(self.q_proj.bias) 232 | nn.init.zeros_(self.k_proj.bias) 233 | nn.init.zeros_(self.v_proj.bias) 234 | nn.init.zeros_(self.out_proj.bias) 235 | 236 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 237 | b, n, c = x.shape 238 | x = x.reshape(b, n, num_heads, c // num_heads) 239 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 240 | 241 | def _recombine_heads(self, x: Tensor) -> Tensor: 242 | b, n_heads, n_tokens, c_per_head = x.shape 243 | x = x.transpose(1, 2) 244 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 245 | 246 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 247 | # Input projections 248 | q = self.q_proj(q) 249 | k = self.k_proj(k) 250 | v = self.v_proj(v) 251 | 252 | # Separate into heads 253 | q = self._separate_heads(q, self.num_heads) 254 | k = self._separate_heads(k, self.num_heads) 255 | v = self._separate_heads(v, self.num_heads) 256 | 257 | # Attention 258 | _, _, _, c_per_head = q.shape 259 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 260 | attn = attn * self.inv_sqrt_c_per_head 261 | attn = torch.softmax(attn, dim=-1) 262 | # Get output 263 | out = attn @ v 264 | out = self._recombine_heads(out) 265 | out = self.out_proj(out) 266 | return out 267 | -------------------------------------------------------------------------------- /Model/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | from .mlp import MLPBlock 5 | from .two_way_transformer import TwoWayTransformer 6 | from .adapter import Adapter 7 | from .layer_norm import LayerNorm2d -------------------------------------------------------------------------------- /Model/common/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Adapter(nn.Module): 6 | def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): 7 | super().__init__() 8 | self.skip_connect = skip_connect 9 | D_hidden_features = int(D_features * mlp_ratio) 10 | self.act = act_layer() 11 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 12 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 13 | 14 | def forward(self, x): 15 | # x is (BT, HW+1, D) 16 | xs = self.D_fc1(x) 17 | xs = self.act(xs) 18 | xs = self.D_fc2(xs) 19 | if self.skip_connect: 20 | x = x + xs 21 | else: 22 | x = xs 23 | return x -------------------------------------------------------------------------------- /Model/common/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm2d(nn.Module): 6 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 7 | super().__init__() 8 | self.weight = nn.Parameter(torch.ones(num_channels)) 9 | self.bias = nn.Parameter(torch.zeros(num_channels)) 10 | self.eps = eps 11 | 12 | def forward(self, x: torch.Tensor) -> torch.Tensor: 13 | u = x.mean(1, keepdim=True) 14 | s = (x - u).pow(2).mean(1, keepdim=True) 15 | x = (x - u) / torch.sqrt(s + self.eps) 16 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 17 | return x 18 | -------------------------------------------------------------------------------- /Model/common/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from torch import nn 4 | 5 | 6 | # Lightly adapted from 7 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 8 | class MLPBlock(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | output_dim: int, 14 | num_layers: int, 15 | act: Type[nn.Module], 16 | ) -> None: 17 | super().__init__() 18 | self.num_layers = num_layers 19 | h = [hidden_dim] * (num_layers - 1) 20 | self.layers = nn.ModuleList( 21 | nn.Sequential(nn.Linear(n, k), act()) 22 | for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) 23 | ) 24 | self.fc = nn.Linear(hidden_dim, output_dim) 25 | 26 | def forward(self, x): 27 | for layer in self.layers: 28 | x = layer(x) 29 | return self.fc(x) 30 | -------------------------------------------------------------------------------- /Model/common/two_way_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Type 3 | import torch 4 | from torch import nn, Tensor 5 | from .mlp import MLPBlock 6 | from .adapter import Adapter 7 | class TwoWayTransformer(nn.Module): 8 | def __init__( 9 | self, 10 | depth: int, 11 | embedding_dim: int, 12 | num_heads: int, 13 | mlp_dim: int, 14 | activation: Type[nn.Module] = nn.ReLU, 15 | normalize_before_activation: bool = False, 16 | attention_downsample_rate: int = 2, 17 | ) -> None: 18 | """ 19 | A transformer decoder that attends to an input image using 20 | queries whose positional embedding is supplied. 21 | 22 | Args: 23 | depth (int): number of layers in the transformer 24 | embedding_dim (int): the channel dimension for the input embeddings 25 | num_heads (int): the number of heads for multihead attention. Must 26 | divide embedding_dim 27 | mlp_dim (int): the channel dimension internal to the MLP block 28 | activation (nn.Module): the activation to use in the MLP block 29 | """ 30 | super().__init__() 31 | self.depth = depth 32 | self.embedding_dim = embedding_dim 33 | self.num_heads = num_heads 34 | self.mlp_dim = mlp_dim 35 | self.layers = nn.ModuleList() 36 | 37 | for i in range(depth): 38 | curr_layer = TwoWayAttentionBlock( 39 | embedding_dim=embedding_dim, 40 | num_heads=num_heads, 41 | mlp_dim=mlp_dim, 42 | activation=activation, 43 | normalize_before_activation=normalize_before_activation, 44 | attention_downsample_rate=attention_downsample_rate, 45 | skip_first_layer_pe=(i == 0), 46 | ) 47 | self.layers.append(curr_layer) 48 | 49 | self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( 50 | embedding_dim, 51 | num_heads, 52 | downsample_rate=attention_downsample_rate, 53 | ) 54 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 55 | 56 | self.MLP_Adapter = Adapter(embedding_dim, skip_connect=False) # MLP-adapter, no skip connection 57 | self.Space_Adapter = Adapter(embedding_dim) # with skip connection 58 | self.scale = 0.5 59 | 60 | def forward( 61 | self, 62 | image_embedding: Tensor, 63 | image_pe: Tensor, 64 | point_embedding: Tensor, 65 | ) -> Tuple[Tensor, Tensor]: 66 | """ 67 | Args: 68 | image_embedding (torch.Tensor): image to attend to. Should be shape 69 | B x embedding_dim x h x w for any h and w. 70 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 71 | have the same shape as image_embedding. 72 | point_embedding (torch.Tensor): the embedding to add to the query points. 73 | Must have shape B x N_points x embedding_dim for any N_points. 74 | 75 | Returns: 76 | torch.Tensor: the processed point_embedding 77 | torch.Tensor: the processed image_embedding 78 | """ 79 | 80 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 81 | bs, c, h, w = image_embedding.shape 82 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 83 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 84 | 85 | # Prepare queries 86 | queries = point_embedding 87 | keys = image_embedding 88 | 89 | # Apply transformer blocks and final layernorm 90 | for idx, layer in enumerate(self.layers): 91 | queries, keys = layer( 92 | queries=queries, 93 | keys=keys, 94 | query_pe=point_embedding, 95 | key_pe=image_pe, 96 | ) 97 | 98 | # Apply the final attention layer from the points to the image 99 | q = queries + point_embedding 100 | k = keys + image_pe 101 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 102 | attn_out = self.Space_Adapter(attn_out) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | queries = self.Space_Adapter(queries) 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int, 115 | activation: Type[nn.Module], 116 | normalize_before_activation: bool, 117 | attention_downsample_rate: int = 2, 118 | skip_first_layer_pe: bool = False, 119 | ) -> None: 120 | """ 121 | A transformer block with four layers: (1) self-attention of sparse 122 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 123 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 124 | inputs. 125 | 126 | Arguments: 127 | embedding_dim (int): the channel dimension of the embeddings 128 | num_heads (int): the number of heads in the attention layers 129 | mlp_dim (int): the hidden dimension of the mlp block 130 | activation (nn.Module): the activation of the mlp block 131 | skip_first_layer_pe (bool): skip the PE on the first layer 132 | """ 133 | super().__init__() 134 | self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) 135 | self.norm1 = nn.LayerNorm(embedding_dim) 136 | 137 | self.MLP_Adapter = Adapter(embedding_dim, skip_connect=False) # MLP-adapter, no skip connection 138 | self.Space_Adapter = Adapter(embedding_dim) # with skip connection 139 | self.scale = 0.5 140 | 141 | self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( 142 | embedding_dim, 143 | num_heads, 144 | downsample_rate=attention_downsample_rate, 145 | ) 146 | self.norm2 = nn.LayerNorm(embedding_dim) 147 | 148 | self.mlp = MLPBlock( 149 | embedding_dim, 150 | mlp_dim, 151 | embedding_dim, 152 | 1, 153 | activation, 154 | ) 155 | 156 | self.norm3 = nn.LayerNorm(embedding_dim) 157 | self.norm4 = nn.LayerNorm(embedding_dim) 158 | self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( 159 | embedding_dim, 160 | num_heads, 161 | downsample_rate=attention_downsample_rate, 162 | ) 163 | self.skip_first_layer_pe = skip_first_layer_pe 164 | 165 | def forward( 166 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 167 | ) -> Tuple[Tensor, Tensor]: 168 | # Self attention block 169 | if not self.skip_first_layer_pe: 170 | queries = queries + query_pe 171 | attn_out = self.self_attn(q=queries, k=queries, v=queries) 172 | attn_out = self.Space_Adapter(attn_out) 173 | queries = queries + attn_out 174 | queries = self.norm1(queries) 175 | 176 | # Cross attention block, tokens attending to image embedding 177 | q = queries + query_pe 178 | k = keys + key_pe 179 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 180 | attn_out = self.Space_Adapter(attn_out) 181 | queries = queries + attn_out 182 | queries = self.norm2(queries) 183 | 184 | # MLP block 185 | mlp_out = self.mlp(queries) + self.scale * self.MLP_Adapter(queries) 186 | # mlp_out = self.mlp(queries) 187 | queries = queries + mlp_out 188 | queries = self.norm3(queries) 189 | 190 | # Cross attention block, image embedding attending to tokens 191 | q = queries + query_pe 192 | k = keys + key_pe 193 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 194 | attn_out = self.Space_Adapter(attn_out) 195 | keys = keys + attn_out 196 | keys = self.norm4(keys) 197 | 198 | return queries, keys 199 | 200 | 201 | class AttentionForTwoWayAttentionBlock(nn.Module): 202 | """ 203 | An attention layer that allows for downscaling the size of the embedding 204 | after projection to queries, keys, and values. 205 | """ 206 | 207 | def __init__( 208 | self, 209 | embedding_dim: int, 210 | num_heads: int, 211 | downsample_rate: int = 1, 212 | ) -> None: 213 | super().__init__() 214 | self.embedding_dim = embedding_dim 215 | self.internal_dim = embedding_dim // downsample_rate 216 | self.num_heads = num_heads 217 | assert ( 218 | self.internal_dim % num_heads == 0 219 | ), "num_heads must divide embedding_dim." 220 | self.c_per_head = self.internal_dim / num_heads 221 | self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head) 222 | 223 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 224 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 225 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 226 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 227 | self._reset_parameters() 228 | 229 | def _reset_parameters(self) -> None: 230 | # The fan_out is incorrect, but matches pytorch's initialization 231 | # for which qkv is a single 3*embedding_dim x embedding_dim matrix 232 | fan_in = self.embedding_dim 233 | fan_out = 3 * self.internal_dim 234 | # Xavier uniform with our custom fan_out 235 | bnd = math.sqrt(6 / (fan_in + fan_out)) 236 | nn.init.uniform_(self.q_proj.weight, -bnd, bnd) 237 | nn.init.uniform_(self.k_proj.weight, -bnd, bnd) 238 | nn.init.uniform_(self.v_proj.weight, -bnd, bnd) 239 | # out_proj.weight is left with default initialization, like pytorch attention 240 | nn.init.zeros_(self.q_proj.bias) 241 | nn.init.zeros_(self.k_proj.bias) 242 | nn.init.zeros_(self.v_proj.bias) 243 | nn.init.zeros_(self.out_proj.bias) 244 | 245 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 246 | b, n, c = x.shape 247 | x = x.reshape(b, n, num_heads, c // num_heads) 248 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 249 | 250 | def _recombine_heads(self, x: Tensor) -> Tensor: 251 | b, n_heads, n_tokens, c_per_head = x.shape 252 | x = x.transpose(1, 2) 253 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 254 | 255 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 256 | # Input projections 257 | q = self.q_proj(q) 258 | k = self.k_proj(k) 259 | v = self.v_proj(v) 260 | 261 | # Separate into heads 262 | q = self._separate_heads(q, self.num_heads) 263 | k = self._separate_heads(k, self.num_heads) 264 | v = self._separate_heads(v, self.num_heads) 265 | 266 | # Attention 267 | _, _, _, c_per_head = q.shape 268 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 269 | attn = attn * self.inv_sqrt_c_per_head 270 | attn = torch.softmax(attn, dim=-1) 271 | # Get output 272 | out = attn @ v 273 | out = self._recombine_heads(out) 274 | out = self.out_proj(out) 275 | return out 276 | -------------------------------------------------------------------------------- /Model/discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | class Discriminator(torch.nn.Module): 5 | def __init__(self, in_channels=2, out_conv_channels=1): 6 | super(Discriminator, self).__init__() 7 | ambiguous_channels = 16 8 | entmap_channels = 16 9 | logits_channels = 16 10 | embedding_channels = 64 11 | 12 | self.out_conv_channels = out_conv_channels 13 | 14 | self.ambiguous_conv = nn.Sequential( 15 | nn.Conv2d(in_channels=in_channels, out_channels=ambiguous_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(ambiguous_channels), 17 | nn.LeakyReLU(0.2, inplace=True), 18 | nn.Conv2d(in_channels=ambiguous_channels, out_channels=ambiguous_channels, kernel_size=1, stride=1, padding=0, bias=False), 19 | nn.BatchNorm2d(ambiguous_channels), 20 | nn.LeakyReLU(0.2, inplace=True) 21 | ) 22 | self.entmap_conv = nn.Sequential( 23 | nn.Conv2d(in_channels=2, out_channels=entmap_channels, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(entmap_channels), 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Conv2d(in_channels=entmap_channels, out_channels=entmap_channels, kernel_size=1, stride=1, padding=0, bias=False), 27 | nn.BatchNorm2d(entmap_channels), 28 | nn.LeakyReLU(0.2, inplace=True) 29 | ) 30 | self.logits_conv = nn.Sequential( 31 | nn.Conv2d(in_channels=in_channels * 2, out_channels=logits_channels, kernel_size=3, padding=1, bias=False), 32 | nn.BatchNorm2d(logits_channels), 33 | nn.LeakyReLU(0.2, inplace=True), 34 | nn.Conv2d(in_channels=logits_channels, out_channels=logits_channels * 2, kernel_size=1, stride=1, padding=0, bias=False), 35 | nn.BatchNorm2d(logits_channels * 2), 36 | nn.LeakyReLU(0.2, inplace=True) 37 | ) 38 | self.final_conv = nn.Sequential( 39 | nn.Conv2d(in_channels=embedding_channels, out_channels=embedding_channels, kernel_size=3, padding=1, bias=False), 40 | nn.BatchNorm2d(embedding_channels), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | 43 | nn.Conv2d(in_channels=embedding_channels, out_channels=out_conv_channels, kernel_size=3, padding=1, bias=False), 44 | nn.BatchNorm2d(out_conv_channels), 45 | nn.LeakyReLU(0.2, inplace=True) 46 | ) 47 | 48 | def get_entropy_map(self, p): 49 | ent_map = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True) 50 | return ent_map 51 | 52 | def forward(self, pred_UNet, pred_YNet, pred_UNet_soft, pred_VNet_soft, entmap1, entmap2, thr=0.5): 53 | 54 | pred_UNet_bool = torch.where(pred_UNet_soft > thr, torch.tensor(1), torch.tensor(0)) 55 | pred_YNet_bool = torch.where(pred_VNet_soft > thr, torch.tensor(1), torch.tensor(0)) 56 | 57 | ambiguous_area = torch.bitwise_xor(pred_UNet_bool, pred_YNet_bool).to(dtype=torch.float32) 58 | uncertainty_area = torch.cat((entmap1, entmap2), dim=1) 59 | pred_logits = torch.cat((pred_UNet, pred_YNet), dim=1) 60 | 61 | ambiguous_info = self.ambiguous_conv(1 - ambiguous_area) 62 | uncertainty_info = self.entmap_conv(1 - uncertainty_area) 63 | pred_info = self.logits_conv(pred_logits) 64 | 65 | x = torch.cat((ambiguous_info, uncertainty_info, pred_info), dim=1) 66 | 67 | x = self.final_conv(x) 68 | 69 | return x -------------------------------------------------------------------------------- /Model/model.py: -------------------------------------------------------------------------------- 1 | from Model.unet import DownBlock, UpBlock 2 | from Model.vnet import DownsamplingConvBlock, UpsamplingDeconvBlock 3 | from Model.vnet import ConvBlock as vnet_ConvBlock 4 | from Model.unet import ConvBlock as unet_ConvBlock 5 | from Model.discriminator import Discriminator 6 | import torch 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | from torch import nn, einsum 10 | 11 | 12 | class FeedForward(nn.Module): 13 | def __init__(self, dim, hidden_dim, dropout=0.): 14 | super().__init__() 15 | self.net = nn.Sequential( 16 | nn.Linear(dim, hidden_dim), 17 | nn.GELU(), 18 | nn.Dropout(dropout), 19 | nn.Linear(hidden_dim, dim), 20 | nn.Dropout(dropout) 21 | ) 22 | 23 | def forward(self, x): 24 | return self.net(x) 25 | 26 | 27 | class Attention(nn.Module): 28 | def __init__(self, dim, num_heads=8, dim_head=64, dropout=0.1): 29 | super().__init__() 30 | inner_dim = dim_head * num_heads 31 | self.num_heads = num_heads 32 | 33 | self.w_q = nn.Linear(dim, inner_dim) 34 | self.w_k = nn.Linear(dim, inner_dim) 35 | self.w_v = nn.Linear(dim, inner_dim) 36 | 37 | self.scale = dim_head ** -0.5 38 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 39 | 40 | project_out = not (num_heads == 1 and dim_head == dim) 41 | self.to_out = nn.Sequential( 42 | nn.Linear(inner_dim, dim), 43 | nn.Dropout(dropout), 44 | ) if project_out else nn.Identity() 45 | 46 | def forward(self, p1, p2): 47 | q_p1 = self.w_q(p1) 48 | k_p2 = self.w_k(p2) 49 | v_p2 = self.w_v(p2) 50 | q_p1 = rearrange(q_p1, 'b n (h d) -> b h n d', h=self.num_heads) 51 | k_p2 = rearrange(k_p2, 'b n (h d) -> b h n d', h=self.num_heads) 52 | v_p2 = rearrange(v_p2, 'b n (h d) -> b h n d', h=self.num_heads) 53 | 54 | attn_p1p2 = einsum('b h i d, b h j d -> b h i j', q_p1, k_p2) * self.scale 55 | attn_p1p2 = attn_p1p2.softmax(dim=-1) 56 | # show(attn_p1p2) 57 | attn_p1p2 = einsum('b h i j, b h j d -> b h i d', attn_p1p2, v_p2) 58 | attn_p1p2 = rearrange(attn_p1p2, 'b h n d -> b n (h d)') 59 | 60 | attn_p1p2 = self.to_out(attn_p1p2) 61 | return attn_p1p2 62 | 63 | 64 | class Cross_Attention_block(nn.Module): 65 | def __init__(self, input_size, in_channels, patch_size=16, num_heads=8, channel_attn_drop=0.1, pos_embed=True, dim=96, dim_head=64, hid_dim=384): 66 | super(Cross_Attention_block, self).__init__() 67 | self.patch_size = patch_size 68 | input_size = int(input_size) 69 | assert input_size % self.patch_size == 0, 'Image dimensions must be divisible by the patch size.' 70 | num_patches = (input_size // patch_size) ** 2 71 | 72 | patch_dim = in_channels * patch_size ** 2 73 | self.to_patch_embedding = nn.Sequential( 74 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 75 | nn.Linear(patch_dim, dim) 76 | ) 77 | self.dropout = nn.Dropout(channel_attn_drop) 78 | 79 | self.attn = Attention(dim, num_heads, dim_head, channel_attn_drop) 80 | 81 | if pos_embed: 82 | # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, dim)) 83 | self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim)) 84 | else: 85 | self.pos_embed = None 86 | 87 | self.MLP = FeedForward(dim, hid_dim) 88 | 89 | self.to_out = nn.Sequential( 90 | nn.Linear(dim, patch_dim), 91 | nn.Dropout(channel_attn_drop), 92 | Rearrange('b (h w) (p1 p2 c)-> b c (h p1) (w p2) ', h=(input_size // patch_size), w=(input_size // patch_size), p1=patch_size, p2=patch_size), 93 | ) 94 | 95 | def forward(self, p1, p2): 96 | 97 | p1 = self.to_patch_embedding(p1) 98 | p2 = self.to_patch_embedding(p2) 99 | _, n, _ = p1.shape # n表示每个块的空间分辨率 100 | 101 | if self.pos_embed is not None: 102 | p1 = p1 + self.pos_embed 103 | p2 = p2 + self.pos_embed 104 | p1 = self.dropout(p1) 105 | p2 = self.dropout(p2) 106 | 107 | attn_p1p2 = self.attn(p1, p2) 108 | 109 | attn_p1p2 = self.MLP(attn_p1p2) + attn_p1p2 110 | attn_p1p2 = self.to_out(attn_p1p2) 111 | # show_attn(attn_p1p2) 112 | return attn_p1p2 113 | 114 | 115 | class VNet(nn.Module): 116 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 117 | super(VNet, self).__init__() 118 | self.has_dropout = has_dropout 119 | self.block_one = vnet_ConvBlock(1, n_channels, n_filters, normalization=normalization) 120 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 121 | 122 | self.block_two = vnet_ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 123 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 124 | 125 | self.block_three = vnet_ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 126 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 127 | 128 | self.block_four = vnet_ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 129 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 130 | 131 | self.block_five = vnet_ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 132 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 133 | 134 | self.block_six = vnet_ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 135 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 136 | 137 | self.block_seven = vnet_ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 138 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 139 | 140 | self.block_eight = vnet_ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 141 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 142 | 143 | self.block_nine = vnet_ConvBlock(1, n_filters, n_filters, normalization=normalization) 144 | self.out_conv = nn.Conv2d(n_filters, n_classes, 1, padding=0) 145 | 146 | self.dropout = nn.Dropout2d(p=0.5, inplace=False) 147 | 148 | self.__init_weight() 149 | 150 | def __init_weight(self): 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): 153 | torch.nn.init.kaiming_normal_(m.weight) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | 158 | 159 | class UNet(nn.Module): 160 | def __init__(self, in_chns, class_num, bilinear): 161 | super(UNet, self).__init__() 162 | params = {'in_chns': in_chns, 163 | 'feature_chns': [16, 32, 64, 128, 256], 164 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 165 | 'class_num': class_num, 166 | 'bilinear': bilinear, 167 | 'acti_func': 'relu', 168 | } 169 | self.params = params 170 | self.n_class = self.params['class_num'] 171 | self.in_chns = self.params['in_chns'] 172 | self.ft_chns = self.params['feature_chns'] 173 | self.bilinear = self.params['bilinear'] 174 | self.dropout = self.params['dropout'] 175 | assert (len(self.ft_chns) == 5) 176 | 177 | self.in_conv = unet_ConvBlock( 178 | self.in_chns, self.ft_chns[0], self.dropout[0]) 179 | self.down1 = DownBlock( 180 | self.ft_chns[0], self.ft_chns[1], self.dropout[1]) 181 | self.down2 = DownBlock( 182 | self.ft_chns[1], self.ft_chns[2], self.dropout[2]) 183 | self.down3 = DownBlock( 184 | self.ft_chns[2], self.ft_chns[3], self.dropout[3]) 185 | self.down4 = DownBlock( 186 | self.ft_chns[3], self.ft_chns[4], self.dropout[4]) 187 | 188 | self.up1 = UpBlock( 189 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, bilinear=self.bilinear) 190 | self.up2 = UpBlock( 191 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, bilinear=self.bilinear) 192 | self.up3 = UpBlock( 193 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, bilinear=self.bilinear) 194 | self.up4 = UpBlock( 195 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, bilinear=self.bilinear) 196 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 197 | kernel_size=3, padding=1) 198 | 199 | class KnowSAM(nn.Module): 200 | def __init__(self, args, bilinear=False, has_dropout=False): 201 | super(KnowSAM, self).__init__() 202 | self.has_dropout = has_dropout 203 | self.UNet = UNet(in_chns=args.in_channels, class_num=args.num_classes, bilinear=bilinear) 204 | self.VNet = VNet(n_channels=args.in_channels, n_classes=args.num_classes) 205 | self.Discriminator = Discriminator(in_channels=args.num_classes, out_conv_channels=args.num_classes) 206 | 207 | def get_entropy_map(self, p): 208 | ent_map = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True) 209 | return ent_map 210 | 211 | def forward(self, x): 212 | x0_u = self.UNet.in_conv(x) 213 | x0_v = self.VNet.block_one(x) 214 | 215 | x1_u = self.UNet.down1(x0_u) 216 | x1_v = self.VNet.block_one_dw(x0_v) 217 | 218 | x2_u = self.UNet.down2(x1_u) 219 | x2_v = self.VNet.block_two_dw(self.VNet.block_two(x1_v)) 220 | 221 | x3_u = self.UNet.down3(x2_u) 222 | x3_v = self.VNet.block_three_dw(self.VNet.block_three(x2_v)) 223 | 224 | x4_u = self.UNet.down4(x3_u) 225 | x4_v = self.VNet.block_four_dw(self.VNet.block_four(x3_v)) 226 | 227 | # unet decoder 228 | x_u = self.UNet.up1(x4_u, x3_u) 229 | x_u = self.UNet.up2(x_u, x2_u) 230 | x_u = self.UNet.up3(x_u, x1_u) 231 | x_u = self.UNet.up4(x_u, x0_u) 232 | pred_UNet = self.UNet.out_conv(x_u) 233 | 234 | # vnet_decoder 235 | x_v = self.VNet.block_five_up(self.VNet.block_five(x4_v)) 236 | x_v = x_v + x3_v 237 | x_v = self.VNet.block_six_up(self.VNet.block_six(x_v)) 238 | x_v = x_v + x2_v 239 | x_v = self.VNet.block_seven_up(self.VNet.block_seven(x_v)) 240 | x_v = x_v + x1_v 241 | x_v = self.VNet.block_eight_up(self.VNet.block_eight(x_v)) 242 | x_v = x_v + x0_v 243 | x_v = self.VNet.block_nine(x_v) 244 | if self.has_dropout: 245 | x_v = self.VNet.dropout(x_v) 246 | pred_VNet = self.VNet.out_conv(x_v) 247 | 248 | pred_UNet_soft = torch.softmax(pred_UNet, dim=1) 249 | pred_VNet_soft = torch.softmax(pred_VNet, dim=1) 250 | 251 | entmap1 = self.get_entropy_map(pred_UNet_soft) 252 | entmap2 = self.get_entropy_map(pred_VNet_soft) 253 | 254 | fusion_map = self.Discriminator(pred_UNet, pred_VNet, pred_UNet_soft, pred_VNet_soft, entmap1, entmap2) 255 | return pred_UNet, pred_VNet, pred_UNet_soft, pred_VNet_soft, fusion_map 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /Model/prompt.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.distributions.uniform import Uniform 7 | import torch.nn.functional as F 8 | 9 | def kaiming_normal_init_weight(model): 10 | for m in model.modules(): 11 | if isinstance(m, nn.Conv2d): 12 | torch.nn.init.kaiming_normal_(m.weight) 13 | elif isinstance(m, nn.BatchNorm2d): 14 | m.weight.data.fill_(1) 15 | m.bias.data.zero_() 16 | return model 17 | 18 | 19 | def sparse_init_weight(model): 20 | for m in model.modules(): 21 | if isinstance(m, nn.Conv2d): 22 | torch.nn.init.sparse_(m.weight, sparsity=0.1) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | return model 27 | 28 | 29 | class ConvBlock(nn.Module): 30 | """two convolution layers with batch norm and leaky relu""" 31 | 32 | def __init__(self, in_channels, out_channels, dropout_p): 33 | super(ConvBlock, self).__init__() 34 | self.conv_conv = nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(out_channels), 37 | nn.LeakyReLU(), 38 | nn.Dropout(dropout_p), 39 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(out_channels), 41 | nn.LeakyReLU() 42 | ) 43 | 44 | def forward(self, x): 45 | return self.conv_conv(x) 46 | 47 | 48 | class DownBlock(nn.Module): 49 | """Downsampling followed by ConvBlock""" 50 | 51 | def __init__(self, in_channels, out_channels, dropout_p): 52 | super(DownBlock, self).__init__() 53 | self.maxpool_conv = nn.Sequential( 54 | nn.MaxPool2d(2), 55 | ConvBlock(in_channels, out_channels, dropout_p) 56 | ) 57 | 58 | def forward(self, x): 59 | return self.maxpool_conv(x) 60 | 61 | 62 | class UpBlock(nn.Module): 63 | """Upssampling followed by ConvBlock""" 64 | 65 | def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, 66 | bilinear=True): 67 | super(UpBlock, self).__init__() 68 | self.bilinear = bilinear 69 | if bilinear: 70 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) 71 | self.up = nn.Upsample( 72 | scale_factor=2, mode='bilinear', align_corners=True) 73 | else: 74 | self.up = nn.ConvTranspose2d( 75 | in_channels1, in_channels2, kernel_size=2, stride=2) 76 | self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) 77 | 78 | def forward(self, x1, x2): 79 | if self.bilinear: 80 | x1 = self.conv1x1(x1) 81 | x1 = self.up(x1) 82 | x = torch.cat([x2, x1], dim=1) 83 | return self.conv(x) 84 | 85 | 86 | class Encoder(nn.Module): 87 | def __init__(self, params): 88 | super(Encoder, self).__init__() 89 | self.params = params 90 | self.in_chns = self.params['in_chns'] 91 | self.ft_chns = self.params['feature_chns'] 92 | # self.n_class = self.params['class_num'] 93 | self.bilinear = self.params['bilinear'] 94 | self.dropout = self.params['dropout'] 95 | assert (len(self.ft_chns) == 5) 96 | self.in_conv = ConvBlock( 97 | self.in_chns, self.ft_chns[0], self.dropout[0]) 98 | self.down1 = DownBlock( 99 | self.ft_chns[0], self.ft_chns[1], self.dropout[1]) 100 | self.down2 = DownBlock( 101 | self.ft_chns[1], self.ft_chns[2], self.dropout[2]) 102 | self.down3 = DownBlock( 103 | self.ft_chns[2], self.ft_chns[3], self.dropout[3]) 104 | self.down4 = DownBlock( 105 | self.ft_chns[3], self.ft_chns[4], self.dropout[4]) 106 | 107 | def forward(self, x): 108 | x0 = self.in_conv(x) 109 | x1 = self.down1(x0) 110 | x2 = self.down2(x1) 111 | x3 = self.down3(x2) 112 | x4 = self.down4(x3) 113 | return [x0, x1, x2, x3, x4] 114 | # return x4 115 | 116 | 117 | class Decoder(nn.Module): 118 | def __init__(self, params): 119 | super(Decoder, self).__init__() 120 | self.params = params 121 | self.in_chns = self.params['in_chns'] 122 | self.ft_chns = self.params['feature_chns'] 123 | self.n_class = self.params['class_num'] 124 | self.bilinear = self.params['bilinear'] 125 | assert (len(self.ft_chns) == 5) 126 | 127 | self.up1 = UpBlock( 128 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 129 | self.up2 = UpBlock( 130 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 131 | self.up3 = UpBlock( 132 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 133 | self.up4 = UpBlock( 134 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 135 | 136 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 137 | kernel_size=3, padding=1) 138 | 139 | def forward(self, feature): 140 | # def forward(self, x0, x1, x2, x3, x4): 141 | x0 = feature[0] 142 | x1 = feature[1] 143 | x2 = feature[2] 144 | x3 = feature[3] 145 | x4 = feature[4] 146 | 147 | x = self.up1(x4, x3) 148 | x = self.up2(x, x2) 149 | x = self.up3(x, x1) 150 | x = self.up4(x, x0) 151 | output = self.out_conv(x) 152 | return output, x 153 | # return output 154 | 155 | 156 | class FeedForward(nn.Sequential): 157 | """ 158 | Feed forward module used in the transformer encoder. 159 | """ 160 | 161 | def __init__(self, 162 | in_features: int, 163 | hidden_features: int, 164 | out_features: int, 165 | dropout: float = 0.) -> None: 166 | """ 167 | Constructor method 168 | :param in_features: (int) Number of input features 169 | :param hidden_features: (int) Number of hidden features 170 | :param out_features: (int) Number of output features 171 | :param dropout: (float) Dropout factor 172 | """ 173 | # Call super constructor and init modules 174 | super().__init__( 175 | nn.Linear(in_features=in_features, out_features=hidden_features), 176 | nn.GELU(), 177 | nn.Dropout(p=dropout), 178 | nn.Linear(in_features=hidden_features, out_features=out_features), 179 | nn.Dropout(p=dropout) 180 | ) 181 | 182 | 183 | class box_decoder_embedding(nn.Module): 184 | def __init__(self, params): 185 | super(box_decoder_embedding, self).__init__() 186 | 187 | self.sam_box_embedding_channels = 256 188 | self.box_nums = params['box_nums'] 189 | f_channel = params['feature_chns'][-1] 190 | 191 | self.conv_1 = ConvBlock(in_channels=params['feature_chns'][-1], out_channels=params['feature_chns'][-1], 192 | dropout_p=0.1) 193 | self.down_1 = DownBlock(params['feature_chns'][-1], params['feature_chns'][-1], 0.1) 194 | self.conv_2 = ConvBlock(in_channels=params['feature_chns'][-1], out_channels=params['feature_chns'][-1], 195 | dropout_p=0.1) 196 | self.down_2 = DownBlock(params['feature_chns'][-1], params['feature_chns'][-1], 0.1) 197 | self.conv_3 = nn.Conv2d(params['feature_chns'][-1], params['feature_chns'][-1], kernel_size=1) 198 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 199 | self.ffn = FeedForward(in_features=params['feature_chns'][-1], hidden_features=self.box_nums * f_channel * 4, 200 | out_features=self.box_nums * self.sam_box_embedding_channels * 2, dropout=0.1) 201 | 202 | def forward(self, feature): 203 | b = feature.shape[0] 204 | x = self.conv_1(feature) 205 | x = self.down_1(x) 206 | x = self.conv_2(x) 207 | x = self.down_2(x) 208 | x = self.conv_3(x) 209 | x = self.global_pooling(x).view(b, -1) 210 | boxes_embedding = self.ffn(x) 211 | 212 | return boxes_embedding.view(b, self.box_nums * 2, self.sam_box_embedding_channels) 213 | 214 | 215 | class Super_Prompt(nn.Module): 216 | def __init__(self, in_chns, class_num, point_nums=5, box_nums=1): 217 | super(Super_Prompt, self).__init__() 218 | 219 | params = {'in_chns': in_chns, 220 | 'feature_chns': [16, 32, 64, 128, 256], 221 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 222 | 'class_num': class_num, 223 | 'bilinear': False, 224 | 'acti_func': 'relu', 225 | 'point_nums': point_nums, 226 | 'box_nums': box_nums 227 | } 228 | self.class_num = class_num 229 | self.box_decoder = nn.ModuleList() 230 | for _ in range(class_num): 231 | self.box_decoder.append(box_decoder_embedding(params)) 232 | 233 | def forward(self, x): 234 | feature = x 235 | boxes_embedding = [] 236 | for i in range(self.class_num): 237 | boxes_embedding.append(self.box_decoder[i](feature)) 238 | return None, boxes_embedding, None 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /Model/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /Model/sam/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import urllib.request 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | from functools import partial 8 | from pathlib import Path 9 | 10 | import torch 11 | from Model.prompt import Super_Prompt 12 | from ..common import TwoWayTransformer 13 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam 14 | 15 | 16 | def build_sam_vit_h(args = None, checkpoint=None): 17 | return _build_sam( 18 | args, 19 | encoder_embed_dim=1280, 20 | encoder_depth=32, 21 | encoder_num_heads=16, 22 | encoder_global_attn_indexes=[7, 15, 23, 31], 23 | checkpoint=checkpoint, 24 | ) 25 | 26 | 27 | build_sam = build_sam_vit_h 28 | 29 | 30 | def build_sam_vit_l(args, checkpoint=None): 31 | return _build_sam( 32 | args, 33 | encoder_embed_dim=1024, 34 | encoder_depth=24, 35 | encoder_num_heads=16, 36 | encoder_global_attn_indexes=[5, 11, 17, 23], 37 | checkpoint=checkpoint, 38 | ) 39 | 40 | 41 | def build_sam_vit_b(args, checkpoint="sam_vit_b_01ec64.pth"): 42 | return _build_sam( 43 | args, 44 | encoder_embed_dim=768, 45 | encoder_depth=12, 46 | encoder_num_heads=12, 47 | encoder_global_attn_indexes=[2, 5, 8, 11], 48 | checkpoint=checkpoint, 49 | ) 50 | 51 | 52 | sam_model_registry = { 53 | "default": build_sam_vit_b, 54 | "vit_h": build_sam_vit_h, 55 | "vit_l": build_sam_vit_l, 56 | "vit_b": build_sam_vit_b, 57 | } 58 | 59 | 60 | def _build_sam( 61 | args, 62 | encoder_embed_dim, 63 | encoder_depth, 64 | encoder_num_heads, 65 | encoder_global_attn_indexes, 66 | checkpoint=None, 67 | ): 68 | prompt_embed_dim = 256 69 | image_size = args.image_size 70 | vit_patch_size = 16 71 | image_embedding_size = image_size // vit_patch_size 72 | sam = Sam( 73 | args, 74 | image_encoder=ImageEncoderViT( 75 | args=args, 76 | depth=encoder_depth, 77 | embed_dim=encoder_embed_dim, 78 | img_size=image_size, 79 | mlp_ratio=4, 80 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 81 | num_heads=encoder_num_heads, 82 | patch_size=vit_patch_size, 83 | qkv_bias=True, 84 | # use_rel_pos=True, 85 | use_rel_pos=False, 86 | global_attn_indexes=encoder_global_attn_indexes, 87 | window_size=14, 88 | out_chans=prompt_embed_dim, 89 | ), 90 | prompt_encoder=PromptEncoder( 91 | embed_dim=prompt_embed_dim, 92 | image_embedding_size=(image_embedding_size, image_embedding_size), 93 | input_image_size=(image_size, image_size), 94 | mask_in_chans=16, 95 | ), 96 | mask_decoder=MaskDecoder( 97 | num_multimask_outputs=3, 98 | transformer=TwoWayTransformer( 99 | depth=2, 100 | embedding_dim=prompt_embed_dim, 101 | mlp_dim=2048, 102 | num_heads=8, 103 | ), 104 | transformer_dim=prompt_embed_dim, 105 | iou_head_depth=3, 106 | iou_head_hidden_dim=256, 107 | ), 108 | super_prompt=Super_Prompt(in_chns=args.in_channels, class_num=args.num_classes, point_nums=args.point_nums, box_nums=args.box_nums), 109 | # smooth_model=Smooth_model(in_channels=args.num_classes, out_channels=args.num_classes), 110 | pixel_mean=[123.675, 116.28, 103.53], 111 | pixel_std=[58.395, 57.12, 57.375], 112 | ) 113 | 114 | sam.eval() 115 | checkpoint = Path(checkpoint) 116 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 117 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 118 | if len(cmd) == 0 or cmd.lower() == 'y': 119 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 120 | print("Downloading SAM ViT-B checkpoint...") 121 | urllib.request.urlretrieve( 122 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 123 | checkpoint, 124 | ) 125 | print(checkpoint.name, " is downloaded!") 126 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 127 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 128 | if len(cmd) == 0 or cmd.lower() == 'y': 129 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 130 | print("Downloading SAM ViT-H checkpoint...") 131 | urllib.request.urlretrieve( 132 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 133 | checkpoint, 134 | ) 135 | print(checkpoint.name, " is downloaded!") 136 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 137 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 138 | if len(cmd) == 0 or cmd.lower() == 'y': 139 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 140 | print("Downloading SAM ViT-L checkpoint...") 141 | urllib.request.urlretrieve( 142 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 143 | checkpoint, 144 | ) 145 | print(checkpoint.name, " is downloaded!") 146 | 147 | 148 | if checkpoint is not None: 149 | with open(checkpoint, "rb") as f: 150 | state_dict = torch.load(f) 151 | 152 | model_dict = sam.state_dict() 153 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 154 | model_dict.update(pretrained_dict) 155 | sam.load_state_dict(model_dict, strict=False) 156 | # sam.load_state_dict(state_dict, strict = False) 157 | return sam 158 | # if checkpoint is not None: 159 | # with open(checkpoint, "rb") as f: 160 | # state_dict = torch.load(f) 161 | # sam.load_state_dict(state_dict,strict=False) 162 | # return sam 163 | -------------------------------------------------------------------------------- /Model/sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | -------------------------------------------------------------------------------- /Model/sam/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Optional, Tuple, Type 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from einops import rearrange 14 | 15 | from ...common import Adapter, LayerNorm2d 16 | from ...ImageEncoder import AdapterBlock, Block 17 | 18 | 19 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 20 | class ImageEncoderViT(nn.Module): 21 | def __init__( 22 | self, 23 | args, 24 | img_size: int = 1024, 25 | patch_size: int = 16, 26 | in_chans: int = 3, 27 | embed_dim: int = 768, 28 | depth: int = 12, 29 | num_heads: int = 12, 30 | mlp_ratio: float = 4.0, 31 | out_chans: int = 256, 32 | qkv_bias: bool = True, 33 | norm_layer: Type[nn.Module] = nn.LayerNorm, 34 | act_layer: Type[nn.Module] = nn.GELU, 35 | use_abs_pos: bool = True, 36 | use_rel_pos: bool = False, 37 | rel_pos_zero_init: bool = True, 38 | window_size: int = 0, 39 | global_attn_indexes: Tuple[int, ...] = (), 40 | ) -> None: 41 | """ 42 | Args: 43 | img_size (int): Input image size. 44 | patch_size (int): Patch size. 45 | in_chans (int): Number of input image channels. 46 | embed_dim (int): Patch embedding dimension. 47 | depth (int): Depth of 48 | ViT. 49 | num_heads (int): Number of attention heads in each ViT block. 50 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 51 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 52 | norm_layer (nn.Module): Normalization layer. 53 | act_layer (nn.Module): Activation layer. 54 | use_abs_pos (bool): If True, use absolute positional embeddings. 55 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 56 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 57 | window_size (int): Window size for window attention blocks. 58 | global_attn_indexes (list): Indexes for blocks using global attention. 59 | """ 60 | super().__init__() 61 | self.img_size = img_size 62 | self.args = args 63 | 64 | self.patch_embed = PatchEmbed( 65 | kernel_size=(patch_size, patch_size), 66 | stride=(patch_size, patch_size), 67 | in_chans=self.args.in_channels, 68 | embed_dim=embed_dim, 69 | ) 70 | 71 | self.pos_embed: Optional[nn.Parameter] = None 72 | if use_abs_pos: 73 | # Initialize absolute positional embedding with pretrain image size. 74 | self.pos_embed = nn.Parameter( 75 | torch.zeros(1, 1024 // patch_size, 1024 // patch_size, embed_dim) 76 | ) 77 | 78 | self.blocks = nn.ModuleList() 79 | if args.mod == 'sam_adpt': 80 | block_class = AdapterBlock 81 | else: 82 | block_class = Block 83 | 84 | for i in range(depth): 85 | block = block_class( 86 | args=self.args, 87 | dim=embed_dim, 88 | num_heads=num_heads, 89 | mlp_ratio=mlp_ratio, 90 | qkv_bias=qkv_bias, 91 | norm_layer=norm_layer, 92 | act_layer=act_layer, 93 | use_rel_pos=use_rel_pos, 94 | rel_pos_zero_init=rel_pos_zero_init, 95 | window_size=window_size if i not in global_attn_indexes else 0, 96 | input_size=(img_size // patch_size, img_size // patch_size), 97 | ) 98 | self.blocks.append(block) 99 | 100 | self.neck = nn.Sequential( 101 | nn.Conv2d( 102 | embed_dim, 103 | out_chans, 104 | kernel_size=1, 105 | bias=False, 106 | ), 107 | LayerNorm2d(out_chans), 108 | nn.Conv2d( 109 | out_chans, 110 | out_chans, 111 | kernel_size=3, 112 | padding=1, 113 | bias=False, 114 | ), 115 | LayerNorm2d(out_chans), 116 | ) 117 | 118 | def forward(self, x: torch.Tensor) -> torch.Tensor: 119 | 120 | x = self.patch_embed(x) 121 | if self.pos_embed is not None: 122 | # resize position embedding to match the input 123 | new_abs_pos = F.interpolate( 124 | self.pos_embed.permute(0, 3, 1, 2), 125 | size=(x.shape[1], x.shape[2]), 126 | mode="bicubic", 127 | align_corners=False, 128 | ).permute(0, 2, 3, 1) 129 | x = x + new_abs_pos 130 | 131 | for blk in self.blocks: 132 | x = blk(x) 133 | 134 | x = self.neck(x.permute(0, 3, 1, 2)) 135 | return x 136 | 137 | class PatchEmbed(nn.Module): 138 | """ 139 | Image to Patch Embedding. 140 | """ 141 | 142 | def __init__( 143 | self, 144 | kernel_size: Tuple[int, int] = (16, 16), 145 | stride: Tuple[int, int] = (16, 16), 146 | padding: Tuple[int, int] = (0, 0), 147 | in_chans: int = 3, 148 | embed_dim: int = 768, 149 | ) -> None: 150 | """ 151 | Args: 152 | kernel_size (Tuple): kernel size of the projection layer. 153 | stride (Tuple): stride of the projection layer. 154 | padding (Tuple): padding size of the projection layer. 155 | in_chans (int): Number of input image channels. 156 | embed_dim (int): Patch embedding dimension. 157 | """ 158 | super().__init__() 159 | 160 | self.proj = nn.Conv2d( 161 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 162 | ) 163 | 164 | def forward(self, x: torch.Tensor) -> torch.Tensor: 165 | x = self.proj(x) 166 | # B C H W -> B H W C 167 | x = x.permute(0, 2, 3, 1) 168 | return x 169 | 170 | -------------------------------------------------------------------------------- /Model/sam/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from ...common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | if image_embeddings.shape[0] != tokens.shape[0]: 127 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 128 | else: 129 | src = image_embeddings 130 | src = src + dense_prompt_embeddings 131 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 132 | b, c, h, w = src.shape 133 | 134 | # Run the transformer 135 | hs, src = self.transformer(src, pos_src, tokens) 136 | iou_token_out = hs[:, 0, :] 137 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 138 | 139 | # Upscale mask embeddings and predict masks using the mask tokens 140 | src = src.transpose(1, 2).view(b, c, h, w) 141 | upscaled_embedding = self.output_upscaling(src) 142 | hyper_in_list: List[torch.Tensor] = [] 143 | for i in range(self.num_mask_tokens): 144 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 145 | hyper_in = torch.stack(hyper_in_list, dim=1) 146 | b, c, h, w = upscaled_embedding.shape 147 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 148 | 149 | # Generate mask quality predictions 150 | iou_pred = self.iou_prediction_head(iou_token_out) 151 | 152 | return masks, iou_pred 153 | 154 | 155 | # Lightly adapted from 156 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 157 | class MLP(nn.Module): 158 | def __init__( 159 | self, 160 | input_dim: int, 161 | hidden_dim: int, 162 | output_dim: int, 163 | num_layers: int, 164 | sigmoid_output: bool = False, 165 | ) -> None: 166 | super().__init__() 167 | self.num_layers = num_layers 168 | h = [hidden_dim] * (num_layers - 1) 169 | self.layers = nn.ModuleList( 170 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 171 | ) 172 | self.sigmoid_output = sigmoid_output 173 | 174 | def forward(self, x): 175 | for i, layer in enumerate(self.layers): 176 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 177 | if self.sigmoid_output: 178 | x = F.sigmoid(x) 179 | return x 180 | -------------------------------------------------------------------------------- /Model/sam/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from ...common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | # a = torch.tensor([0.5], requires_grad=True, device=points.device) 81 | 82 | # points = points + 0.002 # Shift to center of pixel 83 | # points = torch.add(points, 0.5) 84 | # if pad: 85 | # # padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 86 | # padding_point = torch.zeros((points.shape[0], 1, 256), device=points.device) 87 | # padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 88 | # points = torch.cat([points, padding_point], dim=1) 89 | # labels = torch.cat([labels, padding_label], dim=1) 90 | # # point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 91 | point_embedding = points 92 | # point_embedding[labels == -1] = 0.0 93 | # point_embedding[labels == -1] += self.not_a_point_embed.weight 94 | point_embedding[labels == 0] += self.point_embeddings[0].weight 95 | point_embedding[labels == 1] += self.point_embeddings[1].weight 96 | return point_embedding 97 | 98 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 99 | """Embeds box prompts.""" 100 | boxes = boxes + 0.5 # Shift to center of pixel 101 | coords = boxes.reshape(-1, 2, 2) 102 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 103 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 104 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 105 | return corner_embedding 106 | 107 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 108 | """Embeds mask inputs.""" 109 | mask_embedding = self.mask_downscaling(masks) 110 | return mask_embedding 111 | 112 | def _get_batch_size( 113 | self, 114 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 115 | boxes: Optional[torch.Tensor], 116 | masks: Optional[torch.Tensor], 117 | ) -> int: 118 | """ 119 | Gets the batch size of the output given the batch size of the input prompts. 120 | """ 121 | if points is not None: 122 | return points[0].shape[0] 123 | elif boxes is not None: 124 | return boxes.shape[0] 125 | elif masks is not None: 126 | return masks.shape[0] 127 | else: 128 | return 1 129 | 130 | def _get_device(self) -> torch.device: 131 | return self.point_embeddings[0].weight.device 132 | 133 | def forward( 134 | self, 135 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 136 | boxes: Optional[torch.Tensor], 137 | masks: Optional[torch.Tensor], 138 | ) -> Tuple[torch.Tensor, torch.Tensor]: 139 | """ 140 | Embeds different types of prompts, returning both sparse and dense 141 | embeddings. 142 | 143 | Arguments: 144 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 145 | and labels to embed. 146 | boxes (torch.Tensor or none): boxes to embed 147 | masks (torch.Tensor or none): masks to embed 148 | 149 | Returns: 150 | torch.Tensor: sparse embeddings for the points and boxes, with shape 151 | BxNx(embed_dim), where N is determined by the number of input points 152 | and boxes. 153 | torch.Tensor: dense embeddings for the masks, in the shape 154 | Bx(embed_dim)x(embed_H)x(embed_W) 155 | """ 156 | bs = self._get_batch_size(points, boxes, masks) 157 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 158 | if points is not None: 159 | # coords, labels = points #coords:B,N,2 labels:B,N 160 | # point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 161 | point_embeddings = points[0] 162 | 163 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 164 | if boxes is not None: 165 | # box_embeddings = self._embed_boxes(boxes) 166 | box_embeddings = boxes 167 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 168 | 169 | if masks is not None: 170 | # dense_embeddings = masks 171 | dense_embeddings = self._embed_masks(masks) 172 | else: 173 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 174 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 175 | ) 176 | 177 | return sparse_embeddings, dense_embeddings 178 | 179 | 180 | class PositionEmbeddingRandom(nn.Module): 181 | """ 182 | Positional encoding using random spatial frequencies. 183 | """ 184 | 185 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 186 | super().__init__() 187 | if scale is None or scale <= 0.0: 188 | scale = 1.0 189 | self.register_buffer( 190 | "positional_encoding_gaussian_matrix", 191 | scale * torch.randn((2, num_pos_feats)), 192 | # scale * torch.randn((256, num_pos_feats)), 193 | ) 194 | 195 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 196 | """Positionally encode points that are normalized to [0,1].""" 197 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 198 | coords = 2 * coords - 1 199 | coords = coords @ self.positional_encoding_gaussian_matrix 200 | coords = 2 * np.pi * coords 201 | # outputs d_1 x ... x d_n x C shape 202 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 203 | 204 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 205 | """Generate positional encoding for a grid of the specified size.""" 206 | h, w = size 207 | device: Any = self.positional_encoding_gaussian_matrix.device 208 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 209 | y_embed = grid.cumsum(dim=0) - 0.5 210 | x_embed = grid.cumsum(dim=1) - 0.5 211 | y_embed = y_embed / h 212 | x_embed = x_embed / w 213 | 214 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 215 | return pe.permute(2, 0, 1) # C x H x W 216 | 217 | def forward_with_coords( 218 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 219 | ) -> torch.Tensor: 220 | """Positionally encode points that are not normalized to [0,1].""" 221 | # coords = coords_input.clone() 222 | coords = coords_input.clone() 223 | # coords[:, :, 0] = coords[:, :, 0] / image_size[1] 224 | # coords[:, :, 1] = coords[:, :, 1] / image_size[0] 225 | coords[:, :, 0] = coords[:, :, 0] 226 | coords[:, :, 1] = coords[:, :, 1] 227 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 228 | -------------------------------------------------------------------------------- /Model/sam/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | from Model.prompt import Super_Prompt 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | args, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | super_prompt: Super_Prompt, 29 | # smooth_model: Smooth_model, 30 | 31 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 32 | pixel_std: List[float] = [58.395, 57.12, 57.375], 33 | ) -> None: 34 | """ 35 | SAM predicts object masks from an image and input prompts. 36 | 37 | Arguments: 38 | image_encoder (ImageEncoderViT): The backbone used to encode the 39 | image into image embeddings that allow for efficient mask prediction. 40 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 41 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 42 | and encoded prompts. 43 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 44 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 45 | """ 46 | super().__init__() 47 | self.args = args 48 | self.image_encoder = image_encoder 49 | self.prompt_encoder = prompt_encoder 50 | self.mask_decoder = mask_decoder 51 | self.super_prompt = super_prompt 52 | # self.smooth_model = smooth_model 53 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 54 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 55 | 56 | @property 57 | def device(self) -> Any: 58 | return self.pixel_mean.device 59 | 60 | @torch.no_grad() 61 | def forward( 62 | self, 63 | batched_input: List[Dict[str, Any]], 64 | multimask_output: bool, 65 | ) -> List[Dict[str, torch.Tensor]]: 66 | """ 67 | Predicts masks end-to-end from provided images and prompts. 68 | If prompts are not known in advance, using SamPredictor is 69 | recommended over calling the model directly. 70 | 71 | Arguments: 72 | batched_input (list(dict)): A list over input images, each a 73 | dictionary with the following keys. A prompt key can be 74 | excluded if it is not present. 75 | 'image': The image as a torch tensor in 3xHxW format, 76 | already transformed for input to the model. 77 | 'original_size': (tuple(int, int)) The original size of 78 | the image before transformation, as (H, W). 79 | 'point_coords': (torch.Tensor) Batched point prompts for 80 | this image, with shape BxNx2. Already transformed to the 81 | input frame of the model. 82 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 83 | with shape BxN. 84 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 85 | Already transformed to the input frame of the model. 86 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 87 | in the form Bx1xHxW. 88 | multimask_output (bool): Whether the model should predict multiple 89 | disambiguating masks, or return a single mask. 90 | 91 | Returns: 92 | (list(dict)): A list over input images, where each element is 93 | as dictionary with the following keys. 94 | 'masks': (torch.Tensor) Batched binary mask predictions, 95 | with shape BxCxHxW, where B is the number of input prompts, 96 | C is determined by multimask_output, and (H, W) is the 97 | original size of the image. 98 | 'iou_predictions': (torch.Tensor) The model's predictions 99 | of mask quality, in shape BxC. 100 | 'low_res_logits': (torch.Tensor) Low resolution logits with 101 | shape BxCxHxW, where H=W=256. Can be passed as mask input 102 | to subsequent iterations of prediction. 103 | """ 104 | 105 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 106 | image_embeddings = self.image_encoder(input_images) 107 | outputs = [] 108 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 109 | if "point_coords" in image_record: 110 | points = (image_record["point_coords"], image_record["point_labels"]) 111 | else: 112 | points = None 113 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 114 | points=points, 115 | boxes=image_record.get("boxes", None), 116 | masks=image_record.get("mask_inputs", None), 117 | ) 118 | low_res_masks, iou_predictions = self.mask_decoder( 119 | image_embeddings=curr_embedding.unsqueeze(0), 120 | image_pe=self.prompt_encoder.get_dense_pe(), 121 | sparse_prompt_embeddings=sparse_embeddings, 122 | dense_prompt_embeddings=dense_embeddings, 123 | multimask_output=multimask_output, 124 | ) 125 | masks = self.postprocess_masks( 126 | low_res_masks, 127 | input_size=image_record["image"].shape[-2:], 128 | original_size=image_record["original_size"], 129 | ) 130 | masks = masks > self.mask_threshold 131 | outputs.append( 132 | { 133 | "masks": masks, 134 | "iou_predictions": iou_predictions, 135 | "low_res_logits": low_res_masks, 136 | } 137 | ) 138 | return outputs 139 | 140 | def postprocess_masks( 141 | self, 142 | masks: torch.Tensor, 143 | input_size: Tuple[int, ...], 144 | original_size: Tuple[int, ...], 145 | ) -> torch.Tensor: 146 | """ 147 | Remove padding and upscale masks to the original image size. 148 | 149 | Arguments: 150 | masks (torch.Tensor): Batched masks from the mask_decoder, 151 | in BxCxHxW format. 152 | input_size (tuple(int, int)): The size of the image input to the 153 | model, in (H, W) format. Used to remove padding. 154 | original_size (tuple(int, int)): The original size of the image 155 | before resizing for input to the model, in (H, W) format. 156 | 157 | Returns: 158 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 159 | is given by original_size. 160 | """ 161 | masks = F.interpolate( 162 | masks, 163 | (self.image_encoder.img_size, self.image_encoder.img_size), 164 | mode="bilinear", 165 | align_corners=False, 166 | ) 167 | masks = masks[..., : input_size[0], : input_size[1]] 168 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 169 | return masks 170 | 171 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 172 | """Normalize pixel values and pad to a square input.""" 173 | # Normalize colors 174 | x = (x - self.pixel_mean) / self.pixel_std 175 | 176 | # Pad 177 | h, w = x.shape[-2:] 178 | padh = self.image_encoder.img_size - h 179 | padw = self.image_encoder.img_size - w 180 | x = F.pad(x, (0, padw, 0, padh)) 181 | return x 182 | -------------------------------------------------------------------------------- /Model/sam/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /Model/sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /Model/sam/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /Model/sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 40 | new_coords = np.empty_like(coords) 41 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 42 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 43 | return new_coords 44 | 45 | 46 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 47 | """ 48 | Expects a numpy array shape Bx4. Requires the original image size 49 | in (H, W) format. 50 | """ 51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 52 | return boxes.reshape(-1, 4) 53 | 54 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Expects batched images with shape BxCxHxW and float format. This 57 | transformation may not exactly match apply_image. apply_image is 58 | the transformation expected by the model. 59 | """ 60 | # Expects an image in BCHW format. May not exactly match apply_image. 61 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 62 | return F.interpolate( 63 | image, target_size, mode="bilinear", align_corners=False, antialias=True 64 | ) 65 | 66 | def apply_coords_torch( 67 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a torch tensor with length 2 in the last dimension. Requires the 71 | original image size in (H, W) format. 72 | """ 73 | old_h, old_w = original_size 74 | new_h, new_w = self.get_preprocess_shape( 75 | original_size[0], original_size[1], self.target_length 76 | ) 77 | coords = deepcopy(coords).to(torch.float) 78 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 79 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 80 | return coords 81 | 82 | def apply_boxes_torch( 83 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 84 | ) -> torch.Tensor: 85 | """ 86 | Expects a torch tensor with shape Bx4. Requires the original image 87 | size in (H, W) format. 88 | """ 89 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 90 | return boxes.reshape(-1, 4) 91 | 92 | @staticmethod 93 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 94 | """ 95 | Compute the output size given input size and target long side length. 96 | """ 97 | scale = long_side_length * 1.0 / max(oldh, oldw) 98 | newh, neww = oldh * scale, oldw * scale 99 | neww = int(neww + 0.5) 100 | newh = int(newh + 0.5) 101 | return (newh, neww) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KnowSAM 2 | Official code for "[Learnable Prompting SAM-induced Knowledge Distillation for Semi-supervised Medical Image Segmentation](https://arxiv.org/pdf/2412.13742)" 3 | 4 | ## Installation 5 | 6 | To set up the environment and install dependencies, run: 7 | 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Extract Sample Data 13 | 14 | We provide a reference sample dataset (SampleData.rar) that allows users to quickly test and run the model. Extract the dataset using the following command: 15 | ```bash 16 | unrar x SampleData.rar 17 | ``` 18 | For processed ACDC dataset, you can download it from the [ACDC](https://github.com/HiLab-git/SSL4MIS/tree/master/data/ACDC), and place it directly in the `SampleData` folder. 19 | 20 | 21 | ## Training 22 | To train the model on a dataset, execute: 23 | ```bash 24 | python train_semi_SAM.py 25 | ``` 26 | 27 | For ACDC dataset training: 28 | ```bash 29 | python train_semi_SAM_ACDC.py 30 | ``` 31 | 32 | ## Prediction 33 | After training, you can make predictions using: 34 | ```bash 35 | python prediction.py 36 | ``` 37 | 38 | For ACDC dataset inference: 39 | ```bash 40 | python prediction_ACDC.py 41 | ``` 42 | 43 | ## Acknowledgements 44 | Our code is based on [SSL4MIS](https://github.com/HiLab-git/SSL4MIS). 45 | 46 | ## Questions 47 | If you have any questions, welcome contact me at 'taozhou.dreams@gmail.com' 48 | -------------------------------------------------------------------------------- /SampleData.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taozh2017/KnowSAM/5ceef5a9479bac93f5a2f7a61fba4689cebbf1bf/SampleData.rar -------------------------------------------------------------------------------- /dataloader/TwoStreamBatchSampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | def iterate_once(iterable): 7 | return np.random.permutation(iterable) 8 | 9 | 10 | def iterate_eternally(indices): 11 | def infinite_shuffles(): 12 | while True: 13 | yield np.random.permutation(indices) 14 | 15 | return itertools.chain.from_iterable(infinite_shuffles()) 16 | 17 | 18 | def grouper(iterable, n): 19 | "Collect data into fixed-length chunks or blocks" 20 | # grouper('ABCDEFG', 3) --> ABC DEF" 21 | args = [iter(iterable)] * n 22 | return zip(*args) 23 | 24 | 25 | class TwoStreamBatchSampler(Sampler): 26 | """Iterate two sets of indices 27 | 28 | An 'epoch' is one iteration through the primary indices. 29 | During the epoch, the secondary indices are iterated through 30 | as many times as needed. 31 | """ 32 | 33 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 34 | self.primary_indices = primary_indices 35 | self.secondary_indices = secondary_indices 36 | self.secondary_batch_size = secondary_batch_size 37 | self.primary_batch_size = batch_size - secondary_batch_size 38 | 39 | assert len(self.primary_indices) >= self.primary_batch_size > 0 40 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 41 | 42 | def __iter__(self): 43 | primary_iter = iterate_once(self.primary_indices) 44 | secondary_iter = iterate_eternally(self.secondary_indices) 45 | return ( 46 | primary_batch + secondary_batch 47 | for (primary_batch, secondary_batch) in zip( 48 | grouper(primary_iter, self.primary_batch_size), 49 | grouper(secondary_iter, self.secondary_batch_size), 50 | ) 51 | ) 52 | 53 | def __len__(self): 54 | return len(self.primary_indices) // self.primary_batch_size -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import cv2 5 | import numpy as np 6 | from torch.utils.data import Dataset, DataLoader 7 | from utils.utils import patients_to_slices 8 | 9 | class build_Dataset(Dataset): 10 | def __init__(self, args, data_dir, split, transform=None, labeled_slice=None, model="None"): 11 | self.data_dir = data_dir 12 | self.split = split 13 | self.transform = transform 14 | self.sample_list = [] 15 | self.model = model 16 | self.pixel_mean = [123.675, 116.28, 103.53] 17 | self.pixel_std = [58.395, 57.12, 57.375] 18 | self.args = args 19 | 20 | if self.split == "train": 21 | labeled_path = os.path.join(self.data_dir + "/labeled/image") 22 | sample_list_labeled = os.listdir(labeled_path) 23 | sample_list_labeled = [os.path.join(labeled_path, item) for item in sample_list_labeled] 24 | self.sample_list = sample_list_labeled 25 | print("train total {} samples".format(len(self.sample_list))) 26 | elif self.split == "train_semi": 27 | labeled_path = os.path.join(self.data_dir + "/labeled/image") 28 | unlabeled_path = os.path.join(self.data_dir + "/unlabeled/image") 29 | sample_list_labeled = os.listdir(labeled_path) 30 | sample_list_unlabeled = os.listdir(unlabeled_path) 31 | self.sample_list_labeled = [os.path.join(labeled_path, item) for item in sample_list_labeled] 32 | self.sample_list_unlabeled = [os.path.join(unlabeled_path, item) for item in sample_list_unlabeled] 33 | self.sample_list = self.sample_list_labeled + self.sample_list_unlabeled 34 | print("train total {} labeled samples, {} unlabeled samples". 35 | format(len(sample_list_labeled), len(sample_list_unlabeled))) 36 | elif self.split == "val": 37 | val_path = os.path.join(self.data_dir + "/val/image") 38 | sample_list_val = os.listdir(val_path) 39 | self.sample_list = [os.path.join(val_path, item) for item in sample_list_val] 40 | print("val total {} samples".format(len(self.sample_list))) 41 | elif self.split == "train_semi_list": 42 | labeled_path = os.path.join(self.data_dir + "/train.list") 43 | with open(labeled_path, 'r') as f: 44 | self.image_list = f.readlines() 45 | self.image_list = [item.replace('\n', '') for item in self.image_list] 46 | self.sample_list = [self.data_dir + "/images/" + image_name for image_name in self.image_list] 47 | self.sample_list_labeled = patients_to_slices(args.dataset, args.labeled_num) 48 | 49 | print("train total {} samples".format(len(self.sample_list))) 50 | elif self.split == "val_semi_list": 51 | labeled_path = os.path.join(self.data_dir + "/val.list") 52 | with open(labeled_path, 'r') as f: 53 | self.image_list = f.readlines() 54 | self.image_list = [item.replace('\n', '') for item in self.image_list] 55 | self.sample_list = [self.data_dir + "/images/" + image_name for image_name in self.image_list] 56 | print("val total {} samples".format(len(self.sample_list))) 57 | elif self.split == "train_acdc_list": 58 | labeled_path = os.path.join(self.data_dir + "/train_slices.list") 59 | with open(labeled_path, 'r') as f: 60 | self.image_list = f.readlines() 61 | self.image_list = [item.replace('\n', '') for item in self.image_list] 62 | self.sample_list = [self.data_dir + "/data/slices/" + image_name + ".h5" for image_name in self.image_list] 63 | self.sample_list_labeled = patients_to_slices(args.dataset, args.labeled_num) 64 | print("train total {} samples".format(len(self.sample_list))) 65 | elif self.split == "val_acdc_list": 66 | labeled_path = os.path.join(self.data_dir + "/val.list") 67 | with open(labeled_path, 'r') as f: 68 | self.image_list = f.readlines() 69 | self.image_list = [item.replace('\n', '') for item in self.image_list] 70 | self.sample_list = [self.data_dir + "/data/" + image_name + ".h5" for image_name in self.image_list] 71 | print("val total {} samples".format(len(self.sample_list))) 72 | elif self.split == "test_acdc_list": 73 | labeled_path = os.path.join(self.data_dir + "/test.list") 74 | with open(labeled_path, 'r') as f: 75 | self.image_list = f.readlines() 76 | self.image_list = [item.replace('\n', '') for item in self.image_list] 77 | self.sample_list = [self.data_dir + "/data/" + image_name + ".h5" for image_name in self.image_list] 78 | print("test total {} samples".format(len(self.sample_list))) 79 | elif "test" in self.split: 80 | if "CVC-300" in self.split: 81 | test_path = os.path.join(self.data_dir + "/TestDataset/CVC-300/image") 82 | elif "CVC-ClinicDB" in self.split: 83 | test_path = os.path.join(self.data_dir + "/TestDataset/CVC-ClinicDB/image") 84 | elif "CVC-ColonDB" in self.split: 85 | test_path = os.path.join(self.data_dir + "/TestDataset/CVC-ColonDB/image") 86 | elif "ETIS-LaribPolypDB" in self.split: 87 | test_path = os.path.join(self.data_dir + "/TestDataset/ETIS-LaribPolypDB/image") 88 | elif "Kvasir" in self.split: 89 | test_path = os.path.join(self.data_dir + "/TestDataset/Kvasir/image") 90 | elif "ISIC2018" in self.split: 91 | test_path = os.path.join(self.data_dir + "/TestDataset/image") 92 | elif "DDTI" in self.split: 93 | test_path = os.path.join(self.data_dir + "/TestDataset/DDTI/image") 94 | elif "tn3k" in self.split: 95 | test_path = os.path.join(self.data_dir + "/TestDataset/tn3k/image") 96 | elif "BrainMRI" in self.split: 97 | test_path = os.path.join(self.data_dir + "/TestDataset/image") 98 | elif "MRI_Hippocampus" in self.split: 99 | test_path = os.path.join(self.data_dir + "/TestDataset/image") 100 | else: 101 | test_path = None 102 | 103 | if test_path: 104 | print('test_path: ', test_path) 105 | sample_list_val = os.listdir(test_path) 106 | self.sample_list = [os.path.join(test_path, item) for item in sample_list_val] 107 | print("test total {} samples".format(len(self.sample_list))) 108 | else: 109 | if "BCSS" in self.split: 110 | test_list_path = os.path.join(self.data_dir + "/test.list") 111 | 112 | with open(test_list_path, 'r') as f: 113 | self.image_list = f.readlines() 114 | self.image_list = [item.replace('\n', '') for item in self.image_list] 115 | self.sample_list = [self.data_dir + "/images/" + image_name for image_name in self.image_list] 116 | print("test total {} samples".format(len(self.sample_list))) 117 | 118 | def __len__(self): 119 | return len(self.sample_list) 120 | 121 | def __getitem__(self, idx): 122 | 123 | if "_list" not in self.split: 124 | case = self.sample_list[idx] 125 | image = cv2.cvtColor(cv2.imread(case), cv2.COLOR_BGR2RGB) 126 | ori_image = cv2.cvtColor(cv2.resize(image.copy(), (256, 256)), cv2.COLOR_RGB2BGR) 127 | # image = (image - self.pixel_mean) / self.pixel_std 128 | image = image / 255.0 129 | image = image.astype(np.float32) 130 | label_path = case.replace("image", "mask") 131 | 132 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) / 255 133 | if "val" in self.split or "test" in self.split: 134 | if self.transform: 135 | data = self.transform(image=image, mask=label) 136 | image = data['image'] 137 | label = data['mask'] 138 | else: 139 | if self.transform: 140 | if idx < len(self.sample_list_labeled): 141 | data = self.transform["train_weak"](image=image, mask=label) 142 | image = data['image'] 143 | label = data['mask'] 144 | else: 145 | data = self.transform["train_strong"](image=image, mask=label) 146 | image = data['image'] 147 | label = data['mask'] 148 | 149 | label[label < 0.5] = 0 150 | label[label > 0.5] = 1 151 | image = image.transpose(2, 0, 1).astype('float32') 152 | ori_image = ori_image.transpose(2, 0, 1).astype('float32') 153 | image, label = torch.tensor(image), torch.tensor(label) 154 | if "test" in self.split: 155 | sample = {"image": image, "label": label, "ori_image": ori_image} 156 | else: 157 | sample = {"image": image, "label": label,} 158 | return sample 159 | 160 | else: 161 | if "val" not in self.split and "test" not in self.split: 162 | case = self.sample_list[idx] 163 | h5f = h5py.File(case) 164 | image = h5f['image'][:].astype(np.float32) 165 | label = h5f['label'][:].astype(np.float32) 166 | if self.transform: 167 | if idx < self.sample_list_labeled: 168 | data = self.transform["train_weak"](image=image, mask=label) 169 | image = data['image'] 170 | label = data['mask'] 171 | else: 172 | data = self.transform["train_strong"](image=image, mask=label) 173 | image = data['image'] 174 | label = data['mask'] 175 | image = np.expand_dims(image, axis=0) 176 | label = label 177 | image, label = torch.tensor(image), torch.tensor(label) 178 | image = image.repeat(3, 1, 1) 179 | sample = {"image": image, "label": label, "idx": idx} 180 | return sample 181 | else: 182 | case = self.sample_list[idx] 183 | h5f = h5py.File(case) 184 | image = h5f['image'][:].astype('float32') 185 | label = h5f['label'][:].astype('float32') 186 | 187 | image, label = torch.tensor(image), torch.tensor(label) 188 | if self.model != "CAML": 189 | sample = {"image": image, "label": label} 190 | else: 191 | sample = {"image": image, "label": label, "idx": idx} 192 | return sample 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | import random 4 | 5 | def build_transforms(args): 6 | data_transforms = { 7 | "train": A.Compose([ 8 | A.OneOf([ 9 | A.Resize(*[args.image_size, args.image_size], interpolation=cv2.INTER_NEAREST, p=1.0), 10 | ], p=1), 11 | 12 | A.HorizontalFlip(p=0.5), 13 | A.VerticalFlip(p=0.5), 14 | A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5), 15 | A.OneOf([ 16 | A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0), 17 | A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0) 18 | ], p=0.25), 19 | A.CoarseDropout(max_holes=8, max_height=args.image_size // 20, max_width=args.image_size // 20, 20 | min_holes=5, fill_value=0, mask_fill_value=0, p=0.5), 21 | ], p=1.0), 22 | 23 | 24 | "valid_test": A.Compose([ 25 | A.Resize(*[args.image_size, args.image_size], interpolation=cv2.INTER_NEAREST), 26 | ], p=1.0) 27 | } 28 | return data_transforms 29 | 30 | 31 | def build_weak_strong_transforms(args): 32 | data_transforms = { 33 | "train_weak": 34 | A.Compose([ 35 | A.OneOf([A.Resize(*[args.image_size, args.image_size], interpolation=cv2.INTER_NEAREST, p=1.0),], p=1), 36 | A.HorizontalFlip(p=0.5), 37 | A.VerticalFlip(p=0.5), 38 | A.RandomBrightnessContrast(p=0.2), 39 | A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5), 40 | # A.OneOf([A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0), 41 | # A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0) 42 | # ], p=0.25), 43 | A.CoarseDropout(max_holes=8, max_height=args.image_size // 20, max_width=args.image_size // 20, 44 | min_holes=5, fill_value=0, mask_fill_value=0, p=0.5), 45 | ], p=1.0), 46 | 47 | "train_strong": 48 | A.Compose([ 49 | A.OneOf([A.Resize(*[args.image_size, args.image_size], interpolation=cv2.INTER_NEAREST, p=1.0), ], p=1), 50 | 51 | A.HorizontalFlip(p=0.5), 52 | A.VerticalFlip(p=0.5), 53 | A.RandomBrightnessContrast(p=0.6), 54 | A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5), 55 | 56 | A.OneOf([ 57 | A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0), 58 | A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0) 59 | ], p=0.5), 60 | A.CoarseDropout(max_holes=20, max_height=256 // 20, max_width=256 // 20, 61 | min_holes=10, fill_value=0, mask_fill_value=0, p=0.7), 62 | ], p=1.0), 63 | 64 | 65 | "valid_test": A.Compose([ 66 | A.Resize(*[args.image_size, args.image_size], interpolation=cv2.INTER_NEAREST), 67 | ], p=1.0) 68 | } 69 | return data_transforms 70 | 71 | 72 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | from dataloader.dataset import build_Dataset 7 | from dataloader.transforms import build_transforms 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | from utils.utils import eval 11 | from Model.model import KnowSAM 12 | 13 | 14 | def get_entropy_map(p): 15 | ent_map = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True) 16 | return ent_map 17 | 18 | 19 | from skimage.measure import label 20 | def get_ACDC_2DLargestCC(segmentation): 21 | batch_list = [] 22 | N = segmentation.shape[0] 23 | for i in range(0, N): 24 | class_list = [] 25 | for c in range(1, 2): 26 | temp_seg = segmentation[i] # == c * torch.ones_like(segmentation[i]) 27 | temp_prob = torch.zeros_like(temp_seg) 28 | temp_prob[temp_seg == c] = 1 29 | temp_prob = temp_prob.detach().cpu().numpy() 30 | labels = label(temp_prob) 31 | if labels.max() != 0: 32 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 33 | class_list.append(largestCC * c) 34 | else: 35 | class_list.append(temp_prob) 36 | 37 | n_batch = class_list[0] 38 | batch_list.append(n_batch) 39 | 40 | return torch.Tensor(batch_list).cuda() 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--data_path', type=str, 46 | default='./SampleData', 47 | help='Name of Experiment') 48 | 49 | parser.add_argument('--dataset', type=str, default='/tumor_1', 50 | help='Name of Experiment') 51 | 52 | parser.add_argument('--num_classes', type=int, default=2, 53 | help='output channel of network') 54 | parser.add_argument('--in_channels', type=int, default=3, 55 | help='input channel of network') 56 | parser.add_argument('--image_size', type=list, default=256, 57 | help='patch size of network input') 58 | parser.add_argument('--point_nums', type=int, default=10, help='points number') 59 | parser.add_argument('--box_nums', type=int, default=1, help='boxes number') 60 | parser.add_argument('--mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad') 61 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 62 | parser.add_argument('--thd', type=bool, default=False, help='3d or not') 63 | 64 | parser.add_argument('--sam_model_path', type=str, 65 | default="./Results/Result_tumor_10/fold_0/sam_best_model.pth", 66 | help='model weight path') 67 | 68 | parser.add_argument('--SGDL_model_path', type=str, 69 | default="./Results/Result_tumor_10/fold_0/SGDL_best_model.pth", 70 | help='model weight path') 71 | 72 | parser.add_argument('--device', type=str, default='cuda') 73 | args = parser.parse_args() 74 | bilinear = True 75 | Largest = False 76 | data_transforms = build_transforms(args) 77 | 78 | test_dataset_list = ["test_CVC-300", "test_CVC-ClinicDB",] 79 | 80 | for test_dataset_name in test_dataset_list: 81 | test_dataset = build_Dataset(args, data_dir=args.data_path + args.dataset, split=test_dataset_name, 82 | transform=data_transforms["valid_test"]) 83 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 84 | 85 | # print(args.SGDL_model_path) 86 | SGDL_model = KnowSAM(args, bilinear=bilinear).to(args.device).train() 87 | SGDL_checkpoint = torch.load(args.SGDL_model_path) 88 | SGDL_model.load_state_dict(SGDL_checkpoint) 89 | SGDL_model.eval() 90 | 91 | avg_dice_list = [] 92 | avg_hd95_list = [] 93 | avg_iou_list = [] 94 | avg_sp_list = [] 95 | avg_se_list = [] 96 | avg_prec_list = [] 97 | avg_recall_list = [] 98 | for i_batch, sampled_batch in enumerate(test_loader): 99 | test_image, test_label, ori_image = sampled_batch["image"].cuda(), sampled_batch["label"].cuda(), sampled_batch["ori_image"].cuda() 100 | pred_UNet, pred_VNet, pred_UNet_soft, pred_VNet_soft, fusion_map = SGDL_model(test_image) 101 | fusion_map_soft = torch.softmax(fusion_map, dim=1) 102 | 103 | if Largest: 104 | pseudo_label = torch.argmax(fusion_map_soft, dim=1) 105 | fusion_map_soft = get_ACDC_2DLargestCC(pseudo_label).unsqueeze(0) 106 | 107 | eval_list = eval(test_label, fusion_map_soft, thr=0.5) 108 | 109 | avg_dice_list.append(eval_list[0]) 110 | avg_iou_list.append(eval_list[1]) 111 | avg_hd95_list.append(eval_list[2]) 112 | 113 | avg_dice = np.mean(avg_dice_list) 114 | avg_hd95 = np.mean(avg_hd95_list) 115 | avg_iou = np.mean(avg_iou_list) 116 | 117 | print(test_dataset_name, " :") 118 | print("avg_dice: ", avg_dice) 119 | print("avg_iou: ", avg_iou) 120 | print("avg_hd95: ", avg_hd95) 121 | -------------------------------------------------------------------------------- /prediction_ACDC.py: -------------------------------------------------------------------------------- 1 | 2 | from medpy import metric 3 | from scipy.ndimage import zoom 4 | 5 | 6 | def getLargestCC(segmentation): 7 | from skimage.measure import label 8 | labels = label(segmentation) 9 | #assert( labels.max() != 0 ) # assume at least 1 CC 10 | if labels.max() != 0: 11 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 12 | else: 13 | largestCC = segmentation 14 | return largestCC 15 | 16 | def calculate_metric_percase(sam_pred, SGDL_pred, gt): 17 | sam_pred[sam_pred > 0] = 1 18 | SGDL_pred[SGDL_pred > 0] = 1 19 | gt[gt > 0] = 1 20 | dice_res = [] 21 | if sam_pred.sum() > 0: 22 | dice_res.append(metric.binary.dc(sam_pred, gt)) 23 | else: 24 | dice_res.append(0) 25 | 26 | if SGDL_pred.sum() > 0: 27 | dice_res.append(metric.binary.dc(SGDL_pred, gt)) 28 | else: 29 | dice_res.append(0) 30 | 31 | return dice_res 32 | 33 | 34 | def get_entropy_map(p): 35 | ent_map = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True) 36 | return ent_map 37 | 38 | 39 | def test_single_volume(args, image, label, sam_model, SGDL): 40 | classes = args.num_classes 41 | patch_size = [args.image_size, args.image_size] 42 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 43 | sam_prediction = np.zeros_like(label) 44 | SGDL_prediction = np.zeros_like(label) 45 | for ind in range(image.shape[0]): 46 | slice = image[ind, :, :] 47 | x, y = slice.shape[0], slice.shape[1] 48 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 49 | 50 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 51 | input = input.repeat(1,3,1,1) 52 | with torch.no_grad(): 53 | pred_UNet, pred_VNet, pred_UNet_soft, pred_VNet_soft, fusion_map = SGDL(input) 54 | image_embeddings = sam_model.image_encoder(input) 55 | points_embedding, boxes_embedding, mask_embedding = sam_model.super_prompt(image_embeddings) 56 | 57 | low_res_masks_all = torch.empty( 58 | (1, 0, int(args.image_size / 4), int(args.image_size / 4)), 59 | device=args.device) 60 | with torch.no_grad(): 61 | for i in range(args.num_classes): 62 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 63 | # points=points_embedding[i].unsqueeze(0), 64 | points=None, 65 | # boxes=None, 66 | boxes=boxes_embedding[i], 67 | # masks=mask_embedding[i], 68 | masks=F.interpolate(fusion_map[:, i, ...].unsqueeze(1).clone().detach(), size=(64, 64), 69 | mode='bilinear'), 70 | # masks=None, 71 | ) 72 | low_res_masks, iou_predictions = sam_model.mask_decoder( 73 | image_embeddings=image_embeddings, 74 | image_pe=sam_model.prompt_encoder.get_dense_pe(), 75 | sparse_prompt_embeddings=sparse_embeddings, 76 | dense_prompt_embeddings=dense_embeddings, 77 | multimask_output=args.multimask, 78 | ) 79 | low_res_masks_all = torch.cat((low_res_masks_all, low_res_masks), dim=1) 80 | 81 | pred_sam = F.interpolate(low_res_masks_all, size=(args.image_size, args.image_size)) 82 | pred_sam_soft = torch.softmax(pred_sam, dim=1) 83 | fusion_map_soft = torch.softmax(fusion_map, dim=1) 84 | 85 | out_SGDL = torch.argmax(fusion_map_soft, dim=1).squeeze(0).cpu().detach().numpy() 86 | out_sam = torch.argmax(pred_sam_soft, dim=1).squeeze(0).cpu().detach().numpy() 87 | 88 | pred_SGDL = zoom(out_SGDL, (x / patch_size[0], y / patch_size[1]), order=0) 89 | pred_sam = zoom(out_sam, (x / patch_size[0], y / patch_size[1]), order=0) 90 | 91 | SGDL_prediction[ind] = pred_SGDL 92 | sam_prediction[ind] = pred_sam 93 | 94 | metric_list = [] 95 | for i in range(1, classes): 96 | metric_list.append(calculate_metric_percase(sam_prediction == i, SGDL_prediction == i, label == i)) 97 | return metric_list 98 | 99 | 100 | if __name__ == '__main__': 101 | import cv2 102 | import torch 103 | import argparse 104 | 105 | import torch.nn.functional as F 106 | from dataloader.dataset import build_Dataset 107 | from dataloader.transforms import build_transforms 108 | from torch.utils.data import DataLoader 109 | import numpy as np 110 | 111 | from Model.model import KnowSAM 112 | 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--data_path', type=str, 115 | default='./SampleData', 116 | help='Name of Experiment') 117 | parser.add_argument('--dataset', type=str, default='/ACDC', 118 | help='Name of Experiment') 119 | parser.add_argument('--num_classes', type=int, default=4, 120 | help='output channel of network') 121 | parser.add_argument('--in_channels', type=int, default=3, 122 | help='input channel of network') 123 | parser.add_argument('--image_size', type=list, default=256, 124 | help='patch size of network input') 125 | parser.add_argument('--point_nums', type=int, default=5, help='points number') 126 | parser.add_argument('--box_nums', type=int, default=1, help='boxes number') 127 | parser.add_argument('--mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad') 128 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 129 | parser.add_argument('--thd', type=bool, default=False, help='3d or not') 130 | parser.add_argument('--device', type=str, default='cuda') 131 | parser.add_argument("--multimask", type=bool, default=False, help="ouput multimask") 132 | 133 | parser.add_argument('--sam_model_path', type=str, 134 | default="./sam_best_model.pth", 135 | help='model weight path') 136 | parser.add_argument('--SGDL_model_path', type=str, 137 | default="./SGDL_iter_16400.pth", 138 | help='model weight path') 139 | 140 | args = parser.parse_args() 141 | data_transforms = build_transforms(args) 142 | 143 | test_dataset = build_Dataset(data_dir=args.data_path + args.dataset, split="test_list", 144 | transform=data_transforms["valid_test"]) 145 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 146 | 147 | model = "SGDL" 148 | if model == "SGDL": 149 | SGDL_model = KnowSAM(args, bilinear=True).to(args.device).train() 150 | SGDL_checkpoint = torch.load(args.SGDL_model_path) 151 | SGDL_model.load_state_dict(SGDL_checkpoint) 152 | SGDL_model.eval() 153 | 154 | avg_dice_list = 0.0 155 | avg_iou_list = 0.0 156 | avg_hd95_list = 0.0 157 | avg_asd_list = 0.0 158 | classes = args.num_classes 159 | patch_size = [args.image_size, args.image_size] 160 | final_res = [0, 0, 0, 0, 0] 161 | 162 | for i_batch, sampled_batch in enumerate(test_loader): 163 | test_image, test_label = sampled_batch["image"].cuda(), sampled_batch["label"].cuda() 164 | image, label = test_image.squeeze(0).cpu().detach().numpy(), test_label.squeeze(0).cpu().detach().numpy() 165 | SGDL_prediction = np.zeros_like(label) 166 | for ind in range(image.shape[0]): 167 | slice = image[ind, :, :] 168 | x, y = slice.shape[0], slice.shape[1] 169 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 170 | 171 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 172 | input = input.repeat(1, 3, 1, 1) 173 | with torch.no_grad(): 174 | pred_UNet, pred_VNet, pred_UNet_soft, pred_VNet_soft, fusion_map = SGDL_model(input) 175 | fusion_map_soft = torch.softmax(fusion_map, dim=1) 176 | out_SGDL = torch.argmax(fusion_map_soft, dim=1).squeeze(0).cpu().detach().numpy() 177 | pred_SGDL = zoom(out_SGDL, (x / patch_size[0], y / patch_size[1]), order=0) 178 | SGDL_prediction[ind] = pred_SGDL 179 | 180 | metric_list = [] 181 | for i in range(1, classes): 182 | disc_pred = SGDL_prediction == i 183 | gt = label == i 184 | disc_pred[disc_pred > 0] = 1 185 | if 1: 186 | disc_pred = getLargestCC(disc_pred) 187 | gt[gt > 0] = 1 188 | single_class_res = [] 189 | if disc_pred.sum() > 0: 190 | single_class_res.append(metric.binary.dc(disc_pred, gt)) 191 | single_class_res.append(metric.binary.jc(disc_pred, gt)) 192 | single_class_res.append(metric.binary.asd(disc_pred, gt)) 193 | single_class_res.append(metric.binary.hd95(disc_pred, gt)) 194 | else: 195 | single_class_res = [0, 0, 0, 0, 0] 196 | metric_list.append(single_class_res) 197 | 198 | metric_list = np.array(metric_list).astype("float32") 199 | metric_list = np.mean(metric_list, axis=0) 200 | 201 | print(metric_list) 202 | final_res += metric_list 203 | final_res = [x / len(test_loader) for x in final_res] 204 | print("avg_dice: ", final_res[0]) 205 | print("avg_iou: ", final_res[1]) 206 | print("avg_asd: ", final_res[2]) 207 | print("avg_hd95: ", final_res[3]) 208 | 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | albumentations==1.3.1 5 | aliyun-python-sdk-core==2.15.2 6 | aliyun-python-sdk-kms==2.16.5 7 | annotated-types==0.7.0 8 | antlr4-python3-runtime==4.9.3 9 | anyio==4.4.0 10 | apex==0.1 11 | apptools==5.2.1 12 | asttokens==2.4.1 13 | attrs==24.2.0 14 | batchgenerators==0.25 15 | black==24.8.0 16 | blinker==1.8.2 17 | cachetools==5.3.1 18 | certifi==2022.12.7 19 | cffi==1.17.1 20 | charset-normalizer==2.1.1 21 | click==8.1.7 22 | clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 23 | cloudpickle==3.0.0 24 | cmake==3.29.2 25 | colorama==0.4.6 26 | coloredlogs==15.0.1 27 | comm==0.2.2 28 | ConfigArgParse==1.7 29 | configobj==5.0.8 30 | contourpy==1.1.1 31 | coverage==7.3.4 32 | crcmod==1.7 33 | cryptography==43.0.1 34 | cycler==0.11.0 35 | Cython==3.0.11 36 | dash==2.18.1 37 | dash-core-components==2.0.0 38 | dash-html-components==2.0.0 39 | dash-table==5.0.0 40 | datasketch==1.6.5 41 | decorator==5.1.1 42 | deepdiff==8.1.1 43 | Deprecated==1.2.14 44 | dipy==1.10.0 45 | docker-pycreds==0.4.0 46 | edt==2.4.0 47 | efficientnet-pytorch==0.7.1 48 | einops==0.6.1 49 | envisage==7.0.3 50 | et-xmlfile==1.1.0 51 | exceptiongroup==1.2.2 52 | executing==2.1.0 53 | fastapi==0.112.2 54 | fastjsonschema==2.20.0 55 | ffmpy==0.4.0 56 | filelock==3.14.0 57 | Flask==3.0.3 58 | flatbuffers==24.3.25 59 | fonttools==4.42.1 60 | fsspec==2023.9.2 61 | ftfy==6.2.3 62 | future==0.18.3 63 | fvcore==0.1.5.post20221221 64 | gitdb==4.0.11 65 | GitPython==3.1.43 66 | google-auth==2.23.1 67 | google-auth-oauthlib==1.0.0 68 | gradio==4.42.0 69 | gradio_client==1.3.0 70 | grpcio==1.58.0 71 | gviz-api==1.10.0 72 | h11==0.14.0 73 | h5py==3.10.0 74 | hausdorff==0.2.6 75 | httpcore==1.0.5 76 | httpx==0.27.2 77 | huggingface-hub==0.24.6 78 | humanfriendly==10.0 79 | humanize==4.9.0 80 | hydra-core==1.3.2 81 | icecream==2.1.3 82 | idna==3.4 83 | imageio==2.31.4 84 | importlib_metadata==8.5.0 85 | importlib_resources==6.4.4 86 | iopath==0.1.9 87 | ipdb==0.13.13 88 | ipython==8.27.0 89 | ipywidgets==8.1.5 90 | itsdangerous==2.2.0 91 | jedi==0.19.1 92 | Jinja2==3.1.2 93 | jmespath==0.10.0 94 | joblib==1.3.2 95 | jsonschema==4.23.0 96 | jsonschema-specifications==2023.12.1 97 | jupyter_core==5.7.2 98 | jupyterlab_widgets==3.0.13 99 | kiwisolver==1.4.5 100 | kneed==0.8.5 101 | kornia==0.7.1 102 | lazy_loader==0.3 103 | linecache2==1.0.0 104 | littleutils==0.2.2 105 | llvmlite==0.41.1 106 | lxml==5.3.0 107 | Markdown==3.4.4 108 | markdown-it-py==3.0.0 109 | MarkupSafe==2.1.2 110 | matplotlib==3.8.0 111 | matplotlib-inline==0.1.7 112 | mayavi==4.8.1 113 | mdurl==0.1.2 114 | MedPy==0.4.0 115 | medutils==0.1.21 116 | mmengine==0.10.5 117 | mmsegmentation==1.2.2 118 | model-index==0.1.11 119 | monai==1.3.0 120 | mpmath==1.3.0 121 | munch==4.0.0 122 | mypy-extensions==1.0.0 123 | nbformat==5.10.4 124 | nest-asyncio==1.6.0 125 | networkx==3.0 126 | nibabel==5.1.0 127 | nilearn==0.10.4 128 | numba==0.58.1 129 | numpy==1.23.2 130 | oauthlib==3.2.2 131 | omegaconf==2.3.0 132 | onnx==1.16.1 133 | onnxruntime==1.19.0 134 | open3d==0.18.0 135 | opencv-python==4.8.0.76 136 | opencv-python-headless==4.8.0.76 137 | opendatalab==0.0.10 138 | openmim==0.3.9 139 | openpyxl==3.1.2 140 | openxlab==0.1.1 141 | ordered-set==4.1.0 142 | orderly-set==5.2.3 143 | orjson==3.10.7 144 | oss2==2.17.0 145 | outdated==0.2.2 146 | packaging==24.1 147 | pandas==2.1.4 148 | pandas-flavor==0.6.0 149 | parameterized==0.9.0 150 | parso==0.8.4 151 | pathspec==0.12.1 152 | patsy==0.5.4 153 | Pillow==9.3.0 154 | pingouin==0.5.3 155 | platformdirs==4.2.1 156 | plotly==5.24.1 157 | portalocker==2.10.1 158 | prefetch-generator==1.0.3 159 | pretrainedmodels==0.7.4 160 | prettytable==3.11.0 161 | prompt_toolkit==3.0.48 162 | protobuf==4.24.3 163 | psutil==5.9.8 164 | pure_eval==0.2.3 165 | pyasn1==0.5.0 166 | pyasn1-modules==0.3.0 167 | pycocotools==2.0.8 168 | pycparser==2.22 169 | pycryptodome==3.21.0 170 | pydantic==2.8.2 171 | pydantic_core==2.20.1 172 | pydicom==3.0.1 173 | pydub==0.25.1 174 | pyface==8.0.0 175 | Pygments==2.17.2 176 | pyparsing==3.1.1 177 | PyQt5==5.15.11 178 | PyQt5-Qt5==5.15.2 179 | PyQt5_sip==12.15.0 180 | pyreadline3==3.4.1 181 | PySimpleGUI==4.60.5 182 | python-dateutil==2.8.2 183 | python-multipart==0.0.9 184 | pytorch-ssim==0.1 185 | pytz==2023.3.post1 186 | PyWavelets==1.4.1 187 | pywin32==306 188 | PyYAML==6.0.1 189 | qudida==0.0.4 190 | read-roi==1.6.0 191 | referencing==0.35.1 192 | regex==2024.9.11 193 | requests==2.28.2 194 | requests-oauthlib==1.3.1 195 | retrying==1.3.4 196 | rich==13.4.2 197 | roifile==2024.1.10 198 | rpds-py==0.20.0 199 | rsa==4.9 200 | ruff==0.6.2 201 | safetensors==0.3.3 202 | scikit-image==0.21.0 203 | scikit-learn==1.3.1 204 | scipy==1.11.3 205 | seaborn==0.13.0 206 | seg-metrics==1.1.6 207 | segmentation-models-pytorch==0.3.3 208 | semantic-version==2.10.0 209 | sentry-sdk==2.1.1 210 | setproctitle==1.3.3 211 | setuptools-scm==8.1.0 212 | shapely==2.0.3 213 | shellingham==1.5.4 214 | SimpleITK==2.3.0 215 | six==1.16.0 216 | smmap==5.0.1 217 | sniffio==1.3.1 218 | stack-data==0.6.3 219 | starlette==0.38.2 220 | statsmodels==0.14.1 221 | Surface-Distance-Based-Measures @ file:///C:/Users/admin/Desktop/surface-distance-master 222 | sympy==1.12 223 | tabulate==0.9.0 224 | tenacity==9.0.0 225 | tensorboard==2.14.1 226 | tensorboard-data-server==0.7.1 227 | tensorboard-plugin-profile==2.13.1 228 | tensorboardX==2.6.2.2 229 | termcolor==2.4.0 230 | thop==0.1.1.post2209072238 231 | threadpoolctl==3.2.0 232 | tifffile==2023.9.26 233 | timm==0.9.2 234 | tomli==2.0.1 235 | tomlkit==0.12.0 236 | torch==2.0.1+cu118 237 | torch-cka==0.21 238 | torchaudio==2.0.2+cu118 239 | torchcam==0.4.0 240 | torchio==0.19.6 241 | torchsummary==1.5.1 242 | torchvision==0.15.2+cu118 243 | tqdm==4.65.2 244 | traceback2==1.4.0 245 | traitlets==5.14.3 246 | traits==6.4.3 247 | traitsui==8.0.0 248 | trx-python==0.3 249 | tsnecuda==3.0.1 250 | typer==0.12.3 251 | typing_extensions==4.12.2 252 | tzdata==2023.3 253 | unittest2==1.1.0 254 | urllib3==1.26.20 255 | uvicorn==0.30.6 256 | vtk==9.3.0 257 | wandb==0.17.0 258 | wcwidth==0.2.13 259 | websockets==12.0 260 | Werkzeug==3.0.4 261 | widgetsnbextension==4.0.13 262 | wrapt==1.16.0 263 | xarray==2023.12.0 264 | yacs==0.1.8 265 | yapf==0.40.2 266 | zipp==3.20.2 267 | -------------------------------------------------------------------------------- /train_semi_SAM.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | import torch 5 | import os 6 | import logging 7 | import sys 8 | from tqdm import tqdm 9 | from dataloader.dataset import build_Dataset 10 | from torch.utils.data import DataLoader 11 | from utils.utils import patients_to_slices 12 | from dataloader.transforms import build_transforms, build_weak_strong_transforms 13 | from dataloader.TwoStreamBatchSampler import TwoStreamBatchSampler 14 | 15 | from trainer import Trainer 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data_path', type=str, default='./SampleData', 20 | help='Name of Experiment') 21 | parser.add_argument('--labeled_num', type=int, default=1, 22 | help='Percentage of label quantity') 23 | 24 | parser.add_argument('--dataset', type=str, default='/tumor_1', 25 | help='Name of Experiment') 26 | # parser.add_argument('--dataset', type=str, default='/ISIC_TrainDataset_10', 27 | # help='Name of Experiment') 28 | # parser.add_argument('--dataset', type=str, default='/thyroid_30', 29 | # help='Name of Experiment') 30 | # parser.add_argument('--dataset', type=str, default='/BrainMRI_30', 31 | # help='Name of Experiment') 32 | 33 | parser.add_argument('--num_classes', type=int, default=2, 34 | help='output channel of network') 35 | parser.add_argument('--in_channels', type=int, default=3, 36 | help='input channel of network') 37 | 38 | parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate') 39 | parser.add_argument('-UNet_lr', type=float, default=0.01, help='initial learning rate') 40 | parser.add_argument('-VNet_lr', type=float, default=0.01, help='initial learning rate') 41 | parser.add_argument('--image_size', type=int, default=256, help='image_size') 42 | parser.add_argument('--point_nums', type=int, default=5, help='points number') 43 | parser.add_argument('--box_nums', type=int, default=1, help='boxes number') 44 | parser.add_argument('--mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad') 45 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 46 | parser.add_argument('-thd', type=bool, default=False, help='3d or not') 47 | parser.add_argument('--batch_size', type=int, default=24, 48 | help='batch_size per gpu') 49 | parser.add_argument('--labeled_bs', type=int, default=12, 50 | help='labeled_batch_size per gpu') 51 | parser.add_argument('--seed', type=int, default=42, 52 | help='random seed') 53 | 54 | parser.add_argument('--mixed_iterations', type=int, default=12000, 55 | help='maximum epoch number to train') 56 | parser.add_argument('--max_iterations', type=int, default=50000, 57 | help='maximum epoch number to train') 58 | 59 | parser.add_argument('--n_fold', type=int, default=1, 60 | help='maximum epoch number to train') 61 | parser.add_argument('--consistency', type=float, default=0.1, 62 | help='consistency') 63 | parser.add_argument('--consistency_rampup', type=float, 64 | default=200.0, help='consistency_rampup') 65 | parser.add_argument('--device', type=str, default='cuda') 66 | parser.add_argument("--multimask", type=bool, default=False, help="ouput multimask") 67 | parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter") 68 | parser.add_argument("--sam_checkpoint", type=str, default="./sam_vit_b_01ec64.pth", help="sam checkpoint") 69 | 70 | args = parser.parse_args() 71 | 72 | 73 | def sigmoid_rampup(current, rampup_length): 74 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 75 | if rampup_length == 0: 76 | return 1.0 77 | else: 78 | current = np.clip(current, 0.0, rampup_length) 79 | phase = 1.0 - current / rampup_length 80 | return float(np.exp(-5.0 * phase * phase)) 81 | 82 | 83 | def worker_init_fn(worker_id): 84 | random.seed(args.seed + worker_id) 85 | 86 | 87 | def get_current_consistency_weight(epoch): 88 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 89 | return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup) 90 | 91 | 92 | def train(args, snapshot_path): 93 | batch_size = args.batch_size 94 | max_iterations = args.max_iterations 95 | # model 96 | trainer = Trainer(args) 97 | # dataset 98 | data_transforms = build_weak_strong_transforms(args) 99 | train_dataset = build_Dataset(args=args, data_dir=args.data_path + args.dataset, split="train_semi", 100 | transform=data_transforms) 101 | val_dataset = build_Dataset(args=args, data_dir=args.data_path + args.dataset, split="val", 102 | transform=data_transforms["valid_test"]) 103 | 104 | # sampler 105 | total_slices = len(train_dataset) 106 | labeled_slice = patients_to_slices(args.dataset, args.labeled_num) 107 | labeled_idxs = list(range(0, labeled_slice)) 108 | unlabeled_idxs = list(range(labeled_slice, total_slices)) 109 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-args.labeled_bs) 110 | 111 | # dataloader 112 | train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler, num_workers=2, pin_memory=True, worker_init_fn=worker_init_fn) 113 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) 114 | logging.info("{} iterations per epoch".format(len(train_loader))) 115 | max_epoch = max_iterations // len(train_loader) + 1 116 | iterator = tqdm(range(max_epoch), ncols=70) 117 | 118 | # else: 119 | iter_num = 0 120 | for _ in iterator: 121 | for i_batch, sampled_batch in enumerate(train_loader): 122 | volume_batch, label_batch = sampled_batch['image'].cuda(), sampled_batch['label'].cuda() 123 | trainer.train(volume_batch, label_batch, iter_num) 124 | iter_num = iter_num + 1 125 | if iter_num > 0 and iter_num % 200 == 0: 126 | if "ACDC" not in args.dataset: 127 | trainer.val(val_loader, snapshot_path, iter_num) 128 | else: 129 | trainer.val_ACDC(val_loader, snapshot_path, iter_num) 130 | 131 | 132 | if __name__ == '__main__': 133 | import shutil 134 | for fold in range(args.n_fold): 135 | torch.autograd.set_detect_anomaly(True) 136 | random.seed(2024) 137 | np.random.seed(2024) 138 | torch.manual_seed(2024) 139 | torch.cuda.manual_seed(2024) 140 | 141 | snapshot_path = "./Results/Result_tumor_1/fold_" + str(fold) 142 | 143 | if not os.path.exists(snapshot_path): 144 | os.makedirs(snapshot_path) 145 | if os.path.exists(snapshot_path + '/code'): 146 | shutil.rmtree(snapshot_path + '/code') 147 | if not os.path.exists(snapshot_path + '/code'): 148 | os.makedirs(snapshot_path + '/code') 149 | 150 | shutil.copyfile("./train_semi_SAM.py", snapshot_path + "/code/train_semi_SAM.py") 151 | shutil.copyfile("./trainer.py", snapshot_path + "/code/trainer.py") 152 | 153 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 154 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 155 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 156 | logging.info(str(args)) 157 | train(args, snapshot_path) 158 | -------------------------------------------------------------------------------- /train_semi_SAM_ACDC.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | import torch 5 | import os 6 | import logging 7 | import sys 8 | from tqdm import tqdm 9 | from dataloader.dataset import build_Dataset 10 | from torch.utils.data import DataLoader 11 | from utils.utils import patients_to_slices 12 | from dataloader.transforms import build_transforms, build_weak_strong_transforms 13 | from dataloader.TwoStreamBatchSampler import TwoStreamBatchSampler 14 | 15 | 16 | from trainer import Trainer 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--data_path', type=str, default='./SampleData', 21 | help='Name of Experiment') 22 | parser.add_argument('--labeled_num', type=int, default=7, 23 | help='Percentage of label quantity') 24 | parser.add_argument('--dataset', type=str, default='/ACDC', 25 | help='Name of Experiment') 26 | 27 | 28 | parser.add_argument('--num_classes', type=int, default=4, 29 | help='output channel of network') 30 | parser.add_argument('--in_channels', type=int, default=3, 31 | help='input channel of network') 32 | 33 | parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate') 34 | parser.add_argument('-UNet_lr', type=float, default=0.01, help='initial learning rate') 35 | parser.add_argument('-VNet_lr', type=float, default=0.01, help='initial learning rate') 36 | parser.add_argument('--image_size', type=int, default=256, help='image_size') 37 | parser.add_argument('--point_nums', type=int, default=5, help='points number') 38 | parser.add_argument('--box_nums', type=int, default=1, help='boxes number') 39 | parser.add_argument('--mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad') 40 | parser.add_argument("--model_type", type=str, default="vit_b", help="sam model_type") 41 | parser.add_argument('-thd', type=bool, default=False, help='3d or not') 42 | parser.add_argument('--batch_size', type=int, default=24, 43 | help='batch_size per gpu') 44 | parser.add_argument('--labeled_bs', type=int, default=12, 45 | help='labeled_batch_size per gpu') 46 | parser.add_argument('--seed', type=int, default=42, 47 | help='random seed') 48 | 49 | parser.add_argument('--mixed_iterations', type=int, default=12000, 50 | help='maximum epoch number to train') 51 | parser.add_argument('--max_iterations', type=int, default=50000, 52 | help='maximum epoch number to train') 53 | 54 | parser.add_argument('--n_fold', type=int, default=1, 55 | help='maximum epoch number to train') 56 | parser.add_argument('--consistency', type=float, default=0.1, 57 | help='consistency') 58 | parser.add_argument('--consistency_rampup', type=float, 59 | default=200.0, help='consistency_rampup') 60 | parser.add_argument('--device', type=str, default='cuda') 61 | parser.add_argument("--multimask", type=bool, default=False, help="ouput multimask") 62 | parser.add_argument("--encoder_adapter", type=bool, default=True, help="use adapter") 63 | parser.add_argument("--sam_checkpoint", type=str, default="./sam_vit_b_01ec64.pth", help="sam checkpoint") 64 | 65 | 66 | args = parser.parse_args() 67 | 68 | 69 | def sigmoid_rampup(current, rampup_length): 70 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 71 | if rampup_length == 0: 72 | return 1.0 73 | else: 74 | current = np.clip(current, 0.0, rampup_length) 75 | phase = 1.0 - current / rampup_length 76 | return float(np.exp(-5.0 * phase * phase)) 77 | 78 | 79 | def worker_init_fn(worker_id): 80 | random.seed(args.seed + worker_id) 81 | 82 | 83 | def get_current_consistency_weight(epoch): 84 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 85 | return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup) 86 | 87 | 88 | def train(args, snapshot_path): 89 | batch_size = args.batch_size 90 | max_iterations = args.max_iterations 91 | # model 92 | trainer = Trainer(args) 93 | 94 | labeled_slice = patients_to_slices(args.dataset, args.labeled_num) 95 | # dataset 96 | data_transforms = build_weak_strong_transforms(args) 97 | train_dataset = build_Dataset(args=args, data_dir=args.data_path + args.dataset, split="train_acdc_list", 98 | transform=data_transforms) 99 | val_dataset = build_Dataset(args=args, data_dir=args.data_path + args.dataset, split="val_acdc_list", 100 | transform=data_transforms["valid_test"]) 101 | 102 | # sampler 103 | total_slices = len(train_dataset) 104 | labeled_idxs = list(range(0, labeled_slice)) 105 | unlabeled_idxs = list(range(labeled_slice, total_slices)) 106 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-args.labeled_bs) 107 | 108 | # dataloader 109 | train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler, 110 | num_workers=2, pin_memory=True, worker_init_fn=worker_init_fn) 111 | val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) 112 | logging.info("{} iterations per epoch".format(len(train_loader))) 113 | max_epoch = max_iterations // len(train_loader) + 1 114 | iterator = tqdm(range(max_epoch), ncols=70) 115 | 116 | # else: 117 | 118 | iter_num = 0 119 | for _ in iterator: 120 | for i_batch, sampled_batch in enumerate(train_loader): 121 | volume_batch, label_batch = sampled_batch['image'].cuda(), sampled_batch['label'].cuda() 122 | trainer.train(volume_batch, label_batch, iter_num) 123 | iter_num = iter_num + 1 124 | if iter_num > 0 and iter_num % 200 == 0: 125 | if "ACDC" not in args.dataset: 126 | trainer.val(val_loader, snapshot_path, iter_num) 127 | else: 128 | trainer.val_ACDC(val_loader, snapshot_path, iter_num) 129 | 130 | 131 | if __name__ == '__main__': 132 | import shutil 133 | for fold in range(args.n_fold): 134 | torch.autograd.set_detect_anomaly(True) 135 | random.seed(2024) 136 | np.random.seed(2024) 137 | torch.manual_seed(2024) 138 | torch.cuda.manual_seed(2024) 139 | 140 | snapshot_path = "./Results/results_ACDC_10/fold_" + str(fold) 141 | 142 | if not os.path.exists(snapshot_path): 143 | os.makedirs(snapshot_path) 144 | if os.path.exists(snapshot_path + '/code'): 145 | shutil.rmtree(snapshot_path + '/code') 146 | if not os.path.exists(snapshot_path + '/code'): 147 | os.makedirs(snapshot_path + '/code') 148 | 149 | shutil.copyfile("./train_semi_SAM_ACDC.py", snapshot_path + "/code/train_semi_SAM_ACDC.py") 150 | shutil.copyfile("./trainer.py", snapshot_path + "/code/trainer.py") 151 | 152 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 153 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 154 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 155 | logging.info(str(args)) 156 | train(args, snapshot_path) -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | CE = torch.nn.BCELoss() 6 | mse = torch.nn.MSELoss() 7 | 8 | class KDLoss(nn.Module): 9 | """ 10 | Distilling the Knowledge in a Neural Network 11 | https://arxiv.org/pdf/1503.02531.pdf 12 | """ 13 | 14 | def __init__(self, T): 15 | super(KDLoss, self).__init__() 16 | self.T = T 17 | 18 | def forward(self, out_s, out_t): 19 | loss = ( 20 | F.kl_div(F.log_softmax(out_s / self.T, dim=1), 21 | F.softmax(out_t / self.T, dim=1), reduction="batchmean") # , reduction="batchmean" 22 | * self.T 23 | * self.T 24 | ) 25 | return loss 26 | 27 | 28 | def loss_diff1(u_prediction_1, u_prediction_2): 29 | loss_a = 0.0 30 | 31 | for i in range(u_prediction_2.size(1)): 32 | loss_a = CE(u_prediction_1[:, i, ...].clamp(1e-8, 1 - 1e-7), 33 | Variable(u_prediction_2[:, i, ...].float(), requires_grad=False)) 34 | 35 | loss_diff_avg = loss_a.mean() 36 | return loss_diff_avg 37 | 38 | 39 | def loss_diff2(u_prediction_1, u_prediction_2): 40 | loss_b = 0.0 41 | 42 | for i in range(u_prediction_2.size(1)): 43 | loss_b = CE(u_prediction_2[:, i, ...].clamp(1e-8, 1 - 1e-7), 44 | Variable(u_prediction_1[:, i, ...], requires_grad=False)) 45 | 46 | loss_diff_avg = loss_b.mean() 47 | return loss_diff_avg 48 | 49 | 50 | class DiceLoss(nn.Module): 51 | def __init__(self, n_classes): 52 | super(DiceLoss, self).__init__() 53 | self.n_classes = n_classes 54 | 55 | def _one_hot_encoder(self, input_tensor): 56 | tensor_list = [] 57 | for i in range(self.n_classes): 58 | temp_prob = input_tensor * i == i * torch.ones_like(input_tensor) 59 | tensor_list.append(temp_prob) 60 | output_tensor = torch.cat(tensor_list, dim=1) 61 | return output_tensor.float() 62 | 63 | def _dice_loss(self, score, target): 64 | target = target.float() 65 | smooth = 1e-10 66 | intersect = torch.sum(score * target) 67 | y_sum = torch.sum(target * target) 68 | z_sum = torch.sum(score * score) 69 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 70 | loss = 1 - loss 71 | return loss 72 | 73 | def _dice_mask_loss(self, score, target, mask): 74 | target = target.float() 75 | mask = mask.float() 76 | smooth = 1e-10 77 | intersect = torch.sum(score * target * mask) 78 | y_sum = torch.sum(target * target * mask) 79 | z_sum = torch.sum(score * score * mask) 80 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 81 | loss = 1 - loss 82 | return loss 83 | 84 | def forward(self, inputs, target, weight=None, softmax=False): 85 | if softmax: 86 | inputs = torch.softmax(inputs, dim=1) 87 | target = self._one_hot_encoder(target.unsqueeze(1)) 88 | if weight is None: 89 | weight = [1] * self.n_classes 90 | assert inputs.size() == target.size(), 'predict & target shape do not match' 91 | class_wise_dice = [] 92 | loss = 0.0 93 | for i in range(0, self.n_classes): 94 | dice = self._dice_loss(inputs[:, i], target[:, i]) 95 | class_wise_dice.append(1.0 - dice.item()) 96 | loss += dice * weight[i] 97 | return loss / self.n_classes 98 | 99 | 100 | def dice_loss(pred, label, epsilon=1e-5): 101 | intersection = torch.sum(pred * label, dim=(2, 3)) 102 | union = torch.sum(pred, dim=(2, 3)) + torch.sum(label, dim=(2, 3)) 103 | dice_coefficient = (2.0 * intersection + epsilon) / (union + epsilon) 104 | dice_loss = 1.0 - dice_coefficient 105 | return dice_loss.mean() 106 | 107 | 108 | def sigmoid_mse_loss_map(input_logits, target_logits): 109 | assert input_logits.size() == target_logits.size() 110 | input_softmax = torch.nn.Sigmoid()(input_logits) 111 | target_softmax = torch.nn.Sigmoid()(target_logits) 112 | mse_loss_map = (input_softmax-target_softmax)**2 113 | return mse_loss_map 114 | 115 | 116 | def mse_loss(input1, input2): 117 | return torch.mean((input1 - input2) ** 2) 118 | 119 | 120 | class DiceLoss(nn.Module): 121 | def __init__(self, n_classes): 122 | super(DiceLoss, self).__init__() 123 | self.n_classes = n_classes 124 | 125 | def _one_hot_encoder(self, input_tensor): 126 | tensor_list = [] 127 | for i in range(self.n_classes): 128 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 129 | tensor_list.append(temp_prob) 130 | output_tensor = torch.cat(tensor_list, dim=1) 131 | return output_tensor.float() 132 | 133 | def _one_hot_mask_encoder(self, input_tensor): 134 | tensor_list = [] 135 | for i in range(self.n_classes): 136 | temp_prob = input_tensor * i == i * torch.ones_like(input_tensor) 137 | tensor_list.append(temp_prob) 138 | output_tensor = torch.cat(tensor_list, dim=1) 139 | return output_tensor.float() 140 | 141 | def _dice_loss(self, score, target): 142 | target = target.float() 143 | smooth = 1e-10 144 | intersect = torch.sum(score * target) 145 | y_sum = torch.sum(target * target) 146 | z_sum = torch.sum(score * score) 147 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 148 | loss = 1 - loss 149 | return loss 150 | 151 | def _dice_mask_loss(self, score, target, mask): 152 | target = target.float() 153 | mask = mask.float() 154 | smooth = 1e-10 155 | intersect = torch.sum(score * target * mask) 156 | y_sum = torch.sum(target * target * mask) 157 | z_sum = torch.sum(score * score * mask) 158 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 159 | loss = 1 - loss 160 | return loss 161 | 162 | def forward(self, inputs, target, mask=None, weight=None, softmax=False): 163 | if softmax: 164 | inputs = torch.softmax(inputs, dim=1) 165 | target = self._one_hot_encoder(target.unsqueeze(1)) 166 | if weight is None: 167 | weight = [1] * self.n_classes 168 | assert inputs.size() == target.size(), 'predict & target shape do not match' 169 | class_wise_dice = [] 170 | loss = 0.0 171 | if mask is not None: 172 | mask = self._one_hot_mask_encoder(mask) 173 | for i in range(0, self.n_classes): 174 | dice = self._dice_mask_loss(inputs[:, i], target[:, i], mask[:, i]) 175 | class_wise_dice.append(1.0 - dice.item()) 176 | loss += dice * weight[i] 177 | else: 178 | for i in range(0, self.n_classes): 179 | dice = self._dice_loss(inputs[:, i], target[:, i]) 180 | class_wise_dice.append(1.0 - dice.item()) 181 | loss += dice * weight[i] 182 | return loss / self.n_classes -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from medpy import metric 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | from hausdorff import hausdorff_distance 8 | 9 | 10 | 11 | 12 | 13 | def eval(y_true, y_pred, thr=0.5, epsilon=0.001): 14 | if y_pred.shape[1] == 1: 15 | y_true = y_true.to(torch.float32).squeeze(0).cpu().detach().numpy() 16 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 17 | else: 18 | y_true = y_true.to(torch.float32).squeeze(0).cpu().detach().numpy() 19 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 20 | single_class_res = [] 21 | single_class_res.append(metric.binary.dc(y_pred, y_true)) 22 | single_class_res.append(metric.binary.jc(y_pred, y_true)) 23 | single_class_res.append(hausdorff_distance(y_true, y_pred) * 0.95) 24 | return single_class_res 25 | 26 | 27 | 28 | def dice_coef(y_true, y_pred, thr=0.5, epsilon=0.001): 29 | if y_pred.shape[1] > 1: 30 | y_true = y_true.to(torch.float32).squeeze(0).cpu().detach().numpy() 31 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 32 | else: 33 | y_true = y_true.to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 34 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 35 | inter_map = y_true * y_pred 36 | inter = inter_map.sum() 37 | den = y_true.sum() + y_pred.sum() 38 | dice = ((2 * inter) / (den + epsilon)) if den > 0 else 0 39 | return dice 40 | 41 | 42 | def evaluate_95hd(y_true, y_pred, thr=0.5): 43 | if y_true.shape[1] > 1: 44 | y_true = y_true.to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 45 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 46 | else: 47 | y_true = y_true.to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 48 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 49 | hd = hausdorff_distance(y_true, y_pred) 50 | return hd * 0.95 51 | 52 | 53 | def calculate_iou(y_true, y_pred, thr=0.5): 54 | if y_true.shape[1] > 1: 55 | y_true = y_true.to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 56 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0)[1].cpu().detach().numpy() 57 | else: 58 | y_true = y_true.to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 59 | y_pred = (y_pred > thr).to(torch.float32).squeeze(0).squeeze(0).cpu().detach().numpy() 60 | intersection = np.logical_and(y_true > 0, y_pred > 0) 61 | intersection = np.sum(intersection) 62 | union = np.logical_or(y_true > 0, y_pred > 0) 63 | union = np.sum(union) 64 | # 计算Jaccard指数(IoU) 65 | iou = intersection / union if union > 0 else 0 66 | return iou 67 | 68 | 69 | 70 | def calculate_metric_percase(pred, gt): 71 | pred[pred > 0] = 1 72 | gt[gt > 0] = 1 73 | if pred.sum() > 0: 74 | dice = metric.binary.dc(pred, gt) 75 | hd95 = metric.binary.hd95(pred, gt) 76 | return dice, hd95 77 | else: 78 | return 0, 0 79 | 80 | 81 | def test_single_2D_colorImage(image, label, net, classes): 82 | thr = 0.5 83 | net.eval() 84 | label = label.squeeze(0).cpu().detach().numpy() 85 | with torch.no_grad(): 86 | out = net(image.cuda()) 87 | out = torch.nn.Sigmoid()(out) 88 | out = (out > thr).to(torch.float32) 89 | out = torch.argmax(out, dim=1).squeeze(0) 90 | prediction = out.cpu().detach().numpy() 91 | metric_list = [] 92 | 93 | for i in range(1, classes): 94 | metric_list.append(calculate_metric_percase( 95 | prediction == i, label == i)) 96 | return metric_list 97 | 98 | 99 | def patients_to_slices(dataset, patiens_num): 100 | ref_dict = {} 101 | if "ACDC" in dataset: 102 | ref_dict = {"3": 68, "7": 136, 103 | "14": 256, "21": 396, "28": 512, "35": 664, "140": 1312} 104 | elif "tumor" in dataset: 105 | ref_dict = {"1": 15, "10": 145, "20": 290, "30": 435, } 106 | elif "ISIC" in dataset: 107 | ref_dict = {"10": 207, "30": 622, } 108 | elif "thyroid" in dataset: 109 | ref_dict = {"10": 613, "30": 1841, } 110 | elif "BrainMRI" in dataset: 111 | ref_dict = {"10": 103, "30": 310, } 112 | elif "MRI_Hippocampus_Seg" in dataset: 113 | ref_dict = {"10": 282, "30": 846, } 114 | else: 115 | print("Error") 116 | return ref_dict[str(patiens_num)] 117 | 118 | 119 | def get_uncertainty_map(model, image_batch, num_classes, T=8, uncertainty_bs=2): # uncertainty_bs must be divisible by T 120 | _, _, w, h = image_batch.shape 121 | volume_batch_r = image_batch.repeat(uncertainty_bs, 1, 1, 1) 122 | stride = volume_batch_r.shape[0] // uncertainty_bs 123 | preds = torch.zeros([stride * T, num_classes, w, h]).cuda() # init preds 124 | for i in range(T // uncertainty_bs): 125 | ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2) # add noise 126 | with torch.no_grad(): 127 | preds[uncertainty_bs * stride * i: uncertainty_bs * stride * (i + 1)] = model(ema_inputs) 128 | preds = F.softmax(preds, dim=1) 129 | preds = preds.reshape(T, stride, num_classes, w, h) 130 | preds = torch.mean(preds, dim=0) 131 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) 132 | return uncertainty 133 | 134 | 135 | def get_no_noise_uncertainty_map(model, image_batch, num_classes, T=8): # uncertainty_bs must be divisible by T 136 | b, _, w, h = image_batch.shape 137 | preds = torch.zeros([T * b, num_classes, w, h]).cuda() # init preds 138 | for i in range(T): 139 | ema_inputs = image_batch 140 | with torch.no_grad(): 141 | preds[i * b: i * b + b] = model(ema_inputs) 142 | # preds = F.softmax(preds, dim=1) 143 | preds = torch.nn.Sigmoid()(preds) 144 | preds = preds.reshape(T, b, num_classes, w, h) 145 | preds = torch.mean(preds, dim=0) 146 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) 147 | return uncertainty 148 | 149 | 150 | def generate_mask(img): 151 | batch_size, channel, img_x, img_y = img.shape[0], img.shape[1], img.shape[2], img.shape[3] 152 | loss_mask = torch.ones(batch_size, img_x, img_y).cuda() 153 | mask = torch.ones(img_x, img_y).cuda() 154 | patch_x, patch_y = int(img_x*2/3), int(img_y*2/3) 155 | w = np.random.randint(0, img_x - patch_x) 156 | h = np.random.randint(0, img_y - patch_y) 157 | mask[w:w+patch_x, h:h+patch_y] = 0 158 | loss_mask[:, w:w+patch_x, h:h+patch_y] = 0 159 | return mask.long(), loss_mask.long() 160 | 161 | --------------------------------------------------------------------------------