├── overview.jpg ├── efficiency.jpg ├── asset ├── overview.jpg ├── efficiency.jpg ├── performance.jpg └── structure.jpg ├── LICENSE ├── README.md └── mim_network.py /overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/overview.jpg -------------------------------------------------------------------------------- /efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/efficiency.jpg -------------------------------------------------------------------------------- /asset/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/asset/overview.jpg -------------------------------------------------------------------------------- /asset/efficiency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/asset/efficiency.jpg -------------------------------------------------------------------------------- /asset/performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/asset/performance.jpg -------------------------------------------------------------------------------- /asset/structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txchen-USTC/MiM-ISTD/HEAD/asset/structure.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 freq 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Official pytorch code of our TGRS 2024 paper "MiM-ISTD: Mamba-in-Mamba for Efficient Infrared Small Target Detection". 2 | 3 | [https://ieeexplore.ieee.org/abstract/document/10740056] 4 | 5 | ## News 6 | 7 | * 24-11-01. Our paper get published in IEEE Transactions on Geoscience and Remote Sensing [IF=7.5]. 8 | 9 | * 24-03-15. We have corrected some errors and updated the whole network structure code of our MiM-ISTD. Feel free to use it, especially to more other tasks! 10 | 11 | * 24-03-08. Our paper has been released on arXiv. 12 | 13 | ## A Quick Overview 14 | 15 | ![image](https://github.com/txchen-USTC/MiM-ISTD/blob/main/asset/overview.jpg) 16 | 17 | ## Efficiency Advantages 18 | 19 | ![image](https://github.com/txchen-USTC/MiM-ISTD/blob/main/asset/efficiency.jpg) 20 | 21 | ## Detailed structure of our Mamba-in-Mamba design 22 | 23 | ![image](https://github.com/txchen-USTC/MiM-ISTD/blob/main/asset/structure.jpg) 24 | 25 | ## Performance Comparison 26 | 27 | ![image](https://github.com/txchen-USTC/MiM-ISTD/blob/main/asset/performance.jpg) 28 | 29 | ## Required Environments 30 | 31 | ``` 32 | conda create -n mim python=3.8 33 | conda activate mim 34 | pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 35 | pip install packaging 36 | pip install timm==0.4.12 37 | pip install pytest chardet yacs termcolor 38 | pip install submitit tensorboardX 39 | pip install triton==2.0.0 40 | pip install causal_conv1d==1.0.0 # causal_conv1d-1.0.0+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 41 | pip install mamba_ssm==1.0.1 # mmamba_ssm-1.0.1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl 42 | pip install scikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs 43 | ``` 44 | 45 | The .whl files of causal_conv1d and mamba_ssm could be found here. {[Baidu](https://pan.baidu.com/s/1Uza8g1pkVcbXG1F-2tB0xQ?pwd=p3h9)} 46 | 47 | ## Checkpoint 48 | 49 | A newly retrained MiM checkpoint that maintains relatively high accuracy (around 80% IoU) on the SIRST dataset is available at Baidu Disk: {[Baidu](https://pan.baidu.com/s/13g2v_M9tPxq_ze02fpaYGw)}, extraction code: DY4h. 50 | 51 | ## Citation 52 | 53 | Please cite our paper if you find the repository helpful. 54 | ``` 55 | @article{chen2024mim, 56 | title={Mim-istd: Mamba-in-mamba for efficient infrared small target detection}, 57 | author={Chen, Tianxiang and Ye, Zi and Tan, Zhentao and Gong, Tao and Wu, Yue and Chu, Qi and Liu, Bin and Yu, Nenghai and Ye, Jieping}, 58 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 59 | year={2024}, 60 | publisher={IEEE} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /mim_network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.models.helpers import load_pretrained 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.resnet import resnet26d, resnet50d 11 | from timm.models.registry import register_model 12 | from einops import rearrange, repeat 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | try: 15 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 16 | except: 17 | pass 18 | 19 | # an alternative for mamba_ssm (in which causal_conv1d is needed) 20 | try: 21 | from selective_scan import selective_scan_fn as selective_scan_fn_v1 22 | from selective_scan import selective_scan_ref as selective_scan_ref_v1 23 | except: 24 | pass 25 | 26 | 27 | 28 | 29 | class SS2D(nn.Module): 30 | def __init__( 31 | self, 32 | d_model, 33 | d_state=16, 34 | # d_state="auto", # 20240109 35 | d_conv=3, 36 | expand=2, 37 | dt_rank="auto", 38 | dt_min=0.001, 39 | dt_max=0.1, 40 | dt_init="random", 41 | dt_scale=1.0, 42 | dt_init_floor=1e-4, 43 | dropout=0., 44 | conv_bias=True, 45 | bias=False, 46 | device=None, 47 | dtype=None, 48 | **kwargs, 49 | ): 50 | factory_kwargs = {"device": device, "dtype": dtype} 51 | super().__init__() 52 | self.d_model = d_model 53 | self.d_state = d_state 54 | # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 55 | self.d_conv = d_conv 56 | self.expand = expand 57 | self.d_inner = int(self.expand * self.d_model) 58 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 59 | 60 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 61 | self.conv2d = nn.Conv2d( 62 | in_channels=self.d_inner, 63 | out_channels=self.d_inner, 64 | groups=self.d_inner, 65 | bias=conv_bias, 66 | kernel_size=d_conv, 67 | padding=(d_conv - 1) // 2, 68 | **factory_kwargs, 69 | ) 70 | self.act = nn.SiLU() 71 | 72 | self.x_proj = ( 73 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 74 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 75 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 76 | nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 77 | ) 78 | self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) 79 | del self.x_proj 80 | 81 | self.dt_projs = ( 82 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 83 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 84 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 85 | self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), 86 | ) 87 | self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) 88 | self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) 89 | del self.dt_projs 90 | 91 | self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) 92 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 93 | 94 | # self.selective_scan = selective_scan_fn 95 | self.forward_core = self.forward_corev0 96 | 97 | self.out_norm = nn.LayerNorm(self.d_inner) 98 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 99 | self.dropout = nn.Dropout(dropout) if dropout > 0. else None 100 | 101 | @staticmethod 102 | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): 103 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 104 | 105 | # Initialize special dt projection to preserve variance at initialization 106 | dt_init_std = dt_rank**-0.5 * dt_scale 107 | if dt_init == "constant": 108 | nn.init.constant_(dt_proj.weight, dt_init_std) 109 | elif dt_init == "random": 110 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 111 | else: 112 | raise NotImplementedError 113 | 114 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 115 | dt = torch.exp( 116 | torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 117 | + math.log(dt_min) 118 | ).clamp(min=dt_init_floor) 119 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 120 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 121 | with torch.no_grad(): 122 | dt_proj.bias.copy_(inv_dt) 123 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 124 | dt_proj.bias._no_reinit = True 125 | 126 | return dt_proj 127 | 128 | @staticmethod 129 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 130 | # S4D real initialization 131 | A = repeat( 132 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 133 | "n -> d n", 134 | d=d_inner, 135 | ).contiguous() 136 | A_log = torch.log(A) # Keep A_log in fp32 137 | if copies > 1: 138 | A_log = repeat(A_log, "d n -> r d n", r=copies) 139 | if merge: 140 | A_log = A_log.flatten(0, 1) 141 | A_log = nn.Parameter(A_log) 142 | A_log._no_weight_decay = True 143 | return A_log 144 | 145 | @staticmethod 146 | def D_init(d_inner, copies=1, device=None, merge=True): 147 | # D "skip" parameter 148 | D = torch.ones(d_inner, device=device) 149 | if copies > 1: 150 | D = repeat(D, "n1 -> r n1", r=copies) 151 | if merge: 152 | D = D.flatten(0, 1) 153 | D = nn.Parameter(D) # Keep in fp32 154 | D._no_weight_decay = True 155 | return D 156 | 157 | def forward_corev0(self, x: torch.Tensor): 158 | self.selective_scan = selective_scan_fn 159 | 160 | B, C, H, W = x.shape 161 | L = H * W 162 | K = 4 163 | 164 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 165 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 166 | 167 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 168 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 169 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 170 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 171 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 172 | 173 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 174 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 175 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 176 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 177 | Ds = self.Ds.float().view(-1) # (k * d) 178 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 179 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 180 | 181 | out_y = self.selective_scan( 182 | xs, dts, 183 | As, Bs, Cs, Ds, z=None, 184 | delta_bias=dt_projs_bias, 185 | delta_softplus=True, 186 | return_last_state=False, 187 | ).view(B, K, -1, L) 188 | assert out_y.dtype == torch.float 189 | 190 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 191 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 192 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 193 | 194 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 195 | 196 | # an alternative to forward_corev1 197 | def forward_corev1(self, x: torch.Tensor): 198 | self.selective_scan = selective_scan_fn_v1 199 | 200 | B, C, H, W = x.shape 201 | L = H * W 202 | K = 4 203 | 204 | x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) 205 | xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) 206 | 207 | x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) 208 | # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) 209 | dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) 210 | dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) 211 | # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) 212 | 213 | xs = xs.float().view(B, -1, L) # (b, k * d, l) 214 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 215 | Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) 216 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 217 | Ds = self.Ds.float().view(-1) # (k * d) 218 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) 219 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 220 | 221 | out_y = self.selective_scan( 222 | xs, dts, 223 | As, Bs, Cs, Ds, 224 | delta_bias=dt_projs_bias, 225 | delta_softplus=True, 226 | ).view(B, K, -1, L) 227 | assert out_y.dtype == torch.float 228 | 229 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 230 | wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 231 | invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) 232 | 233 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 234 | 235 | def forward(self, x, H, W, relative_pos=None): 236 | B, N, C = x.shape 237 | #print('x input',x.shape) 238 | x = x.permute(0, 2, 1).reshape(B, H, W, C) 239 | 240 | B, H, W, C = x.shape 241 | 242 | xz = self.in_proj(x) 243 | x, z = xz.chunk(2, dim=-1) # (b, h, w, d) 244 | 245 | x = x.permute(0, 3, 1, 2).contiguous() 246 | x = self.act(self.conv2d(x)) # (b, d, h, w) 247 | y1, y2, y3, y4 = self.forward_core(x) 248 | assert y1.dtype == torch.float32 249 | y = y1 + y2 + y3 + y4 250 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 251 | y = self.out_norm(y) 252 | y = y * F.silu(z) 253 | out = self.out_proj(y) 254 | if self.dropout is not None: 255 | out = self.dropout(out) 256 | out=out.reshape(B,N,C) 257 | #print('x output',out.shape) 258 | return out 259 | 260 | class Mlp(nn.Module): 261 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 262 | super().__init__() 263 | out_features = out_features or in_features 264 | hidden_features = hidden_features or in_features 265 | self.fc1 = nn.Linear(in_features, hidden_features) 266 | self.act = act_layer() 267 | self.fc2 = nn.Linear(hidden_features, out_features) 268 | self.drop = nn.Dropout(drop) 269 | 270 | def forward(self, x): 271 | x = self.fc1(x) 272 | x = self.act(x) 273 | x = self.drop(x) 274 | x = self.fc2(x) 275 | x = self.drop(x) 276 | return x 277 | 278 | def make_pairs(x): 279 | """make the int -> tuple 280 | """ 281 | return x if isinstance(x, tuple) else (x, x) 282 | 283 | class InvertedResidualFeedForward(nn.Module): 284 | def __init__(self, dim, dim_ratio=2.): 285 | super(InvertedResidualFeedForward, self).__init__() 286 | output_dim = int(dim_ratio * dim) 287 | self.conv1x1_gelu_bn = ConvGeluBN( 288 | in_channel=dim, 289 | out_channel=output_dim, 290 | kernel_size=1, 291 | stride_size=1, 292 | padding=0 293 | ) 294 | self.conv3x3_dw = ConvDW3x3(dim=output_dim) 295 | self.act = nn.Sequential( 296 | nn.GELU(), 297 | nn.BatchNorm2d(output_dim) 298 | ) 299 | self.conv1x1_pw = nn.Sequential( 300 | nn.Conv2d(output_dim, dim, 1, 1, 0), 301 | nn.BatchNorm2d(dim) 302 | ) 303 | 304 | def forward(self, x): 305 | x = self.conv1x1_gelu_bn(x) 306 | out = x + self.act(self.conv3x3_dw(x)) 307 | out = self.conv1x1_pw(out) 308 | return out 309 | 310 | 311 | class ConvDW3x3(nn.Module): 312 | def __init__(self, dim, kernel_size=3): 313 | super(ConvDW3x3, self).__init__() 314 | self.conv = nn.Conv2d( 315 | in_channels=dim, 316 | out_channels=dim, 317 | kernel_size=make_pairs(kernel_size), 318 | padding=make_pairs(1), 319 | groups=dim) 320 | 321 | def forward(self, x): 322 | x = self.conv(x) 323 | return x 324 | 325 | 326 | class ConvGeluBN(nn.Module): 327 | def __init__(self, in_channel, out_channel, kernel_size, stride_size, padding=1): 328 | """build the conv3x3 + gelu + bn module 329 | """ 330 | super(ConvGeluBN, self).__init__() 331 | self.kernel_size = make_pairs(kernel_size) 332 | self.stride_size = make_pairs(stride_size) 333 | self.padding_size = make_pairs(padding) 334 | self.in_channel = in_channel 335 | self.out_channel = out_channel 336 | self.conv3x3_gelu_bn = nn.Sequential( 337 | nn.Conv2d(in_channels=self.in_channel, 338 | out_channels=self.out_channel, 339 | kernel_size=self.kernel_size, 340 | stride=self.stride_size, 341 | padding=self.padding_size), 342 | nn.GELU(), 343 | nn.BatchNorm2d(self.out_channel) 344 | ) 345 | 346 | def forward(self, x): 347 | x = self.conv3x3_gelu_bn(x) 348 | return x 349 | 350 | class Block(nn.Module): 351 | """ MiM-ISTD Block 352 | """ 353 | def __init__(self, outer_dim, inner_dim, outer_head, inner_head, num_words, mlp_ratio=4., 354 | qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, 355 | norm_layer=nn.LayerNorm, se=0, sr_ratio=1): 356 | super().__init__() 357 | self.has_inner = inner_dim > 0 358 | if self.has_inner: 359 | # Inner 360 | self.inner_norm1 = norm_layer(num_words * inner_dim) 361 | self.inner_attn = SS2D(d_model=inner_dim, dropout=0, d_state=16) 362 | self.inner_norm2 = norm_layer(num_words * inner_dim) 363 | self.inner_mlp = InvertedResidualFeedForward(inner_dim) 364 | # self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio), 365 | # out_features=inner_dim, act_layer=act_layer, drop=drop) 366 | 367 | self.proj_norm1 = norm_layer(num_words * inner_dim) 368 | self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False) 369 | self.proj_norm2 = norm_layer(outer_dim) 370 | # Outer 371 | self.outer_norm1 = norm_layer(outer_dim) 372 | 373 | self.outer_attn = SS2D(d_model=outer_dim, dropout=0, d_state=16) 374 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 375 | self.outer_norm2 = norm_layer(outer_dim) 376 | self.outer_mlp = InvertedResidualFeedForward(outer_dim) 377 | # self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio), 378 | # out_features=outer_dim, act_layer=act_layer, drop=drop) 379 | 380 | def forward(self, x, outer_tokens, H_out, W_out, H_in, W_in, relative_pos): 381 | B, N, C = outer_tokens.size() 382 | #print('outer_tokens input',outer_tokens.shape) 383 | if self.has_inner: 384 | x = x + self.drop_path(self.inner_attn(self.inner_norm1(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1), H_in, W_in)) # B*N, k*k, c 385 | mid=self.inner_norm2(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1) 386 | mid=mid.reshape(B,mid.size(-1),int(math.sqrt(N*H_in*W_in)),int(math.sqrt(N*H_in*W_in))) 387 | x = x + self.drop_path(self.inner_mlp(mid).reshape(B*N, H_in*W_in, -1)).reshape(B*N, H_in*W_in, -1) 388 | #x = x + self.drop_path(self.inner_mlp(self.inner_norm2(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1))) # B*N, k*k, c 389 | outer_tokens = outer_tokens + self.proj_norm2(self.proj(self.proj_norm1(x.reshape(B, N, -1)))) # B, N, C 390 | outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens), H_out, W_out, relative_pos)) 391 | mid_out=self.outer_norm2(outer_tokens) 392 | mid_out=mid_out.reshape(B,mid_out.size(-1),int(math.sqrt(N)),int(math.sqrt(N))) 393 | outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(mid_out).reshape(B,N,C)) 394 | return x, outer_tokens 395 | 396 | 397 | 398 | 399 | 400 | 401 | class PatchMerging2D_sentence(nn.Module): 402 | r""" Patch Merging Layer. 403 | Args: 404 | input_resolution (tuple[int]): Resolution of input feature. 405 | dim (int): Number of input channels. 406 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 407 | """ 408 | 409 | def __init__(self, dim, norm_layer=nn.LayerNorm): 410 | super().__init__() 411 | self.dim = dim 412 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 413 | self.norm = norm_layer(4 * dim) 414 | 415 | def forward(self, x):#(b,h,w,c)->(b,h/2,w/2,2c) 416 | B, N, C = x.shape 417 | x=x.reshape(B,int(math.sqrt(N)),int(math.sqrt(N)),C) 418 | B, H, W, C = x.shape 419 | 420 | SHAPE_FIX = [-1, -1] 421 | if (W % 2 != 0) or (H % 2 != 0): 422 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 423 | SHAPE_FIX[0] = H // 2 424 | SHAPE_FIX[1] = W // 2 425 | 426 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 427 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 428 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 429 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 430 | 431 | if SHAPE_FIX[0] > 0: 432 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 433 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 434 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 435 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 436 | 437 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 438 | x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C 439 | 440 | x = self.norm(x) 441 | x = self.reduction(x) 442 | b, h, w, c = x.shape 443 | x=x.reshape(b,h*w,c) 444 | 445 | return x,h,w 446 | 447 | 448 | class PatchMerging2D_word(nn.Module): 449 | r""" Patch Merging Layer. 450 | Args: 451 | input_resolution (tuple[int]): Resolution of input feature. 452 | dim (int): Number of input channels. 453 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 454 | """ 455 | 456 | def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU): 457 | super().__init__() 458 | self.stride = stride 459 | self.dim_out = dim_out 460 | self.norm = nn.LayerNorm(dim_in) 461 | self.conv = nn.Sequential( 462 | nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride), 463 | ) 464 | 465 | def forward(self, x, H_out, W_out, H_in, W_in): 466 | B_N, M, C = x.shape # B*N, M, C 467 | x = self.norm(x) 468 | x = x.reshape(-1, H_out, W_out, H_in, W_in, C) 469 | # padding to fit (1333, 800) in detection. 470 | pad_input = (H_out % 2 == 1) or (W_out % 2 == 1) 471 | if pad_input: 472 | x = F.pad(x.permute(0, 3, 4, 5, 1, 2), (0, W_out % 2, 0, H_out % 2)) 473 | x = x.permute(0, 4, 5, 1, 2, 3) 474 | 475 | H,W=x.shape[1],x.shape[2] 476 | SHAPE_FIX = [-1, -1] 477 | if (W % 2 != 0) or (H % 2 != 0): 478 | print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) 479 | SHAPE_FIX[0] = H // 2 480 | SHAPE_FIX[1] = W // 2 481 | 482 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 483 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 484 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 485 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 486 | 487 | if SHAPE_FIX[0] > 0: 488 | x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 489 | x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 490 | x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 491 | x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] 492 | 493 | x = torch.cat([torch.cat([x0, x1], 3), torch.cat([x2, x3], 3)], 4) # B, H/2, W/2, 2*H_in, 2*W_in, C 494 | x = x.reshape(-1, 2*H_in, 2*W_in, C).permute(0, 3, 1, 2) # B_N/4, C, 2*H_in, 2*W_in 495 | x = self.conv(x) # B_N/4, C, H_in, W_in 496 | x = x.reshape(-1, self.dim_out, M).transpose(1, 2) 497 | 498 | return x 499 | 500 | 501 | 502 | 503 | class Stem(nn.Module): 504 | 505 | def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24): 506 | super().__init__() 507 | img_size = to_2tuple(img_size) 508 | self.img_size = img_size 509 | self.inner_dim = inner_dim 510 | self.num_patches = img_size[0] // 8 * img_size[1] // 8 511 | self.num_words = 16 512 | 513 | self.common_conv = nn.Sequential( 514 | nn.Conv2d(in_chans, inner_dim*2, 3, stride=2, padding=1), 515 | nn.BatchNorm2d(inner_dim*2), 516 | nn.ReLU(inplace=True), 517 | ) 518 | self.inner_convs = nn.Sequential( 519 | nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1), 520 | nn.BatchNorm2d(inner_dim), 521 | nn.ReLU(inplace=False), 522 | ) 523 | self.outer_convs = nn.Sequential( 524 | nn.Conv2d(inner_dim*2, inner_dim*4, 3, stride=2, padding=1), 525 | nn.BatchNorm2d(inner_dim*4), 526 | nn.ReLU(inplace=True), 527 | nn.Conv2d(inner_dim*4, inner_dim*8, 3, stride=2, padding=1), 528 | nn.BatchNorm2d(inner_dim*8), 529 | nn.ReLU(inplace=True), 530 | nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1), 531 | nn.BatchNorm2d(outer_dim), 532 | nn.ReLU(inplace=False), 533 | ) 534 | self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4) 535 | 536 | def forward(self, x): 537 | B, C, H, W = x.shape 538 | x = self.common_conv(x) 539 | H_out, W_out = H // 8, W // 8 # Each visual sentence corresponds to 8x8 pixel area of the original image 540 | H_in, W_in = 4, 4 # Every visual sentence is composed of 4x4 visual words, Every visual word at the stem stage corresponds to 2x2 pixel area of the original image 541 | # inner_tokens 542 | inner_tokens = self.inner_convs(x) # B, C, H, W 543 | inner_tokens = self.unfold(inner_tokens).transpose(1, 2) # B, N, Ck2 544 | inner_tokens = inner_tokens.reshape(B * H_out * W_out, self.inner_dim, H_in*W_in).transpose(1, 2) # B*N, C, 4*4 545 | # outer_tokens 546 | outer_tokens = self.outer_convs(x) # B, C, H_out, W_out 547 | outer_tokens = outer_tokens.permute(0, 2, 3, 1).reshape(B, H_out * W_out, -1) 548 | return inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in) 549 | 550 | class Stage(nn.Module): 551 | """ PyramidTNT stage 552 | """ 553 | def __init__(self, num_blocks, outer_dim, inner_dim, outer_head, inner_head, num_patches, num_words, mlp_ratio=4., 554 | qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, 555 | norm_layer=nn.LayerNorm, se=0, sr_ratio=1): 556 | super().__init__() 557 | blocks = [] 558 | drop_path = drop_path if isinstance(drop_path, list) else [drop_path] * num_blocks 559 | 560 | for j in range(num_blocks): 561 | if j == 0: 562 | _inner_dim = inner_dim 563 | elif j == 1 and num_blocks > 6: 564 | _inner_dim = inner_dim 565 | else: 566 | _inner_dim = -1 567 | blocks.append(Block( 568 | outer_dim, _inner_dim, outer_head=outer_head, inner_head=inner_head, 569 | num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, 570 | attn_drop=attn_drop, drop_path=drop_path[j], act_layer=act_layer, norm_layer=norm_layer, 571 | se=se, sr_ratio=sr_ratio)) 572 | 573 | self.blocks = nn.ModuleList(blocks) 574 | self.relative_pos = nn.Parameter(torch.randn( 575 | 1, outer_head, num_patches, num_patches // sr_ratio // sr_ratio)) 576 | 577 | def forward(self, inner_tokens, outer_tokens, H_out, W_out, H_in, W_in): 578 | for blk in self.blocks: 579 | inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens, H_out, W_out, H_in, W_in, self.relative_pos) 580 | return inner_tokens, outer_tokens 581 | 582 | 583 | class UpsampleBlock(nn.Module): 584 | def __init__(self, in_channels, out_channels): 585 | super(UpsampleBlock, self).__init__() 586 | # 步长为2的2x2转置卷积 587 | self.transposed_conv = nn.ConvTranspose2d( 588 | in_channels, out_channels, kernel_size=2, stride=2, padding=0 589 | ) 590 | # 批量归一化 591 | self.batch_norm1 = nn.BatchNorm2d(out_channels) 592 | # GeLU 激活函数 593 | self.gelu1 = nn.GELU() 594 | # 步长为1的3x3卷积 595 | self.conv = nn.Conv2d( 596 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 597 | ) 598 | # 另一个批量归一化 599 | self.batch_norm2 = nn.BatchNorm2d(out_channels) 600 | # 另一个 GeLU 激活函数 601 | self.gelu2 = nn.GELU() 602 | 603 | def forward(self, x): 604 | x = self.transposed_conv(x) 605 | x = self.batch_norm1(x) 606 | x = self.gelu1(x) 607 | x = self.conv(x) 608 | x = self.batch_norm2(x) 609 | x = self.gelu2(x) 610 | return x 611 | 612 | 613 | 614 | 615 | class PyramidMiM_enc(nn.Module): 616 | """ Pyramid MiM-ISTD encoder including conv stem for computer vision 617 | """ 618 | def __init__(self, configs=None, img_size=512, in_chans=3, num_classes=1, mlp_ratio=4., qkv_bias=False, 619 | qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, se=0): 620 | super().__init__() 621 | self.num_classes = num_classes 622 | depths = [2, 2, 2, 2] 623 | outer_dims = [32, 32*2, 32*4, 32*8] 624 | inner_dims = [4, 4*2, 4*4, 4*8]# original mim-istd 625 | outer_heads = [2, 2*2, 2*4, 2*8] 626 | inner_heads = [1, 1*2, 1*4, 1*8] 627 | sr_ratios = [4, 2, 1, 1] 628 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 629 | self.num_features = outer_dims[-1] 630 | 631 | self.patch_embed = Stem( 632 | img_size=img_size, in_chans=in_chans, outer_dim=outer_dims[0], inner_dim=inner_dims[0]) 633 | num_patches = self.patch_embed.num_patches 634 | num_words = self.patch_embed.num_words 635 | 636 | 637 | depth = 0 638 | self.word_merges = nn.ModuleList([]) 639 | self.sentence_merges = nn.ModuleList([]) 640 | self.stages = nn.ModuleList([]) 641 | for i in range(4): 642 | if i > 0: 643 | self.word_merges.append(PatchMerging2D_word(inner_dims[i-1], inner_dims[i])) 644 | self.sentence_merges.append(PatchMerging2D_sentence(outer_dims[i-1])) 645 | self.stages.append(Stage(depths[i], outer_dim=outer_dims[i], inner_dim=inner_dims[i], 646 | outer_head=outer_heads[i], inner_head=inner_heads[i], 647 | num_patches=num_patches // (2 ** i) // (2 ** i), num_words=num_words, mlp_ratio=mlp_ratio, 648 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, 649 | drop_path=dpr[depth:depth+depths[i]], norm_layer=norm_layer, se=se, sr_ratio=sr_ratios[i]) 650 | ) 651 | depth += depths[i] 652 | 653 | self.norm = norm_layer(outer_dims[-1]) 654 | 655 | self.up_blocks = nn.ModuleList([]) 656 | for i in range(4): 657 | self.up_blocks.append(UpsampleBlock(outer_dims[i],outer_dims[i])) 658 | 659 | self.apply(self._init_weights) 660 | 661 | def _init_weights(self, m): 662 | if isinstance(m, nn.Linear): 663 | trunc_normal_(m.weight, std=.02) 664 | if isinstance(m, nn.Linear) and m.bias is not None: 665 | nn.init.constant_(m.bias, 0) 666 | if isinstance(m, nn.Conv2d): 667 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 668 | fan_out //= m.groups 669 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 670 | if m.bias is not None: 671 | m.bias.data.zero_() 672 | elif isinstance(m, nn.LayerNorm): 673 | nn.init.constant_(m.bias, 0) 674 | nn.init.constant_(m.weight, 1.0) 675 | 676 | @torch.jit.ignore 677 | def no_weight_decay(self): 678 | return {'outer_pos', 'inner_pos'} 679 | 680 | def get_classifier(self): 681 | return self.head 682 | 683 | def reset_classifier(self, num_classes, global_pool=''): 684 | self.num_classes = num_classes 685 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 686 | 687 | def forward_features(self, x): 688 | size = x.size()[2:] 689 | inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in) = self.patch_embed(x) 690 | outputs=[] 691 | 692 | for i in range(4): 693 | if i > 0: 694 | inner_tokens = self.word_merges[i-1](inner_tokens, H_out, W_out, H_in, W_in) 695 | outer_tokens, H_out, W_out = self.sentence_merges[i-1](outer_tokens) 696 | inner_tokens, outer_tokens = self.stages[i](inner_tokens, outer_tokens, H_out, W_out, H_in, W_in) 697 | b,l,m=outer_tokens.shape 698 | mid_out=outer_tokens.reshape(b,int(math.sqrt(l)),int(math.sqrt(l)),m).permute(0,3,1,2) 699 | mid_out=self.up_blocks[i](mid_out) 700 | 701 | outputs.append(mid_out) 702 | 703 | return outputs 704 | 705 | def forward(self, x): 706 | x = self.forward_features(x) 707 | return x 708 | 709 | 710 | class ResidualBlock(nn.Module): 711 | def __init__(self, in_channels, out_channels, stride, downsample): 712 | super(ResidualBlock, self).__init__() 713 | self.body = nn.Sequential( 714 | nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False), 715 | nn.BatchNorm2d(out_channels), 716 | nn.ReLU(True), 717 | 718 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), 719 | nn.BatchNorm2d(out_channels), 720 | ) 721 | if downsample: 722 | self.downsample = nn.Sequential( 723 | nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False), 724 | nn.BatchNorm2d(out_channels), 725 | ) 726 | else: 727 | self.downsample = nn.Sequential() 728 | 729 | 730 | def forward(self, x): 731 | residual = x 732 | x = self.body(x) 733 | 734 | if self.downsample: 735 | residual = self.downsample(residual) 736 | 737 | out = F.relu(x+residual, True) 738 | return out 739 | 740 | class _FCNHead(nn.Module): 741 | def __init__(self, in_channels, out_channels): 742 | super(_FCNHead, self).__init__() 743 | inter_channels = in_channels // 4 744 | self.block = nn.Sequential( 745 | nn.Conv2d(in_channels, inter_channels, 3, 1, 1, bias=False), 746 | nn.BatchNorm2d(inter_channels), 747 | nn.ReLU(True), 748 | nn.Dropout(0.1), 749 | nn.Conv2d(inter_channels, out_channels, 1, 1, 0) 750 | ) 751 | 752 | def forward(self, x): 753 | return self.block(x) 754 | 755 | class PatchExpand2D(nn.Module): 756 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 757 | super().__init__() 758 | self.dim = dim*2 759 | self.dim_scale = dim_scale 760 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 761 | self.norm = norm_layer(self.dim // dim_scale) 762 | 763 | def forward(self, x):#(b,h,w,c)->(b,h,w,2c)->(b,2h,2w,c/2) 764 | x=x.permute(0,2,3,1) 765 | B, H, W, C = x.shape 766 | x = self.expand(x) 767 | 768 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 769 | x= self.norm(x).permute(0,3,1,2) 770 | 771 | return x 772 | 773 | 774 | class MiM(nn.Module): 775 | def __init__(self, layer_blocks, channels): 776 | super(MiM, self).__init__() 777 | 778 | self.deconv3 = PatchExpand2D(channels[4]//2) 779 | #self.deconv3 = nn.ConvTranspose2d(channels[4], channels[3], 4, 2, 1) 780 | self.uplayer3 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[2], 781 | in_channels=channels[3], out_channels=channels[3], stride=1) 782 | self.deconv2 = PatchExpand2D(channels[3]//2) 783 | #self.deconv2 = nn.ConvTranspose2d(channels[3], channels[2], 4, 2, 1) 784 | self.uplayer2 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[1], 785 | in_channels=channels[2], out_channels=channels[2], stride=1) 786 | self.deconv1 = PatchExpand2D(channels[2]//2) 787 | #self.deconv1 = nn.ConvTranspose2d(channels[2], channels[1], 4, 2, 1) 788 | self.uplayer1 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[0], 789 | in_channels=channels[1], out_channels=channels[1], stride=1) 790 | self.head = _FCNHead(channels[1], 1) 791 | ##################### 792 | self.mim_backbone = PyramidMiM_enc() 793 | 794 | def forward(self, x): # the input is of size (b,3,512,512), the output is of size (b,1,512,512), where the num_class=1 in ISTD. 795 | _, _, hei, wid = x.shape 796 | 797 | outputs=self.mim_backbone(x) 798 | t1,t2,t3,t4=outputs[0],outputs[1],outputs[2],outputs[3] 799 | 800 | deconc3 = self.deconv3(t4) 801 | fusec3 = deconc3+t3 802 | 803 | upc3 = self.uplayer3(fusec3) 804 | 805 | deconc2 = self.deconv2(upc3) 806 | fusec2 = deconc2+t2 807 | 808 | upc2 = self.uplayer2(fusec2) 809 | 810 | deconc1 = self.deconv1(upc2) 811 | fusec1 = deconc1+t1 812 | 813 | upc1 = self.uplayer1(fusec1) 814 | 815 | pred = self.head(upc1) 816 | out = F.interpolate(pred, size=[hei, wid], mode='bilinear') 817 | 818 | return out 819 | 820 | def _make_layer(self, block, block_num, in_channels, out_channels, stride): 821 | layer = [] 822 | downsample = (in_channels != out_channels) or (stride != 1) 823 | layer.append(block(in_channels, out_channels, stride, downsample)) 824 | for _ in range(block_num-1): 825 | layer.append(block(out_channels, out_channels, 1, False)) 826 | return nn.Sequential(*layer) 827 | 828 | 829 | if __name__ == '__main__': 830 | input_ = torch.Tensor(5, 3, 256, 256) 831 | net = MiM([2]*3,[8, 16, 32, 64, 128]) 832 | out=net(input_) 833 | 834 | 835 | 836 | 837 | 838 | 839 | --------------------------------------------------------------------------------