├── README.md ├── img └── framework-github.png └── models ├── hmt_unet.py ├── mamba_vision.py └── registry.py /README.md: -------------------------------------------------------------------------------- 1 | # HMT-Unet 2 | 3 | This is the official code repository for "HMT-UNet: A hybird Mamba-Transformer Vision UNet for Medical Image Segmentation". {[Arxiv Paper](https://arxiv.org/html/2408.11289v1)} 4 | 5 | ![framework](img/framework-github.png) 6 | 7 | ## training details 8 | 9 | Our training code can refer to the VM-UnetV2 repository{[git link](https://github.com/nobodyplayer1/VM-UNetV2)}, please replace the files with those from this model. 10 | 11 | If there are any issues, feel free to contact and raise an issue. 12 | 13 | My email: dg20330034@smail.nju.edu.cn 14 | 15 | ## Cite: 16 | 17 | ``` 18 | @misc{2408.11289, 19 | Author = {Mingya Zhang and Limei Gu and Tingshen Ling and Xianping Tao}, 20 | Title = {HMT-UNet: A hybird Mamba-Transformer Vision UNet for Medical Image Segmentation}, 21 | Year = {2024}, 22 | Eprint = {arXiv:2408.11289}, 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /img/framework-github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simzhangbest/HMT-Unet/2e84d2d31430cd2ecbadf30ae12effab221a1293/img/framework-github.png -------------------------------------------------------------------------------- /models/hmt_unet.py: -------------------------------------------------------------------------------- 1 | from .mamba_vision import MambaVision, mamba_vision_T, MambaVision_sim 2 | import torch 3 | from torch import nn 4 | 5 | # by mingya zhang dg20330034@smail.nju.edu.cn 2024 08 16 6 | 7 | class HMTUNet(nn.Module): 8 | 9 | def __init__(self, 10 | input_channels=3, 11 | num_classes=1, 12 | depths=[1, 3, 8, 4], 13 | num_heads=[2, 4, 8, 16], 14 | window_size=[8, 8, 14, 7], 15 | dim=80, 16 | in_dim=32, 17 | mlp_ratio=4, 18 | resolution=224, 19 | drop_path_rate=0.2, 20 | load_ckpt_path=None, 21 | **kwargs): 22 | 23 | super().__init__() 24 | 25 | self.load_ckpt_path = load_ckpt_path 26 | self.num_classes = num_classes 27 | 28 | self.hmtunet = MambaVision_sim( 29 | depths=depths, 30 | num_heads=num_heads, 31 | window_size=window_size, 32 | dim=dim, 33 | in_dim=in_dim, 34 | mlp_ratio=mlp_ratio, 35 | resolution=resolution, 36 | drop_path_rate=drop_path_rate, 37 | ) 38 | 39 | 40 | 41 | 42 | def forward(self, x): 43 | return self.hmtunet(x) 44 | 45 | 46 | def load_from(self): 47 | if self.load_ckpt_path is not None: 48 | model_dict = self.hmtunet.state_dict() 49 | model_checkpoint = torch.load(self.load_ckpt_path) 50 | pretrained_dict = model_checkpoint['state_dict'] 51 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 52 | model_dict.update(new_dict) 53 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 54 | self.hmtunet.load_state_dict(model_dict) 55 | 56 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 57 | print('Not loaded keys:', not_loaded_keys) 58 | print("encoder loaded finished!") 59 | 60 | model_dict = self.hmtunet.state_dict() 61 | model_checkpoint = torch.load(self.load_ckpt_path) 62 | pretrained_order_dict = model_checkpoint['state_dict'] 63 | pretrained_dict = {} 64 | for k,v in pretrained_order_dict.items(): 65 | if 'levels.0' in k: 66 | new_k = k.replace('levels.0', 'layers_up.3') 67 | pretrained_dict[new_k] = v 68 | elif 'levels.1' in k: 69 | new_k = k.replace('levels.1', 'layers_up.2') 70 | pretrained_dict[new_k] = v 71 | elif 'levels.2' in k: 72 | new_k = k.replace('levels.2', 'layers_up.1') 73 | pretrained_dict[new_k] = v 74 | elif 'levels.3' in k: 75 | new_k = k.replace('levels.3', 'layers_up.0') 76 | pretrained_dict[new_k] = v 77 | 78 | # decoder 79 | new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 80 | model_dict.update(new_dict) 81 | print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict))) 82 | self.hmtunet.load_state_dict(model_dict) 83 | 84 | # 找到没有加载的键(keys) 85 | not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()] 86 | print('Not loaded keys:', not_loaded_keys) 87 | print("decoder loaded finished!") 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /models/mamba_vision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | from timm.models.registry import register_model 15 | import math 16 | from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d 17 | from timm.models._builder import resolve_pretrained_cfg 18 | try: 19 | from timm.models._builder import _update_default_kwargs as update_args 20 | except: 21 | from timm.models._builder import _update_default_model_kwargs as update_args 22 | from timm.models.vision_transformer import Mlp, PatchEmbed 23 | from timm.models.layers import DropPath, trunc_normal_ 24 | from timm.models.registry import register_model 25 | import torch.nn.functional as F 26 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn 27 | from einops import rearrange, repeat 28 | # from .registry import register_pip_model 29 | from .registry import register_pip_model # debug use only 30 | from pathlib import Path 31 | 32 | 33 | def _cfg(url='', **kwargs): 34 | return {'url': url, 35 | 'num_classes': 1000, 36 | 'input_size': (3, 224, 224), 37 | 'pool_size': None, 38 | 'crop_pct': 0.875, 39 | 'interpolation': 'bicubic', 40 | 'fixed_input_size': True, 41 | 'mean': (0.485, 0.456, 0.406), 42 | 'std': (0.229, 0.224, 0.225), 43 | **kwargs 44 | } 45 | 46 | 47 | default_cfgs = { 48 | 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar', 49 | crop_pct=1.0, 50 | input_size=(3, 224, 224), 51 | crop_mode='center'), 52 | 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar', 53 | crop_pct=0.98, 54 | input_size=(3, 224, 224), 55 | crop_mode='center'), 56 | 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar', 57 | crop_pct=0.93, 58 | input_size=(3, 224, 224), 59 | crop_mode='center'), 60 | 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar', 61 | crop_pct=1.0, 62 | input_size=(3, 224, 224), 63 | crop_mode='center'), 64 | 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar', 65 | crop_pct=1.0, 66 | input_size=(3, 224, 224), 67 | crop_mode='center'), 68 | 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar', 69 | crop_pct=1.0, 70 | input_size=(3, 224, 224), 71 | crop_mode='center') 72 | } 73 | 74 | 75 | def window_partition(x, window_size): 76 | """ 77 | Args: 78 | x: (B, C, H, W) 79 | window_size: window size 80 | h_w: Height of window 81 | w_w: Width of window 82 | Returns: 83 | local window features (num_windows*B, window_size*window_size, C) 84 | """ 85 | B, C, H, W = x.shape 86 | x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) 87 | windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C) 88 | return windows 89 | 90 | 91 | def window_reverse(windows, window_size, H, W): 92 | """ 93 | Args: 94 | windows: local window features (num_windows*B, window_size, window_size, C) 95 | window_size: Window size 96 | H: Height of image 97 | W: Width of image 98 | Returns: 99 | x: (B, C, H, W) 100 | """ 101 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 102 | x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1) 103 | x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W) 104 | return x 105 | 106 | 107 | def _load_state_dict(module, state_dict, strict=False, logger=None): 108 | """Load state_dict to a module. 109 | 110 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 111 | Default value for ``strict`` is set to ``False`` and the message for 112 | param mismatch will be shown even if strict is False. 113 | 114 | Args: 115 | module (Module): Module that receives the state_dict. 116 | state_dict (OrderedDict): Weights. 117 | strict (bool): whether to strictly enforce that the keys 118 | in :attr:`state_dict` match the keys returned by this module's 119 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 120 | logger (:obj:`logging.Logger`, optional): Logger to log the error 121 | message. If not specified, print function will be used. 122 | """ 123 | unexpected_keys = [] 124 | all_missing_keys = [] 125 | err_msg = [] 126 | 127 | metadata = getattr(state_dict, '_metadata', None) 128 | state_dict = state_dict.copy() 129 | if metadata is not None: 130 | state_dict._metadata = metadata 131 | 132 | def load(module, prefix=''): 133 | local_metadata = {} if metadata is None else metadata.get( 134 | prefix[:-1], {}) 135 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 136 | all_missing_keys, unexpected_keys, 137 | err_msg) 138 | for name, child in module._modules.items(): 139 | if child is not None: 140 | load(child, prefix + name + '.') 141 | 142 | load(module) 143 | load = None 144 | missing_keys = [ 145 | key for key in all_missing_keys if 'num_batches_tracked' not in key 146 | ] 147 | 148 | if unexpected_keys: 149 | err_msg.append('unexpected key in source ' 150 | f'state_dict: {", ".join(unexpected_keys)}\n') 151 | if missing_keys: 152 | err_msg.append( 153 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n') 154 | 155 | 156 | if len(err_msg) > 0: 157 | err_msg.insert( 158 | 0, 'The model and loaded state dict do not match exactly\n') 159 | err_msg = '\n'.join(err_msg) 160 | if strict: 161 | raise RuntimeError(err_msg) 162 | elif logger is not None: 163 | logger.warning(err_msg) 164 | else: 165 | print(err_msg) 166 | 167 | 168 | def _load_checkpoint(model, 169 | filename, 170 | map_location='cpu', 171 | strict=False, 172 | logger=None): 173 | """Load checkpoint from a file or URI. 174 | 175 | Args: 176 | model (Module): Module to load checkpoint. 177 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 178 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 179 | details. 180 | map_location (str): Same as :func:`torch.load`. 181 | strict (bool): Whether to allow different params for the model and 182 | checkpoint. 183 | logger (:mod:`logging.Logger` or None): The logger for error message. 184 | 185 | Returns: 186 | dict or OrderedDict: The loaded checkpoint. 187 | """ 188 | checkpoint = torch.load(filename, map_location=map_location) 189 | if not isinstance(checkpoint, dict): 190 | raise RuntimeError( 191 | f'No state_dict found in checkpoint file {filename}') 192 | if 'state_dict' in checkpoint: 193 | state_dict = checkpoint['state_dict'] 194 | elif 'model' in checkpoint: 195 | state_dict = checkpoint['model'] 196 | else: 197 | state_dict = checkpoint 198 | if list(state_dict.keys())[0].startswith('module.'): 199 | state_dict = {k[7:]: v for k, v in state_dict.items()} 200 | 201 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): 202 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} 203 | 204 | _load_state_dict(model, state_dict, strict, logger) 205 | return checkpoint 206 | 207 | 208 | class Downsample(nn.Module): 209 | """ 210 | Down-sampling block" 211 | """ 212 | 213 | def __init__(self, 214 | dim, 215 | keep_dim=False, 216 | ): 217 | """ 218 | Args: 219 | dim: feature size dimension. 220 | norm_layer: normalization layer. 221 | keep_dim: bool argument for maintaining the resolution. 222 | """ 223 | 224 | super().__init__() 225 | if keep_dim: 226 | dim_out = dim 227 | else: 228 | dim_out = 2 * dim 229 | self.reduction = nn.Sequential( 230 | nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False), 231 | ) 232 | 233 | def forward(self, x): 234 | x = self.reduction(x) 235 | return x 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | class Upsample(nn.Module): 245 | """ 246 | Up-sampling block" 247 | """ 248 | 249 | def __init__(self, 250 | dim, 251 | keep_dim=False, 252 | ): 253 | """ 254 | Args: 255 | dim: feature size dimension. 256 | norm_layer: normalization layer. 257 | keep_dim: bool argument for maintaining the resolution. 258 | """ 259 | 260 | super().__init__() 261 | if keep_dim: 262 | dim_out = dim 263 | else: 264 | dim_out = dim // 2 265 | self.expansion = nn.Sequential( 266 | nn.ConvTranspose2d(dim, dim_out, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False), 267 | # nn.BatchNorm2d(dim_out) # Optionally add batch normalization after upsampling 268 | ) 269 | 270 | def forward(self, x): 271 | x = self.expansion(x) 272 | return x 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | class PatchEmbed(nn.Module): 283 | """ 284 | Patch embedding block" 285 | """ 286 | 287 | def __init__(self, in_chans=3, in_dim=64, dim=96): 288 | """ 289 | Args: 290 | in_chans: number of input channels. 291 | dim: feature size dimension. 292 | """ 293 | # in_dim = 1 294 | super().__init__() 295 | self.proj = nn.Identity() 296 | self.conv_down = nn.Sequential( 297 | nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False), 298 | nn.BatchNorm2d(in_dim, eps=1e-4), 299 | nn.ReLU(), 300 | nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False), 301 | nn.BatchNorm2d(dim, eps=1e-4), 302 | nn.ReLU() 303 | ) 304 | 305 | def forward(self, x): 306 | x = self.proj(x) 307 | x = self.conv_down(x) 308 | return x 309 | 310 | 311 | class ConvBlock(nn.Module): 312 | 313 | def __init__(self, dim, 314 | drop_path=0., 315 | layer_scale=None, 316 | kernel_size=3): 317 | super().__init__() 318 | 319 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) 320 | self.norm1 = nn.BatchNorm2d(dim, eps=1e-5) 321 | self.act1 = nn.GELU(approximate= 'tanh') 322 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1) 323 | self.norm2 = nn.BatchNorm2d(dim, eps=1e-5) 324 | self.layer_scale = layer_scale 325 | if layer_scale is not None and type(layer_scale) in [int, float]: 326 | self.gamma = nn.Parameter(layer_scale * torch.ones(dim)) 327 | self.layer_scale = True 328 | else: 329 | self.layer_scale = False 330 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 331 | 332 | def forward(self, x): 333 | input = x 334 | x = self.conv1(x) 335 | x = self.norm1(x) 336 | x = self.act1(x) 337 | x = self.conv2(x) 338 | x = self.norm2(x) 339 | if self.layer_scale: 340 | x = x * self.gamma.view(1, -1, 1, 1) 341 | x = input + self.drop_path(x) 342 | return x 343 | 344 | 345 | class MambaVisionMixer(nn.Module): 346 | def __init__( 347 | self, 348 | d_model, 349 | d_state=16, 350 | d_conv=4, 351 | expand=2, 352 | dt_rank="auto", 353 | dt_min=0.001, 354 | dt_max=0.1, 355 | dt_init="random", 356 | dt_scale=1.0, 357 | dt_init_floor=1e-4, 358 | conv_bias=True, 359 | bias=False, 360 | use_fast_path=True, 361 | layer_idx=None, 362 | device=None, 363 | dtype=None, 364 | ): 365 | factory_kwargs = {"device": device, "dtype": dtype} 366 | super().__init__() 367 | self.d_model = d_model 368 | self.d_state = d_state 369 | self.d_conv = d_conv 370 | self.expand = expand 371 | self.d_inner = int(self.expand * self.d_model) 372 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 373 | self.use_fast_path = use_fast_path 374 | self.layer_idx = layer_idx 375 | self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs) 376 | self.x_proj = nn.Linear( 377 | self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 378 | ) 379 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs) 380 | dt_init_std = self.dt_rank**-0.5 * dt_scale 381 | if dt_init == "constant": 382 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 383 | elif dt_init == "random": 384 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 385 | else: 386 | raise NotImplementedError 387 | dt = torch.exp( 388 | torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 389 | + math.log(dt_min) 390 | ).clamp(min=dt_init_floor) 391 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 392 | with torch.no_grad(): 393 | self.dt_proj.bias.copy_(inv_dt) 394 | self.dt_proj.bias._no_reinit = True 395 | A = repeat( 396 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 397 | "n -> d n", 398 | d=self.d_inner//2, 399 | ).contiguous() 400 | A_log = torch.log(A) 401 | self.A_log = nn.Parameter(A_log) 402 | self.A_log._no_weight_decay = True 403 | self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device)) 404 | self.D._no_weight_decay = True 405 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 406 | self.conv1d_x = nn.Conv1d( 407 | in_channels=self.d_inner//2, 408 | out_channels=self.d_inner//2, 409 | bias=conv_bias//2, 410 | kernel_size=d_conv, 411 | groups=self.d_inner//2, 412 | **factory_kwargs, 413 | ) 414 | self.conv1d_z = nn.Conv1d( 415 | in_channels=self.d_inner//2, 416 | out_channels=self.d_inner//2, 417 | bias=conv_bias//2, 418 | kernel_size=d_conv, 419 | groups=self.d_inner//2, 420 | **factory_kwargs, 421 | ) 422 | 423 | def forward(self, hidden_states): 424 | """ 425 | hidden_states: (B, L, D) 426 | Returns: same shape as hidden_states 427 | """ 428 | _, seqlen, _ = hidden_states.shape 429 | xz = self.in_proj(hidden_states) 430 | xz = rearrange(xz, "b l d -> b d l") 431 | x, z = xz.chunk(2, dim=1) 432 | A = -torch.exp(self.A_log.float()) 433 | # 注意测试 434 | x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2)) 435 | z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2)) 436 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) 437 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 438 | dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen) 439 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 440 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 441 | 442 | # 这里的x 就是 不用先silu 443 | # x = F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2) 444 | # x = x*z # Hadamard product simzhang added 445 | 446 | y = selective_scan_fn(x, 447 | dt, 448 | A, 449 | B, 450 | C, 451 | self.D.float(), 452 | z=None, 453 | delta_bias=self.dt_proj.bias.float(), 454 | delta_softplus=True, 455 | return_last_state=None) 456 | 457 | y = torch.cat([y, z], dim=1) 458 | y = rearrange(y, "b d l -> b l d") 459 | out = self.out_proj(y) 460 | return out 461 | 462 | 463 | class Attention(nn.Module): 464 | 465 | def __init__( 466 | self, 467 | dim, 468 | num_heads=8, 469 | qkv_bias=False, 470 | qk_norm=False, 471 | attn_drop=0., 472 | proj_drop=0., 473 | norm_layer=nn.LayerNorm, 474 | ): 475 | super().__init__() 476 | assert dim % num_heads == 0 477 | self.num_heads = num_heads 478 | self.head_dim = dim // num_heads 479 | self.scale = self.head_dim ** -0.5 480 | self.fused_attn = True 481 | 482 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 483 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 484 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 485 | self.attn_drop = nn.Dropout(attn_drop) 486 | self.proj = nn.Linear(dim, dim) 487 | self.proj_drop = nn.Dropout(proj_drop) 488 | 489 | def forward(self, x): 490 | B, N, C = x.shape 491 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 492 | q, k, v = qkv.unbind(0) 493 | q, k = self.q_norm(q), self.k_norm(k) 494 | 495 | if self.fused_attn: 496 | x = F.scaled_dot_product_attention( 497 | q, k, v, 498 | dropout_p=self.attn_drop.p, 499 | ) 500 | else: 501 | q = q * self.scale 502 | attn = q @ k.transpose(-2, -1) 503 | attn = attn.softmax(dim=-1) 504 | attn = self.attn_drop(attn) 505 | x = attn @ v 506 | 507 | x = x.transpose(1, 2).reshape(B, N, C) 508 | x = self.proj(x) 509 | x = self.proj_drop(x) 510 | return x 511 | 512 | 513 | class Block(nn.Module): 514 | def __init__(self, 515 | dim, 516 | num_heads, 517 | counter, 518 | transformer_blocks, 519 | mlp_ratio=4., 520 | qkv_bias=False, 521 | qk_scale=False, 522 | drop=0., 523 | attn_drop=0., 524 | drop_path=0., 525 | act_layer=nn.GELU, 526 | norm_layer=nn.LayerNorm, 527 | Mlp_block=Mlp, 528 | layer_scale=None, 529 | ): 530 | super().__init__() 531 | self.norm1 = norm_layer(dim) 532 | if counter in transformer_blocks: 533 | self.mixer = Attention( 534 | dim, 535 | num_heads=num_heads, 536 | qkv_bias=qkv_bias, 537 | qk_norm=qk_scale, 538 | attn_drop=attn_drop, 539 | proj_drop=drop, 540 | norm_layer=norm_layer, 541 | ) 542 | else: 543 | self.mixer = MambaVisionMixer(d_model=dim, 544 | d_state=8, 545 | d_conv=3, 546 | expand=1 547 | ) 548 | 549 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 550 | self.norm2 = norm_layer(dim) 551 | mlp_hidden_dim = int(dim * mlp_ratio) 552 | self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 553 | use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float] 554 | self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 555 | self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1 556 | 557 | def forward(self, x): 558 | x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x))) 559 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 560 | return x 561 | 562 | 563 | class MambaVisionLayer(nn.Module): 564 | """ 565 | MambaVision layer" 566 | """ 567 | 568 | def __init__(self, 569 | dim, 570 | depth, 571 | num_heads, 572 | window_size, 573 | conv=False, 574 | downsample=True, 575 | mlp_ratio=4., 576 | qkv_bias=True, 577 | qk_scale=None, 578 | drop=0., 579 | attn_drop=0., 580 | drop_path=0., 581 | layer_scale=None, 582 | layer_scale_conv=None, 583 | transformer_blocks = [], 584 | ): 585 | """ 586 | Args: 587 | dim: feature size dimension. 588 | depth: number of layers in each stage. 589 | window_size: window size in each stage. 590 | conv: bool argument for conv stage flag. 591 | downsample: bool argument for down-sampling. 592 | mlp_ratio: MLP ratio. 593 | num_heads: number of heads in each stage. 594 | qkv_bias: bool argument for query, key, value learnable bias. 595 | qk_scale: bool argument to scaling query, key. 596 | drop: dropout rate. 597 | attn_drop: attention dropout rate. 598 | drop_path: drop path rate. 599 | norm_layer: normalization layer. 600 | layer_scale: layer scaling coefficient. 601 | layer_scale_conv: conv layer scaling coefficient. 602 | transformer_blocks: list of transformer blocks. 603 | """ 604 | 605 | super().__init__() 606 | self.conv = conv 607 | self.transformer_block = False 608 | if conv: 609 | self.blocks = nn.ModuleList([ConvBlock(dim=dim, 610 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 611 | layer_scale=layer_scale_conv) 612 | for i in range(depth)]) 613 | self.transformer_block = False 614 | else: 615 | self.blocks = nn.ModuleList([Block(dim=dim, 616 | counter=i, 617 | transformer_blocks=transformer_blocks, 618 | num_heads=num_heads, 619 | mlp_ratio=mlp_ratio, 620 | qkv_bias=qkv_bias, 621 | qk_scale=qk_scale, 622 | drop=drop, 623 | attn_drop=attn_drop, 624 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 625 | layer_scale=layer_scale) 626 | for i in range(depth)]) 627 | self.transformer_block = True 628 | 629 | self.downsample = None if not downsample else Downsample(dim=dim) 630 | self.do_gt = False 631 | self.window_size = window_size 632 | 633 | def forward(self, x): 634 | _, _, H, W = x.shape 635 | 636 | if self.transformer_block: 637 | pad_r = (self.window_size - W % self.window_size) % self.window_size 638 | pad_b = (self.window_size - H % self.window_size) % self.window_size 639 | if pad_r > 0 or pad_b > 0: 640 | x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b)) 641 | _, _, Hp, Wp = x.shape 642 | else: 643 | Hp, Wp = H, W 644 | x = window_partition(x, self.window_size) 645 | 646 | for _, blk in enumerate(self.blocks): 647 | x = blk(x) 648 | if self.transformer_block: 649 | x = window_reverse(x, self.window_size, Hp, Wp) 650 | if pad_r > 0 or pad_b > 0: 651 | x = x[:, :, :H, :W].contiguous() 652 | if self.downsample is None: 653 | return x 654 | return self.downsample(x) 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | class MambaVisionLayer_up(nn.Module): 666 | """ 667 | MambaVision layer" 668 | """ 669 | 670 | def __init__(self, 671 | dim, 672 | depth, 673 | num_heads, 674 | window_size, 675 | conv=False, 676 | upsample=True, 677 | mlp_ratio=4., 678 | qkv_bias=True, 679 | qk_scale=None, 680 | drop=0., 681 | attn_drop=0., 682 | drop_path=0., 683 | layer_scale=None, 684 | layer_scale_conv=None, 685 | transformer_blocks = [], 686 | ): 687 | """ 688 | Args: 689 | dim: feature size dimension. 690 | depth: number of layers in each stage. 691 | window_size: window size in each stage. 692 | conv: bool argument for conv stage flag. 693 | upsample: bool argument for up-sampling. 694 | mlp_ratio: MLP ratio. 695 | num_heads: number of heads in each stage. 696 | qkv_bias: bool argument for query, key, value learnable bias. 697 | qk_scale: bool argument to scaling query, key. 698 | drop: dropout rate. 699 | attn_drop: attention dropout rate. 700 | drop_path: drop path rate. 701 | norm_layer: normalization layer. 702 | layer_scale: layer scaling coefficient. 703 | layer_scale_conv: conv layer scaling coefficient. 704 | transformer_blocks: list of transformer blocks. 705 | """ 706 | 707 | super().__init__() 708 | self.conv = conv 709 | self.transformer_block = False 710 | if conv: 711 | self.blocks = nn.ModuleList([ConvBlock(dim=dim, 712 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 713 | layer_scale=layer_scale_conv) 714 | for i in range(depth)]) 715 | self.transformer_block = False 716 | else: 717 | self.blocks = nn.ModuleList([Block(dim=dim, 718 | counter=i, 719 | transformer_blocks=transformer_blocks, 720 | num_heads=num_heads, 721 | mlp_ratio=mlp_ratio, 722 | qkv_bias=qkv_bias, 723 | qk_scale=qk_scale, 724 | drop=drop, 725 | attn_drop=attn_drop, 726 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 727 | layer_scale=layer_scale) 728 | for i in range(depth)]) 729 | self.transformer_block = True 730 | 731 | self.upsample = None if not upsample else PatchExpand2D_sim(dim=dim) 732 | self.do_gt = False 733 | self.window_size = window_size 734 | 735 | def forward(self, x): 736 | if self.upsample is not None: 737 | x = self.upsample(x) 738 | 739 | _, _, H, W = x.shape 740 | if self.transformer_block: 741 | pad_r = (self.window_size - W % self.window_size) % self.window_size 742 | pad_b = (self.window_size - H % self.window_size) % self.window_size 743 | if pad_r > 0 or pad_b > 0: 744 | x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b)) 745 | _, _, Hp, Wp = x.shape 746 | else: 747 | Hp, Wp = H, W 748 | x = window_partition(x, self.window_size) 749 | 750 | for _, blk in enumerate(self.blocks): 751 | x = blk(x) 752 | if self.transformer_block: 753 | x = window_reverse(x, self.window_size, Hp, Wp) 754 | if pad_r > 0 or pad_b > 0: 755 | x = x[:, :, :H, :W].contiguous() 756 | 757 | return x 758 | 759 | 760 | 761 | 762 | 763 | class MambaVision(nn.Module): 764 | """ 765 | MambaVision, 766 | """ 767 | 768 | def __init__(self, 769 | dim, 770 | in_dim, 771 | depths, 772 | window_size, 773 | mlp_ratio, 774 | num_heads, 775 | drop_path_rate=0.2, 776 | in_chans=3, 777 | num_classes=1000, 778 | qkv_bias=True, 779 | qk_scale=None, 780 | drop_rate=0., 781 | attn_drop_rate=0., 782 | layer_scale=None, 783 | layer_scale_conv=None, 784 | **kwargs): 785 | """ 786 | Args: 787 | dim: feature size dimension. 788 | depths: number of layers in each stage. 789 | window_size: window size in each stage. 790 | mlp_ratio: MLP ratio. 791 | num_heads: number of heads in each stage. 792 | drop_path_rate: drop path rate. 793 | in_chans: number of input channels. 794 | num_classes: number of classes. 795 | qkv_bias: bool argument for query, key, value learnable bias. 796 | qk_scale: bool argument to scaling query, key. 797 | drop_rate: dropout rate. 798 | attn_drop_rate: attention dropout rate. 799 | norm_layer: normalization layer. 800 | layer_scale: layer scaling coefficient. 801 | layer_scale_conv: conv layer scaling coefficient. 802 | """ 803 | super().__init__() 804 | num_features = int(dim * 2 ** (len(depths) - 1)) 805 | self.num_classes = num_classes 806 | self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim) 807 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 808 | dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))][::-1] # encoder 和 decoder dpr 相反 809 | 810 | self.levels = nn.ModuleList() 811 | for i in range(len(depths)): 812 | conv = True if (i == 0 or i == 1) else False 813 | level = MambaVisionLayer(dim=int(dim * 2 ** i), 814 | depth=depths[i], 815 | num_heads=num_heads[i], 816 | window_size=window_size[i], 817 | mlp_ratio=mlp_ratio, 818 | qkv_bias=qkv_bias, 819 | qk_scale=qk_scale, 820 | conv=conv, 821 | drop=drop_rate, 822 | attn_drop=attn_drop_rate, 823 | drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], # 选择每个layer 的 dpr 参数 824 | downsample=(i < 3), 825 | layer_scale=layer_scale, 826 | layer_scale_conv=layer_scale_conv, 827 | transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])), 828 | ) 829 | self.levels.append(level) 830 | self.norm = nn.BatchNorm2d(num_features) 831 | self.avgpool = nn.AdaptiveAvgPool2d(1) 832 | self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() 833 | self.apply(self._init_weights) 834 | 835 | def _init_weights(self, m): 836 | if isinstance(m, nn.Linear): 837 | trunc_normal_(m.weight, std=.02) 838 | if isinstance(m, nn.Linear) and m.bias is not None: 839 | nn.init.constant_(m.bias, 0) 840 | elif isinstance(m, nn.LayerNorm): 841 | nn.init.constant_(m.bias, 0) 842 | nn.init.constant_(m.weight, 1.0) 843 | elif isinstance(m, LayerNorm2d): 844 | nn.init.constant_(m.bias, 0) 845 | nn.init.constant_(m.weight, 1.0) 846 | elif isinstance(m, nn.BatchNorm2d): 847 | nn.init.ones_(m.weight) 848 | nn.init.zeros_(m.bias) 849 | 850 | @torch.jit.ignore 851 | def no_weight_decay_keywords(self): 852 | return {'rpb'} 853 | 854 | def forward_features(self, x): 855 | x = self.patch_embed(x) # [2,3,224,224] -> [2,80,56,56] 856 | # each layer output shape print: 857 | # level 1: [2,80,56,56] level 2: [2,160,28,28] 858 | # level 3: [2,320,14,14] level 4: [2,640,7,7] 859 | for level in self.levels: 860 | x = level(x) 861 | x = self.norm(x) # [2,640,7,7] 862 | x = self.avgpool(x) # [2,640,1,1] 863 | x = torch.flatten(x, 1) # [2,640] 864 | return x 865 | 866 | def forward(self, x): 867 | x = self.forward_features(x) # [2,640] 868 | x = self.head(x) # # [2, 1000] 869 | return x 870 | 871 | def _load_state_dict(self, 872 | pretrained, 873 | strict: bool = False): 874 | _load_checkpoint(self, 875 | pretrained, 876 | strict=strict) 877 | 878 | 879 | 880 | 881 | 882 | 883 | class MambaVision_sim(nn.Module): 884 | """ 885 | MambaVision sim for unet, 886 | """ 887 | 888 | def __init__(self, 889 | dim, 890 | in_dim, 891 | depths, 892 | window_size, 893 | mlp_ratio, 894 | num_heads, 895 | drop_path_rate=0.2, 896 | in_chans=3, 897 | num_classes=1, 898 | qkv_bias=True, 899 | qk_scale=None, 900 | drop_rate=0., 901 | attn_drop_rate=0., 902 | layer_scale=None, 903 | layer_scale_conv=None, 904 | **kwargs): 905 | """ 906 | Args: 907 | dim: feature size dimension. 908 | depths: number of layers in each stage. 909 | window_size: window size in each stage. 910 | mlp_ratio: MLP ratio. 911 | num_heads: number of heads in each stage. 912 | drop_path_rate: drop path rate. 913 | in_chans: number of input channels. 914 | num_classes: number of classes. 915 | qkv_bias: bool argument for query, key, value learnable bias. 916 | qk_scale: bool argument to scaling query, key. 917 | drop_rate: dropout rate. 918 | attn_drop_rate: attention dropout rate. 919 | norm_layer: normalization layer. 920 | layer_scale: layer scaling coefficient. 921 | layer_scale_conv: conv layer scaling coefficient. 922 | """ 923 | super().__init__() 924 | num_features = int(dim * 2 ** (len(depths) - 1)) 925 | self.num_classes = num_classes 926 | self.num_layers = len(depths) 927 | self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim) 928 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 929 | self.levels = nn.ModuleList() 930 | self.pos_drop = nn.Dropout(p=drop_rate) # train drop rate 931 | self.encoder_depths = depths 932 | self.decoder_depths = depths[::-1] 933 | 934 | 935 | for i in range(len(depths)): # i --> 0,1,2,3 936 | conv = True if (i == 0 or i == 1) else False 937 | level = MambaVisionLayer(dim=int(dim * 2 ** i), # up layer 和 这个layer 的dim 是反过来的 938 | depth=depths[i], 939 | num_heads=num_heads[i], 940 | window_size=window_size[i], 941 | mlp_ratio=mlp_ratio, 942 | qkv_bias=qkv_bias, 943 | qk_scale=qk_scale, 944 | conv=conv, 945 | drop=drop_rate, 946 | attn_drop=attn_drop_rate, 947 | drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], 948 | downsample=(i < 3), 949 | layer_scale=layer_scale, 950 | layer_scale_conv=layer_scale_conv, 951 | transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])), 952 | ) 953 | self.levels.append(level) 954 | 955 | 956 | # dim=int( dim * 2 ** (3-i) ) 957 | # depth = depths[3-i] 958 | # num_heads=num_heads[3-i] 959 | # window_size=window_size[3-i] 960 | # drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])] 961 | self.layers_up = nn.ModuleList() 962 | for i in range(len(depths)): 963 | up_layer = MambaVisionLayer_up(dim=int( dim * 2 ** (3-i) ), 964 | depth=depths[3-i], 965 | num_heads=num_heads[3-i], 966 | window_size=window_size[3-i], 967 | mlp_ratio=mlp_ratio, 968 | qkv_bias=qkv_bias, 969 | qk_scale=qk_scale, 970 | conv=conv, 971 | drop=drop_rate, 972 | attn_drop=attn_drop_rate, 973 | drop_path=dpr[sum(self.decoder_depths[:i]):sum(self.decoder_depths[:i + 1])], 974 | upsample=(i != 0), 975 | layer_scale=layer_scale, 976 | layer_scale_conv=layer_scale_conv, 977 | transformer_blocks=list(range(depths[3-i]//2+1, depths[3-i])) if depths[3-i]%2!=0 else list(range(depths[3-i]//2, depths[3-i])), 978 | ) 979 | self.layers_up.append(up_layer) 980 | 981 | 982 | # classification 983 | # self.norm = nn.BatchNorm2d(num_features) 984 | # self.avgpool = nn.AdaptiveAvgPool2d(1) 985 | # self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() 986 | self.norm_layer = nn.LayerNorm 987 | self.final_up = Final_PatchExpand2D(dim=dim, dim_scale=4, norm_layer=self.norm_layer) 988 | self.final_conv = nn.Conv2d(dim//4, num_classes, 1) 989 | 990 | 991 | self.apply(self._init_weights) 992 | 993 | def _init_weights(self, m): 994 | if isinstance(m, nn.Linear): 995 | trunc_normal_(m.weight, std=.02) 996 | if isinstance(m, nn.Linear) and m.bias is not None: 997 | nn.init.constant_(m.bias, 0) 998 | elif isinstance(m, nn.LayerNorm): 999 | nn.init.constant_(m.bias, 0) 1000 | nn.init.constant_(m.weight, 1.0) 1001 | elif isinstance(m, LayerNorm2d): 1002 | nn.init.constant_(m.bias, 0) 1003 | nn.init.constant_(m.weight, 1.0) 1004 | elif isinstance(m, nn.BatchNorm2d): 1005 | nn.init.ones_(m.weight) 1006 | nn.init.zeros_(m.bias) 1007 | 1008 | @torch.jit.ignore 1009 | def no_weight_decay_keywords(self): 1010 | return {'rpb'} 1011 | 1012 | def forward_features(self, x): 1013 | skip_list = [] 1014 | x = self.patch_embed(x) # [2,3,224,224] -> [2,80,56,56] 1015 | x = self.pos_drop(x) 1016 | # each layer output shape print: 1017 | # level 0: [2,80,56,56] level 1: [2,160,28,28] 1018 | # level 2: [2,320,14,14] level 3: [2,640,7,7] 1019 | for level in self.levels: 1020 | skip_list.append(x) 1021 | x = level(x) 1022 | return x, skip_list 1023 | 1024 | 1025 | def forward_features_up(self, x, skip_list): 1026 | for inx, layer_up in enumerate(self.layers_up): 1027 | if inx == 0: 1028 | x = layer_up(x) 1029 | else: 1030 | x = layer_up(x+skip_list[-inx]) 1031 | 1032 | return x 1033 | 1034 | def forward_final(self, x): 1035 | x = self.final_up(x) 1036 | # x = x.permute(0,3,1,2) 1037 | x = self.final_conv(x) 1038 | return x 1039 | 1040 | 1041 | def forward(self, x): 1042 | x, skip_list = self.forward_features(x) 1043 | x = self.forward_features_up(x, skip_list) 1044 | x = self.forward_final(x) 1045 | 1046 | if self.num_classes == 1: 1047 | return torch.sigmoid(x) 1048 | return x 1049 | 1050 | # def forward(self, x): 1051 | # x = self.forward_features(x) 1052 | # x = self.head(x) 1053 | # return x 1054 | 1055 | # def _load_state_dict(self, 1056 | # pretrained, 1057 | # strict: bool = False): 1058 | # _load_checkpoint(self, 1059 | # pretrained, 1060 | # strict=strict) 1061 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | @register_pip_model 1068 | @register_model 1069 | def mamba_vision_T(pretrained=False, **kwargs): 1070 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar") 1071 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict() 1072 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1073 | model = MambaVision(depths=[1, 3, 8, 4], 1074 | num_heads=[2, 4, 8, 16], 1075 | window_size=[8, 8, 14, 7], 1076 | dim=80, 1077 | in_dim=32, 1078 | mlp_ratio=4, 1079 | resolution=224, 1080 | drop_path_rate=0.2, 1081 | **kwargs) 1082 | model.pretrained_cfg = pretrained_cfg 1083 | model.default_cfg = model.pretrained_cfg 1084 | if pretrained: 1085 | # download model 1086 | if not Path(model_path).is_file(): 1087 | url = model.default_cfg['url'] 1088 | torch.hub.download_url_to_file(url=url, dst=model_path) 1089 | model._load_state_dict(model_path) 1090 | return model 1091 | 1092 | 1093 | @register_pip_model 1094 | @register_model 1095 | def mamba_vision_T2(pretrained=False, **kwargs): 1096 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar") 1097 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict() 1098 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1099 | model = MambaVision(depths=[1, 3, 11, 4], 1100 | num_heads=[2, 4, 8, 16], 1101 | window_size=[8, 8, 14, 7], 1102 | dim=80, 1103 | in_dim=32, 1104 | mlp_ratio=4, 1105 | resolution=224, 1106 | drop_path_rate=0.2, 1107 | **kwargs) 1108 | model.pretrained_cfg = pretrained_cfg 1109 | model.default_cfg = model.pretrained_cfg 1110 | if pretrained: 1111 | if not Path(model_path).is_file(): 1112 | url = model.default_cfg['url'] 1113 | torch.hub.download_url_to_file(url=url, dst=model_path) 1114 | model._load_state_dict(model_path) 1115 | return model 1116 | 1117 | 1118 | @register_pip_model 1119 | @register_model 1120 | def mamba_vision_S(pretrained=False, **kwargs): 1121 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar") 1122 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict() 1123 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1124 | model = MambaVision(depths=[3, 3, 7, 5], 1125 | num_heads=[2, 4, 8, 16], 1126 | window_size=[8, 8, 14, 7], 1127 | dim=96, 1128 | in_dim=64, 1129 | mlp_ratio=4, 1130 | resolution=224, 1131 | drop_path_rate=0.2, 1132 | **kwargs) 1133 | model.pretrained_cfg = pretrained_cfg 1134 | model.default_cfg = model.pretrained_cfg 1135 | if pretrained: 1136 | if not Path(model_path).is_file(): 1137 | url = model.default_cfg['url'] 1138 | torch.hub.download_url_to_file(url=url, dst=model_path) 1139 | model._load_state_dict(model_path) 1140 | return model 1141 | 1142 | 1143 | @register_pip_model 1144 | @register_model 1145 | def mamba_vision_B(pretrained=False, **kwargs): 1146 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar") 1147 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict() 1148 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1149 | model = MambaVision(depths=[3, 3, 10, 5], 1150 | num_heads=[2, 4, 8, 16], 1151 | window_size=[8, 8, 14, 7], 1152 | dim=128, 1153 | in_dim=64, 1154 | mlp_ratio=4, 1155 | resolution=224, 1156 | drop_path_rate=0.3, 1157 | layer_scale=1e-5, 1158 | layer_scale_conv=None, 1159 | **kwargs) 1160 | model.pretrained_cfg = pretrained_cfg 1161 | model.default_cfg = model.pretrained_cfg 1162 | if pretrained: 1163 | if not Path(model_path).is_file(): 1164 | url = model.default_cfg['url'] 1165 | torch.hub.download_url_to_file(url=url, dst=model_path) 1166 | model._load_state_dict(model_path) 1167 | return model 1168 | 1169 | 1170 | @register_pip_model 1171 | @register_model 1172 | def mamba_vision_L(pretrained=False, **kwargs): 1173 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar") 1174 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict() 1175 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1176 | model = MambaVision(depths=[3, 3, 10, 5], 1177 | num_heads=[4, 8, 16, 32], 1178 | window_size=[8, 8, 14, 7], 1179 | dim=196, 1180 | in_dim=64, 1181 | mlp_ratio=4, 1182 | resolution=224, 1183 | drop_path_rate=0.3, 1184 | layer_scale=1e-5, 1185 | layer_scale_conv=None, 1186 | **kwargs) 1187 | model.pretrained_cfg = pretrained_cfg 1188 | model.default_cfg = model.pretrained_cfg 1189 | if pretrained: 1190 | if not Path(model_path).is_file(): 1191 | url = model.default_cfg['url'] 1192 | torch.hub.download_url_to_file(url=url, dst=model_path) 1193 | model._load_state_dict(model_path) 1194 | return model 1195 | 1196 | 1197 | @register_pip_model 1198 | @register_model 1199 | def mamba_vision_L2(pretrained=False, **kwargs): 1200 | model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar") 1201 | pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict() 1202 | update_args(pretrained_cfg, kwargs, kwargs_filter=None) 1203 | model = MambaVision(depths=[3, 3, 12, 5], 1204 | num_heads=[4, 8, 16, 32], 1205 | window_size=[8, 8, 14, 7], 1206 | dim=196, 1207 | in_dim=64, 1208 | mlp_ratio=4, 1209 | resolution=224, 1210 | drop_path_rate=0.3, 1211 | layer_scale=1e-5, 1212 | layer_scale_conv=None, 1213 | **kwargs) 1214 | model.pretrained_cfg = pretrained_cfg 1215 | model.default_cfg = model.pretrained_cfg 1216 | if pretrained: 1217 | if not Path(model_path).is_file(): 1218 | url = model.default_cfg['url'] 1219 | torch.hub.download_url_to_file(url=url, dst=model_path) 1220 | model._load_state_dict(model_path) 1221 | return model 1222 | 1223 | 1224 | 1225 | 1226 | 1227 | class PatchExpand2D_sim(nn.Module): 1228 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 1229 | super().__init__() 1230 | self.dim = dim*2 1231 | self.dim_scale = dim_scale 1232 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 1233 | self.norm = norm_layer(self.dim // dim_scale) 1234 | 1235 | def forward(self, x): 1236 | B, C, H, W = x.shape 1237 | x = x.permute(0,2,3,1) # b h w c 1238 | x = self.expand(x) 1239 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 1240 | x= self.norm(x) 1241 | x = x.permute(0,3,1,2) # b c h w 1242 | 1243 | return x 1244 | 1245 | 1246 | 1247 | 1248 | class PatchExpand2D(nn.Module): 1249 | def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): 1250 | super().__init__() 1251 | self.dim = dim*2 1252 | self.dim_scale = dim_scale 1253 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 1254 | self.norm = norm_layer(self.dim // dim_scale) 1255 | 1256 | def forward(self, x): 1257 | B, H, W, C = x.shape 1258 | x = self.expand(x) 1259 | 1260 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 1261 | x= self.norm(x) 1262 | 1263 | return x 1264 | 1265 | 1266 | class Final_PatchExpand2D(nn.Module): 1267 | def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): 1268 | super().__init__() 1269 | self.dim = dim 1270 | self.dim_scale = dim_scale 1271 | self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) 1272 | self.norm = norm_layer(self.dim // dim_scale) 1273 | 1274 | def forward(self, x): 1275 | B, C, H, W = x.shape 1276 | x = x.permute(0,2,3,1) # b h w c 1277 | x = self.expand(x) 1278 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) 1279 | x= self.norm(x) 1280 | x = x.permute(0,3,1,2) # b c h w 1281 | 1282 | return x 1283 | 1284 | 1285 | 1286 | 1287 | 1288 | if __name__ == "__main__": 1289 | pretrained_path = "/root/workspace/code/MambaVision/mambavision/pretrained/mambavision_tiny_1k.pth.tar" 1290 | model = mamba_vision_T(model_path = pretrained_path).cuda() 1291 | 1292 | 1293 | 1294 | x = torch.randn((2,3,224,224)).cuda() 1295 | ys = model(x) 1296 | for y in ys: 1297 | print(y.shape) 1298 | 1299 | 1300 | 1301 | -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scripts to register and load model, adopted from: 3 | https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_registry.py 4 | https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_factory.py 5 | Hacked together by / Copyright 2023 Ross Wightman 6 | """ 7 | import torch 8 | 9 | import os 10 | from collections import OrderedDict 11 | from copy import deepcopy 12 | from typing import Any 13 | 14 | import sys 15 | import re 16 | import fnmatch 17 | from collections import defaultdict 18 | from copy import deepcopy 19 | 20 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 21 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] 22 | 23 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 24 | _model_to_module = {} # mapping of model names to module names 25 | _model_entrypoints = {} # mapping of model names to entrypoint fns 26 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 27 | _model_default_cfgs = dict() # central repo for model default_cfgs 28 | 29 | 30 | def register_pip_model(fn): 31 | # lookup containing module 32 | mod = sys.modules[fn.__module__] 33 | module_name_split = fn.__module__.split('.') 34 | module_name = module_name_split[-1] if len(module_name_split) else '' 35 | 36 | # add model to __all__ in module 37 | model_name = fn.__name__ 38 | if hasattr(mod, '__all__'): 39 | mod.__all__.append(model_name) 40 | else: 41 | mod.__all__ = [model_name] 42 | 43 | # add entries to registry dict/sets 44 | _model_entrypoints[model_name] = fn 45 | _model_to_module[model_name] = module_name 46 | _module_to_models[module_name].add(model_name) 47 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 48 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 49 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 50 | # entrypoints or non-matching combos 51 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 52 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) 53 | if has_pretrained: 54 | _model_has_pretrained.add(model_name) 55 | return fn 56 | 57 | 58 | def _natural_key(string_): 59 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 60 | 61 | 62 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 63 | """ Return list of available model names, sorted alphabetically 64 | 65 | Args: 66 | filter (str) - Wildcard filter string that works with fnmatch 67 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 68 | pretrained (bool) - Include only models with pretrained weights if True 69 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 70 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 71 | 72 | Example: 73 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 74 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 75 | """ 76 | if module: 77 | all_models = list(_module_to_models[module]) 78 | else: 79 | all_models = _model_entrypoints.keys() 80 | if filter: 81 | models = [] 82 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 83 | for f in include_filters: 84 | include_models = fnmatch.filter(all_models, f) # include these models 85 | if len(include_models): 86 | models = set(models).union(include_models) 87 | else: 88 | models = all_models 89 | if exclude_filters: 90 | if not isinstance(exclude_filters, (tuple, list)): 91 | exclude_filters = [exclude_filters] 92 | for xf in exclude_filters: 93 | exclude_models = fnmatch.filter(models, xf) # exclude these models 94 | if len(exclude_models): 95 | models = set(models).difference(exclude_models) 96 | if pretrained: 97 | models = _model_has_pretrained.intersection(models) 98 | if name_matches_cfg: 99 | models = set(_model_default_cfgs).intersection(models) 100 | return list(sorted(models, key=_natural_key)) 101 | 102 | 103 | def is_model(model_name): 104 | """ Check if a model name exists 105 | """ 106 | return model_name in _model_entrypoints 107 | 108 | 109 | def model_entrypoint(model_name): 110 | """Fetch a model entrypoint for specified model name 111 | """ 112 | return _model_entrypoints[model_name] 113 | 114 | 115 | def list_modules(): 116 | """ Return list of module names that contain models / model entrypoints 117 | """ 118 | modules = _module_to_models.keys() 119 | return list(sorted(modules)) 120 | 121 | 122 | def is_model_in_modules(model_name, module_names): 123 | """Check if a model exists within a subset of modules 124 | Args: 125 | model_name (str) - name of model to check 126 | module_names (tuple, list, set) - names of modules to search in 127 | """ 128 | assert isinstance(module_names, (tuple, list, set)) 129 | return any(model_name in _module_to_models[n] for n in module_names) 130 | 131 | 132 | def has_model_default_key(model_name, cfg_key): 133 | """ Query model default_cfgs for existence of a specific key. 134 | """ 135 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: 136 | return True 137 | return False 138 | 139 | 140 | def is_model_default_key(model_name, cfg_key): 141 | """ Return truthy value for specified model default_cfg key, False if does not exist. 142 | """ 143 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): 144 | return True 145 | return False 146 | 147 | 148 | def get_model_default_value(model_name, cfg_key): 149 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 150 | """ 151 | if model_name in _model_default_cfgs: 152 | return _model_default_cfgs[model_name].get(cfg_key, None) 153 | else: 154 | return None 155 | 156 | 157 | def is_model_pretrained(model_name): 158 | return model_name in _model_has_pretrained 159 | 160 | 161 | def load_state_dict(checkpoint_path, use_ema=False): 162 | if checkpoint_path and os.path.isfile(checkpoint_path): 163 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 164 | state_dict_key = 'state_dict' 165 | if isinstance(checkpoint, dict): 166 | if use_ema and 'state_dict_ema' in checkpoint: 167 | state_dict_key = 'state_dict_ema' 168 | if state_dict_key and state_dict_key in checkpoint: 169 | new_state_dict = OrderedDict() 170 | for k, v in checkpoint[state_dict_key].items(): 171 | # strip `module.` prefix 172 | name = k[7:] if k.startswith('module') else k 173 | new_state_dict[name] = v 174 | state_dict = new_state_dict 175 | else: 176 | state_dict = checkpoint 177 | print("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 178 | return state_dict 179 | else: 180 | print("No checkpoint found at '{}'".format(checkpoint_path)) 181 | raise FileNotFoundError() 182 | 183 | 184 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): 185 | if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): 186 | # numpy checkpoint, try to load via model specific load_pretrained fn 187 | if hasattr(model, 'load_pretrained'): 188 | model.load_pretrained(checkpoint_path) 189 | else: 190 | raise NotImplementedError('Model cannot load numpy checkpoint') 191 | return 192 | state_dict = load_state_dict(checkpoint_path, use_ema) 193 | model.load_state_dict(state_dict, strict=strict) 194 | 195 | def create_model( 196 | model_name, 197 | pretrained=False, 198 | checkpoint_path='', 199 | **kwargs): 200 | create_fn = model_entrypoint(model_name) 201 | model = create_fn(pretrained=pretrained, **kwargs) 202 | if checkpoint_path: 203 | load_checkpoint(model, checkpoint_path) 204 | 205 | return model --------------------------------------------------------------------------------