├── models ├── __init__.py ├── t2t_vit_block.py ├── token_transformer.py ├── configs.py ├── localvit_pvt.py ├── token_performer.py ├── localvit_tnt.py ├── modeling_resnet.py ├── localvit_t2t.py ├── deit.py ├── tnt.py ├── localvit_swin.py ├── t2t_vit.py ├── localvit.py ├── tnt_moex.py ├── pvt.py ├── swin_transformer.py └── swin_moex.py ├── README.md ├── LNL.py └── LNL_MoEx.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import deit 2 | from models import tnt 3 | from models import pvt 4 | from models import t2t_vit 5 | from models import localvit 6 | from models import localvit_tnt 7 | from models import localvit_pvt 8 | from models import localvit_t2t 9 | from models import swin_transformer 10 | from models import localvit_swin -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Transformer with Locality Inductive Bias and Feature Normalization 2 | 3 | This repo is the official implementation of ["Locality iN Locality"](https://arxiv.org/abs/2301.11553). 4 | 5 | ## Train & Test --- Prepare data 6 | Please go to ["Instructions.ipynb"](https://github.com/Omid-Nejati/Locality-iN-Locality/blob/main/Instructions.ipynb) for complete detail on dataset preparation and Train/Test procedure or follow the instructions below. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-phpyKQSLdmwISrkVqprtXgEBMP3fFNz?usp=sharing) 7 | 8 | ## Citation 9 | If you find this project useful in your research, please consider cite: 10 | ``` 11 | @article{manzari2023robust, 12 | title={Robust transformer with locality inductive bias and feature normalization}, 13 | author={Manzari, Omid Nejati and Kashiani, Hossein and Dehkordi, Hojat Asgarian and Shokouhi, Shahriar B}, 14 | journal={Engineering Science and Technology, an International Journal}, 15 | volume={38}, 16 | pages={101320}, 17 | year={2023}, 18 | publisher={Elsevier} 19 | } 20 | ``` 21 | 22 | ## Contact Information 23 | 24 | For any inquiries or questions regarding the code, please feel free to contact us directly via email: 25 | 26 | - Omid Nejaty: [omid.nejaty@gmail.com](mailto:omid.nejaty@gmail.com) 27 | - Hossein kashiani: [hkashia@clemson.edu](mailto:hkashia@clemson.edu) 28 | -------------------------------------------------------------------------------- /models/t2t_vit_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Tokens-to-token vit: Training vision transformers from scratch on imagenet". 6 | Code borrowed from https://github.com/yitu-opensource/T2T-ViT 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from timm.models.layers import DropPath 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | 40 | self.scale = qk_scale or head_dim ** -0.5 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.attn_drop = nn.Dropout(attn_drop) 44 | self.proj = nn.Linear(dim, dim) 45 | self.proj_drop = nn.Dropout(proj_drop) 46 | 47 | def forward(self, x): 48 | B, N, C = x.shape 49 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | q, k, v = qkv[0], qkv[1], qkv[2] 51 | 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | 56 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 57 | x = self.proj(x) 58 | x = self.proj_drop(x) 59 | return x 60 | 61 | 62 | class Block(nn.Module): 63 | 64 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 65 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 66 | super().__init__() 67 | self.norm1 = norm_layer(dim) 68 | self.attn = Attention( 69 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 70 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 71 | self.norm2 = norm_layer(dim) 72 | mlp_hidden_dim = int(dim * mlp_ratio) 73 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 74 | 75 | def forward(self, x): 76 | x = x + self.drop_path(self.attn(self.norm1(x))) 77 | x = x + self.drop_path(self.mlp(self.norm2(x))) 78 | return x 79 | 80 | 81 | def get_sinusoid_encoding(n_position, d_hid): 82 | ''' Sinusoid position encoding table ''' 83 | 84 | def get_position_angle_vec(position): 85 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 86 | 87 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 88 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 89 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 90 | 91 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) -------------------------------------------------------------------------------- /models/token_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Tokens-to-token vit: Training vision transformers from scratch on imagenet". 6 | Code borrowed from https://github.com/yitu-opensource/T2T-ViT 7 | 8 | Take the standard Transformer as T2T Transformer 9 | """ 10 | import torch.nn as nn 11 | from timm.models.layers import DropPath 12 | from models.t2t_vit_block import Mlp 13 | from models.localvit import LocalityFeedForward 14 | import math 15 | import torch 16 | 17 | 18 | class Attention(nn.Module): 19 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 20 | super().__init__() 21 | self.num_heads = num_heads 22 | self.in_dim = in_dim 23 | head_dim = dim // num_heads 24 | self.scale = qk_scale or head_dim ** -0.5 25 | 26 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) 27 | self.attn_drop = nn.Dropout(attn_drop) 28 | self.proj = nn.Linear(in_dim, in_dim) 29 | self.proj_drop = nn.Dropout(proj_drop) 30 | 31 | def forward(self, x): 32 | B, N, C = x.shape 33 | 34 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) 35 | q, k, v = qkv[0], qkv[1], qkv[2] 36 | 37 | attn = (q @ k.transpose(-2, -1)) * self.scale 38 | attn = attn.softmax(dim=-1) 39 | attn = self.attn_drop(attn) 40 | 41 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) 42 | x = self.proj(x) 43 | x = self.proj_drop(x) 44 | 45 | # skip connection 46 | x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection 47 | 48 | return x 49 | 50 | class Token_transformer(nn.Module): 51 | 52 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 53 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 54 | super().__init__() 55 | self.norm1 = norm_layer(dim) 56 | self.attn = Attention( 57 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 59 | self.norm2 = norm_layer(in_dim) 60 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) 61 | 62 | def forward(self, x): 63 | x = self.attn(self.norm1(x)) 64 | x = x + self.drop_path(self.mlp(self.norm2(x))) 65 | return x 66 | 67 | 68 | class Token_transformer_local(nn.Module): 69 | 70 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 71 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 72 | super().__init__() 73 | self.norm1 = norm_layer(dim) 74 | self.attn = Attention( 75 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 76 | self.conv = LocalityFeedForward(in_dim, in_dim, 1, mlp_ratio, act='hs', reduction=in_dim//4) 77 | 78 | def forward(self, x): 79 | x = self.attn(self.norm1(x)) 80 | 81 | batch_size, num_token, embed_dim = x.shape # (B, 197, dim) 82 | patch_size = int(math.sqrt(num_token)) 83 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 84 | x = self.conv(x).flatten(2).transpose(1, 2) # (B, 196, dim) 85 | return x 86 | 87 | -------------------------------------------------------------------------------- /models/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | 18 | def get_testing(): 19 | """Returns a minimal configuration for testing.""" 20 | config = ml_collections.ConfigDict() 21 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 22 | config.hidden_size = 1 23 | config.transformer = ml_collections.ConfigDict() 24 | config.transformer.mlp_dim = 1 25 | config.transformer.num_heads = 1 26 | config.transformer.num_layers = 1 27 | config.transformer.attention_dropout_rate = 0.0 28 | config.transformer.dropout_rate = 0.1 29 | config.classifier = 'token' 30 | config.representation_size = None 31 | return config 32 | 33 | 34 | def get_b16_config(): 35 | """Returns the ViT-B/16 configuration.""" 36 | config = ml_collections.ConfigDict() 37 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 38 | config.hidden_size = 768 39 | config.transformer = ml_collections.ConfigDict() 40 | config.transformer.mlp_dim = 3072 41 | config.transformer.num_heads = 12 42 | config.transformer.num_layers = 12 43 | config.transformer.attention_dropout_rate = 0.0 44 | config.transformer.dropout_rate = 0.1 45 | config.classifier = 'token' 46 | config.representation_size = None 47 | return config 48 | 49 | 50 | def get_r50_b16_config(): 51 | """Returns the Resnet50 + ViT-B/16 configuration.""" 52 | config = get_b16_config() 53 | del config.patches.size 54 | config.patches.grid = (14, 14) 55 | config.resnet = ml_collections.ConfigDict() 56 | config.resnet.num_layers = (3, 4, 9) 57 | config.resnet.width_factor = 1 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | return config 66 | 67 | 68 | def get_l16_config(): 69 | """Returns the ViT-L/16 configuration.""" 70 | config = ml_collections.ConfigDict() 71 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 72 | config.hidden_size = 1024 73 | config.transformer = ml_collections.ConfigDict() 74 | config.transformer.mlp_dim = 4096 75 | config.transformer.num_heads = 16 76 | config.transformer.num_layers = 24 77 | config.transformer.attention_dropout_rate = 0.0 78 | config.transformer.dropout_rate = 0.1 79 | config.classifier = 'token' 80 | config.representation_size = None 81 | return config 82 | 83 | 84 | def get_l32_config(): 85 | """Returns the ViT-L/32 configuration.""" 86 | config = get_l16_config() 87 | config.patches.size = (32, 32) 88 | return config 89 | 90 | 91 | def get_h14_config(): 92 | """Returns the ViT-L/16 configuration.""" 93 | config = ml_collections.ConfigDict() 94 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 95 | config.hidden_size = 1280 96 | config.transformer = ml_collections.ConfigDict() 97 | config.transformer.mlp_dim = 5120 98 | config.transformer.num_heads = 16 99 | config.transformer.num_layers = 32 100 | config.transformer.attention_dropout_rate = 0.0 101 | config.transformer.dropout_rate = 0.1 102 | config.classifier = 'token' 103 | config.representation_size = None 104 | return config 105 | -------------------------------------------------------------------------------- /models/localvit_pvt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Introducing locality mechanism to "Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions". 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | from timm.models.layers import DropPath, trunc_normal_ 13 | from timm.models.registry import register_model 14 | from timm.models.vision_transformer import _cfg 15 | from models.localvit import LocalityFeedForward 16 | from models.pvt import Attention, PyramidVisionTransformer 17 | import math 18 | 19 | __all__ = [ 20 | 'localvit_pvt_tiny' 21 | ] 22 | 23 | 24 | class Block(nn.Module): 25 | 26 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 27 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 28 | super().__init__() 29 | self.sr_ratio = sr_ratio 30 | self.norm1 = norm_layer(dim) 31 | self.attn = Attention( 32 | dim, 33 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 34 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 35 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 36 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 37 | # self.norm2 = norm_layer(dim) 38 | # mlp_hidden_dim = int(dim * mlp_ratio) 39 | # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 40 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim) 41 | 42 | def forward(self, x, H, W): 43 | batch_size, num_token, embed_dim = x.shape # (B, 197, dim) 44 | patch_size = int(math.sqrt(num_token)) 45 | 46 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 47 | # x = x + self.drop_path(self.mlp(self.norm2(x))) 48 | 49 | if self.sr_ratio == 1: 50 | cls_token, x = torch.split(x, [1, num_token - 1], dim=1) # (B, 1, dim), (B, 196, dim) 51 | # print(cls_token.shape, x.shape) 52 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 53 | x = self.conv(x).flatten(2).transpose(1, 2) # (B, 196, dim) 54 | x = torch.cat([cls_token, x], dim=1) 55 | else: 56 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 57 | x = self.conv(x).flatten(2).transpose(1, 2) # (B, 196, dim) 58 | return x 59 | 60 | 61 | class LocalViT_PVT(PyramidVisionTransformer): 62 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 63 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 64 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 65 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 66 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dims, 67 | num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate, 68 | attn_drop_rate, drop_path_rate, norm_layer, 69 | depths, sr_ratios) 70 | 71 | # transformer encoder 72 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 73 | cur = 0 74 | self.block1 = nn.ModuleList([Block( 75 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 76 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 77 | sr_ratio=sr_ratios[0]) 78 | for i in range(depths[0])]) 79 | 80 | cur += depths[0] 81 | self.block2 = nn.ModuleList([Block( 82 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 83 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 84 | sr_ratio=sr_ratios[1]) 85 | for i in range(depths[1])]) 86 | 87 | cur += depths[1] 88 | self.block3 = nn.ModuleList([Block( 89 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 90 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 91 | sr_ratio=sr_ratios[2]) 92 | for i in range(depths[2])]) 93 | 94 | cur += depths[2] 95 | self.block4 = nn.ModuleList([Block( 96 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 97 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 98 | sr_ratio=sr_ratios[3]) 99 | for i in range(depths[3])]) 100 | 101 | # init weights 102 | self.apply(self._init_weights) 103 | 104 | 105 | @register_model 106 | def localvit_pvt_tiny(pretrained=False, **kwargs): 107 | model = LocalViT_PVT( 108 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 109 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 110 | **kwargs) 111 | model.default_cfg = _cfg() 112 | 113 | return model 114 | -------------------------------------------------------------------------------- /models/token_performer.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Author: Omid Nejati 4 | Email: omid_nejaty@alumni.iust.ac.ir 5 | 6 | Implementation of "Tokens-to-token vit: Training vision transformers from scratch on imagenet". 7 | Code borrowed from https://github.com/yitu-opensource/T2T-ViT 8 | 9 | Take Performer as T2T Transformer 10 | """ 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from models.localvit import LocalityFeedForward 15 | 16 | class Token_performer(nn.Module): 17 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): 18 | super().__init__() 19 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 20 | self.kqv = nn.Linear(dim, 3 * self.emb) 21 | self.dp = nn.Dropout(dp1) 22 | self.proj = nn.Linear(self.emb, self.emb) 23 | self.head_cnt = head_cnt 24 | self.norm1 = nn.LayerNorm(dim) 25 | self.norm2 = nn.LayerNorm(self.emb) 26 | self.epsilon = 1e-8 # for stable in division 27 | 28 | self.mlp = nn.Sequential( 29 | nn.Linear(self.emb, 1 * self.emb), 30 | nn.GELU(), 31 | nn.Linear(1 * self.emb, self.emb), 32 | nn.Dropout(dp2), 33 | ) 34 | 35 | self.m = int(self.emb * kernel_ratio) 36 | self.w = torch.randn(self.m, self.emb) 37 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 38 | 39 | def prm_exp(self, x): 40 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 41 | # and Simo Ryu (https://github.com/cloneofsimo) 42 | # ==== positive random features for gaussian kernels ==== 43 | # x = (B, T, hs) 44 | # w = (m, hs) 45 | # return : x : B, T, m 46 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 47 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 48 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 49 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 50 | 51 | return torch.exp(wtx - xd) / math.sqrt(self.m) 52 | 53 | def single_attn(self, x): 54 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) 55 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) 56 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 57 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) 58 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 59 | # skip connection 60 | y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 61 | 62 | return y 63 | 64 | def forward(self, x): 65 | x = self.single_attn(self.norm1(x)) 66 | x = x + self.mlp(self.norm2(x)) 67 | return x 68 | 69 | 70 | class Token_performer_local(nn.Module): 71 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): 72 | super().__init__() 73 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 74 | self.kqv = nn.Linear(dim, 3 * self.emb) 75 | self.dp = nn.Dropout(dp1) 76 | self.proj = nn.Linear(self.emb, self.emb) 77 | self.head_cnt = head_cnt 78 | self.norm1 = nn.LayerNorm(dim) 79 | self.norm2 = nn.LayerNorm(self.emb) 80 | self.epsilon = 1e-8 # for stable in division 81 | 82 | self.conv = LocalityFeedForward(in_dim, in_dim, 1, expand_ratio=1, act='hs', reduction=in_dim//4) 83 | 84 | self.m = int(self.emb * kernel_ratio) 85 | self.w = torch.randn(self.m, self.emb) 86 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 87 | 88 | def prm_exp(self, x): 89 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 90 | # and Simo Ryu (https://github.com/cloneofsimo) 91 | # ==== positive random features for gaussian kernels ==== 92 | # x = (B, T, hs) 93 | # w = (m, hs) 94 | # return : x : B, T, m 95 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 96 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 97 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 98 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 99 | 100 | return torch.exp(wtx - xd) / math.sqrt(self.m) 101 | 102 | def single_attn(self, x): 103 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) 104 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) 105 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 106 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) 107 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 108 | # skip connection 109 | y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 110 | 111 | return y 112 | 113 | def forward(self, x): 114 | 115 | x = self.single_attn(self.norm1(x)) 116 | 117 | batch_size, num_token, embed_dim = x.shape # (B, 197, dim) 118 | patch_size = int(math.sqrt(num_token)) 119 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 120 | x = self.conv(x).flatten(2).transpose(1, 2) 121 | return x 122 | 123 | -------------------------------------------------------------------------------- /LNL.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | LNL : Introducing locality mechanism into Transformer in Transformer (TNT) 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.models.helpers import load_pretrained 11 | from timm.models.layers import DropPath, trunc_normal_ 12 | from timm.models.vision_transformer import Mlp 13 | from timm.models.registry import register_model 14 | from models.localvit import LocalityFeedForward 15 | from models.tnt import Attention, TNT 16 | import math 17 | 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 25 | 'first_conv': 'pixel_embed.proj', 'classifier': 'head', 26 | **kwargs 27 | } 28 | 29 | 30 | default_cfgs = { 31 | 'tnt_t_conv_patch16_224': _cfg( 32 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 33 | ), 34 | 'tnt_s_conv_patch16_224': _cfg( 35 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 36 | ), 37 | 'tnt_b_conv_patch16_224': _cfg( 38 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 39 | ), 40 | } 41 | 42 | 43 | class Block(nn.Module): 44 | """ TNT Block 45 | """ 46 | 47 | def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 48 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 49 | super().__init__() 50 | # Inner transformer 51 | self.norm_in = norm_layer(in_dim) 52 | self.attn_in = Attention( 53 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 54 | attn_drop=attn_drop, proj_drop=drop) 55 | 56 | self.norm_mlp_in = norm_layer(in_dim) 57 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 58 | out_features=in_dim, act_layer=act_layer, drop=drop) 59 | 60 | self.norm1_proj = norm_layer(in_dim) 61 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 62 | # Outer transformer 63 | self.norm_out = norm_layer(dim) 64 | self.attn_out = Attention( 65 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 66 | attn_drop=attn_drop, proj_drop=drop) 67 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 68 | 69 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim) 70 | 71 | 72 | def forward(self, pixel_embed, patch_embed): 73 | # inner 74 | x, _ = self.attn_in(self.norm_in(pixel_embed)) 75 | pixel_embed = pixel_embed + self.drop_path(x) 76 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 77 | 78 | # outer 79 | B, N, C = patch_embed.size() 80 | Nsqrt = int(math.sqrt(N)) 81 | patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) 82 | x, weights = self.attn_out(self.norm_out(patch_embed)) 83 | patch_embed = patch_embed + self.drop_path(x) 84 | 85 | cls_token, patch_embed = torch.split(patch_embed, [1, N - 1], dim=1) # (B, 1, dim), (B, 196, dim) 86 | patch_embed = patch_embed.transpose(1, 2).view(B, C, Nsqrt, Nsqrt) # (B, dim, 14, 14) 87 | patch_embed = self.conv(patch_embed).flatten(2).transpose(1, 2) # (B, 196, dim) 88 | patch_embed = torch.cat([cls_token, patch_embed], dim=1) 89 | 90 | return pixel_embed, patch_embed, weights 91 | 92 | 93 | class LocalViT_TNT(TNT): 94 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 95 | """ 96 | 97 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, 98 | num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 99 | drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 100 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, in_dim, depth, 101 | num_heads, in_num_head, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, 102 | drop_path_rate, norm_layer, first_stride) 103 | new_patch_size = self.pixel_embed.new_patch_size 104 | num_pixel = new_patch_size ** 2 105 | 106 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 107 | blocks = [] 108 | for i in range(depth): 109 | blocks.append(Block( 110 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 111 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 112 | drop_path=dpr[i], norm_layer=norm_layer)) 113 | self.blocks = nn.ModuleList(blocks) 114 | 115 | self.apply(self._init_weights) 116 | 117 | 118 | @register_model 119 | def LNL_Ti(pretrained=False, **kwargs): 120 | model = LocalViT_TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3, 121 | qkv_bias=False, **kwargs) 122 | model.default_cfg = default_cfgs['tnt_t_conv_patch16_224'] 123 | if pretrained: 124 | load_pretrained( 125 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 126 | return model 127 | 128 | 129 | @register_model 130 | def LNL_S(pretrained=False, **kwargs): 131 | model = LocalViT_TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, 132 | qkv_bias=False, **kwargs) 133 | model.default_cfg = default_cfgs['tnt_s_conv_patch16_224'] 134 | if pretrained: 135 | load_pretrained( 136 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 137 | return 138 | -------------------------------------------------------------------------------- /models/localvit_tnt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | LNL : Introducing locality mechanism into Transformer in Transformer (TNT) 6 | 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.models.helpers import load_pretrained 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | from timm.models.vision_transformer import Mlp 15 | from timm.models.registry import register_model 16 | from models.localvit import LocalityFeedForward 17 | from models.tnt import Attention, TNT 18 | import math 19 | 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': .9, 'interpolation': 'bicubic', 26 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 27 | 'first_conv': 'pixel_embed.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | 32 | default_cfgs = { 33 | 'tnt_t_conv_patch16_224': _cfg( 34 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 35 | ), 36 | 'tnt_s_conv_patch16_224': _cfg( 37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | 'tnt_b_conv_patch16_224': _cfg( 40 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 41 | ), 42 | } 43 | 44 | 45 | class Block(nn.Module): 46 | """ TNT Block 47 | """ 48 | 49 | def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 50 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 51 | super().__init__() 52 | # Inner transformer 53 | self.norm_in = norm_layer(in_dim) 54 | self.attn_in = Attention( 55 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 56 | attn_drop=attn_drop, proj_drop=drop) 57 | 58 | self.norm_mlp_in = norm_layer(in_dim) 59 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 60 | out_features=in_dim, act_layer=act_layer, drop=drop) 61 | 62 | self.norm1_proj = norm_layer(in_dim) 63 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 64 | # Outer transformer 65 | self.norm_out = norm_layer(dim) 66 | self.attn_out = Attention( 67 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 68 | attn_drop=attn_drop, proj_drop=drop) 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | 71 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim) 72 | 73 | 74 | def forward(self, pixel_embed, patch_embed): 75 | # inner 76 | x, _ = self.attn_in(self.norm_in(pixel_embed)) 77 | pixel_embed = pixel_embed + self.drop_path(x) 78 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 79 | 80 | # outer 81 | B, N, C = patch_embed.size() 82 | Nsqrt = int(math.sqrt(N)) 83 | patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) 84 | x, weights = self.attn_out(self.norm_out(patch_embed)) 85 | patch_embed = patch_embed + self.drop_path(x) 86 | 87 | cls_token, patch_embed = torch.split(patch_embed, [1, N - 1], dim=1) # (B, 1, dim), (B, 196, dim) 88 | patch_embed = patch_embed.transpose(1, 2).view(B, C, Nsqrt, Nsqrt) # (B, dim, 14, 14) 89 | patch_embed = self.conv(patch_embed).flatten(2).transpose(1, 2) # (B, 196, dim) 90 | patch_embed = torch.cat([cls_token, patch_embed], dim=1) 91 | 92 | return pixel_embed, patch_embed, weights 93 | 94 | 95 | class LocalViT_TNT(TNT): 96 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 97 | """ 98 | 99 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, 100 | num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 101 | drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 102 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, in_dim, depth, 103 | num_heads, in_num_head, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, 104 | drop_path_rate, norm_layer, first_stride) 105 | new_patch_size = self.pixel_embed.new_patch_size 106 | num_pixel = new_patch_size ** 2 107 | 108 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 109 | blocks = [] 110 | for i in range(depth): 111 | blocks.append(Block( 112 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 113 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 114 | drop_path=dpr[i], norm_layer=norm_layer)) 115 | self.blocks = nn.ModuleList(blocks) 116 | 117 | self.apply(self._init_weights) 118 | 119 | 120 | @register_model 121 | def LNL_Ti(pretrained=False, **kwargs): 122 | model = LocalViT_TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3, 123 | qkv_bias=False, **kwargs) 124 | model.default_cfg = default_cfgs['tnt_t_conv_patch16_224'] 125 | if pretrained: 126 | load_pretrained( 127 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 128 | return model 129 | 130 | 131 | @register_model 132 | def LNL_S(pretrained=False, **kwargs): 133 | model = LocalViT_TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, 134 | qkv_bias=False, **kwargs) 135 | model.default_cfg = default_cfgs['tnt_s_conv_patch16_224'] 136 | if pretrained: 137 | load_pretrained( 138 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 139 | return model 140 | -------------------------------------------------------------------------------- /LNL_MoEx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Introducing locality mechanism into Transformer in Transformer (TNT) 6 | 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.models.helpers import load_pretrained 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | from timm.models.vision_transformer import Mlp 15 | from timm.models.registry import register_model 16 | from models.localvit import LocalityFeedForward 17 | from models.tnt_moex import Attention, TNT 18 | import math 19 | 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': .9, 'interpolation': 'bicubic', 26 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 27 | 'first_conv': 'pixel_embed.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | 32 | default_cfgs = { 33 | 'tnt_t_conv_patch16_224': _cfg( 34 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 35 | ), 36 | 'tnt_s_conv_patch16_224': _cfg( 37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | 'tnt_b_conv_patch16_224': _cfg( 40 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 41 | ), 42 | } 43 | 44 | 45 | class Block(nn.Module): 46 | """ TNT Block 47 | """ 48 | 49 | def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 50 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 51 | super().__init__() 52 | # Inner transformer 53 | self.norm_in = norm_layer(in_dim) 54 | self.attn_in = Attention( 55 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 56 | attn_drop=attn_drop, proj_drop=drop) 57 | 58 | self.norm_mlp_in = norm_layer(in_dim) 59 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 60 | out_features=in_dim, act_layer=act_layer, drop=drop) 61 | 62 | self.norm1_proj = norm_layer(in_dim) 63 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 64 | # Outer transformer 65 | self.norm_out = norm_layer(dim) 66 | self.attn_out = Attention( 67 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 68 | attn_drop=attn_drop, proj_drop=drop) 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | 71 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim) 72 | 73 | 74 | def forward(self, pixel_embed, patch_embed): 75 | # inner 76 | x, _ = self.attn_in(self.norm_in(pixel_embed)) 77 | pixel_embed = pixel_embed + self.drop_path(x) 78 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 79 | 80 | # outer 81 | B, N, C = patch_embed.size() 82 | Nsqrt = int(math.sqrt(N)) 83 | patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) 84 | x, weights = self.attn_out(self.norm_out(patch_embed)) 85 | patch_embed = patch_embed + self.drop_path(x) 86 | 87 | cls_token, patch_embed = torch.split(patch_embed, [1, N - 1], dim=1) # (B, 1, dim), (B, 196, dim) 88 | patch_embed = patch_embed.transpose(1, 2).view(B, C, Nsqrt, Nsqrt) # (B, dim, 14, 14) 89 | patch_embed = self.conv(patch_embed).flatten(2).transpose(1, 2) # (B, 196, dim) 90 | patch_embed = torch.cat([cls_token, patch_embed], dim=1) 91 | 92 | return pixel_embed, patch_embed, weights 93 | 94 | 95 | class LocalViT_TNT(TNT): 96 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 97 | """ 98 | 99 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, 100 | num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 101 | drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 102 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, in_dim, depth, 103 | num_heads, in_num_head, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, 104 | drop_path_rate, norm_layer, first_stride) 105 | new_patch_size = self.pixel_embed.new_patch_size 106 | num_pixel = new_patch_size ** 2 107 | 108 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 109 | blocks = [] 110 | for i in range(depth): 111 | blocks.append(Block( 112 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 113 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 114 | drop_path=dpr[i], norm_layer=norm_layer)) 115 | self.blocks = nn.ModuleList(blocks) 116 | 117 | self.apply(self._init_weights) 118 | 119 | 120 | @register_model 121 | def LNL_MoEx_Ti(pretrained=False, **kwargs): 122 | model = LocalViT_TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3, 123 | qkv_bias=False, **kwargs) 124 | model.default_cfg = default_cfgs['tnt_t_conv_patch16_224'] 125 | if pretrained: 126 | load_pretrained( 127 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 128 | return model 129 | 130 | 131 | @register_model 132 | def LNL_MoEx_S(pretrained=False, **kwargs): 133 | model = LocalViT_TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, 134 | qkv_bias=False, **kwargs) 135 | model.default_cfg = default_cfgs['tnt_s_conv_patch16_224'] 136 | if pretrained: 137 | load_pretrained( 138 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 139 | return model 140 | -------------------------------------------------------------------------------- /models/modeling_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" 17 | import math 18 | 19 | from os.path import join as pjoin 20 | 21 | from collections import OrderedDict # pylint: disable=g-importing-member 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | 28 | def np2th(weights, conv=False): 29 | """Possibly convert HWIO to OIHW.""" 30 | if conv: 31 | weights = weights.transpose([3, 2, 0, 1]) 32 | return torch.from_numpy(weights) 33 | 34 | 35 | class StdConv2d(nn.Conv2d): 36 | 37 | def forward(self, x): 38 | w = self.weight 39 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 40 | w = (w - m) / torch.sqrt(v + 1e-5) 41 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 42 | self.dilation, self.groups) 43 | 44 | 45 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 46 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 47 | padding=1, bias=bias, groups=groups) 48 | 49 | 50 | def conv1x1(cin, cout, stride=1, bias=False): 51 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 52 | padding=0, bias=bias) 53 | 54 | 55 | class PreActBottleneck(nn.Module): 56 | """Pre-activation (v2) bottleneck block. 57 | """ 58 | 59 | def __init__(self, cin, cout=None, cmid=None, stride=1): 60 | super().__init__() 61 | cout = cout or cin 62 | cmid = cmid or cout//4 63 | 64 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 65 | self.conv1 = conv1x1(cin, cmid, bias=False) 66 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 67 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 68 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 69 | self.conv3 = conv1x1(cmid, cout, bias=False) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | if (stride != 1 or cin != cout): 73 | # Projection also with pre-activation according to paper. 74 | self.downsample = conv1x1(cin, cout, stride, bias=False) 75 | self.gn_proj = nn.GroupNorm(cout, cout) 76 | 77 | def forward(self, x): 78 | 79 | # Residual branch 80 | residual = x 81 | if hasattr(self, 'downsample'): 82 | residual = self.downsample(x) 83 | residual = self.gn_proj(residual) 84 | 85 | # Unit's branch 86 | y = self.relu(self.gn1(self.conv1(x))) 87 | y = self.relu(self.gn2(self.conv2(y))) 88 | y = self.gn3(self.conv3(y)) 89 | 90 | y = self.relu(residual + y) 91 | return y 92 | 93 | def load_from(self, weights, n_block, n_unit): 94 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 95 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 96 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 97 | 98 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 99 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 100 | 101 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 102 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 103 | 104 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 105 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 106 | 107 | self.conv1.weight.copy_(conv1_weight) 108 | self.conv2.weight.copy_(conv2_weight) 109 | self.conv3.weight.copy_(conv3_weight) 110 | 111 | self.gn1.weight.copy_(gn1_weight.view(-1)) 112 | self.gn1.bias.copy_(gn1_bias.view(-1)) 113 | 114 | self.gn2.weight.copy_(gn2_weight.view(-1)) 115 | self.gn2.bias.copy_(gn2_bias.view(-1)) 116 | 117 | self.gn3.weight.copy_(gn3_weight.view(-1)) 118 | self.gn3.bias.copy_(gn3_bias.view(-1)) 119 | 120 | if hasattr(self, 'downsample'): 121 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 122 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 123 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 124 | 125 | self.downsample.weight.copy_(proj_conv_weight) 126 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 127 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 128 | 129 | class ResNetV2(nn.Module): 130 | """Implementation of Pre-activation (v2) ResNet mode.""" 131 | 132 | def __init__(self, block_units, width_factor): 133 | super().__init__() 134 | width = int(64 * width_factor) 135 | self.width = width 136 | 137 | # The following will be unreadable if we split lines. 138 | # pylint: disable=line-too-long 139 | self.root = nn.Sequential(OrderedDict([ 140 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 141 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 142 | ('relu', nn.ReLU(inplace=True)), 143 | ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 144 | ])) 145 | 146 | self.body = nn.Sequential(OrderedDict([ 147 | ('block1', nn.Sequential(OrderedDict( 148 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 149 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 150 | ))), 151 | ('block2', nn.Sequential(OrderedDict( 152 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 153 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 154 | ))), 155 | ('block3', nn.Sequential(OrderedDict( 156 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 157 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 158 | ))), 159 | ])) 160 | 161 | def forward(self, x): 162 | x = self.root(x) 163 | x = self.body(x) 164 | return x 165 | -------------------------------------------------------------------------------- /models/localvit_t2t.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Introducing locality mechanism to "Tokens-to-token vit: Training vision transformers from scratch on imagenet". 6 | """ 7 | import torch 8 | import math 9 | import torch.nn as nn 10 | import numpy as np 11 | from timm.models.helpers import load_pretrained 12 | from timm.models.registry import register_model 13 | from timm.models.layers import DropPath 14 | from models.token_transformer import Token_transformer_local 15 | from models.token_performer import Token_performer_local 16 | from models.t2t_vit import T2T_module, T2T_ViT 17 | from models.t2t_vit_block import Attention 18 | from models.localvit import LocalityFeedForward 19 | 20 | def _cfg(url='', **kwargs): 21 | return { 22 | 'url': url, 23 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 24 | 'crop_pct': .9, 'interpolation': 'bicubic', 25 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 26 | 'classifier': 'head', 27 | **kwargs 28 | } 29 | 30 | default_cfgs = { 31 | 'localvit_T2t_conv7': _cfg(), 32 | } 33 | 34 | 35 | class Block(nn.Module): 36 | 37 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 38 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patches=196, reduction=4): 39 | super().__init__() 40 | self.num_patches = num_patches 41 | self.norm1 = norm_layer(dim) 42 | self.attn = Attention( 43 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 44 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 45 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 46 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=reduction) 47 | 48 | def forward(self, x): 49 | # print(x.shape) 50 | batch_size, num_token, embed_dim = x.shape # (B, 197, dim) 51 | patch_size = int(math.sqrt(num_token)) 52 | 53 | x = x + self.drop_path(self.attn(self.norm1(x))) # (B, 197, dim) 54 | cls_token, x = torch.split(x, [1, self.num_patches], dim=1) # (B, 1, dim), (B, 196, dim) 55 | # print(cls_token.shape, x.shape) 56 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 57 | x = self.conv(x).flatten(2).transpose(1, 2) # (B, 196, dim) 58 | x = torch.cat([cls_token, x], dim=1) 59 | return x 60 | 61 | 62 | class T2T_module_local(nn.Module): 63 | """ 64 | Tokens-to-Token encoding module 65 | """ 66 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64): 67 | super().__init__() 68 | 69 | if tokens_type == 'transformer': 70 | print('adopt transformer encoder for tokens-to-token') 71 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 72 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 73 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 74 | 75 | self.attention1 = Token_transformer_local(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1) 76 | self.attention2 = Token_transformer_local(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1) 77 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 78 | 79 | elif tokens_type == 'performer': 80 | print('adopt performer encoder for tokens-to-token') 81 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 82 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 83 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 84 | 85 | self.attention1 = Token_performer_local(dim=in_chans * 7 * 7, in_dim=token_dim, kernel_ratio=0.5) 86 | self.attention2 = Token_performer_local(dim=token_dim * 3 * 3, in_dim=token_dim, kernel_ratio=0.5) 87 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 88 | 89 | elif tokens_type == 'convolution': # just for comparison with conolution, not our model 90 | # for this tokens type, you need change forward as three convolution operation 91 | print('adopt convolution layers for tokens-to-token') 92 | self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 93 | self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 94 | self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 95 | 96 | self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately 97 | 98 | def forward(self, x): 99 | # step0: soft split 100 | x = self.soft_split0(x).transpose(1, 2) 101 | 102 | # iteration1: re-structurization/reconstruction 103 | x = self.attention1(x) 104 | B, new_HW, C = x.shape 105 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 106 | # iteration1: soft split 107 | x = self.soft_split1(x).transpose(1, 2) 108 | 109 | # iteration2: re-structurization/reconstruction 110 | x = self.attention2(x) 111 | B, new_HW, C = x.shape 112 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 113 | # iteration2: soft split 114 | x = self.soft_split2(x).transpose(1, 2) 115 | 116 | # final tokens 117 | x = self.project(x) 118 | 119 | return x 120 | 121 | class LocalViT_T2T(T2T_ViT): 122 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, 123 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 124 | drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64, reduction=4): 125 | super().__init__(img_size, tokens_type, in_chans, num_classes, embed_dim, depth, num_heads, mlp_ratio, 126 | qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, token_dim) 127 | 128 | self.tokens_to_token = T2T_module_local( 129 | img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim) 130 | 131 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 132 | self.blocks = nn.ModuleList([ 133 | Block( 134 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 135 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 136 | reduction=reduction 137 | ) 138 | for i in range(depth)]) 139 | 140 | self.apply(self._init_weights) 141 | 142 | 143 | @register_model 144 | def localvit_T2t_conv7(pretrained=False, **kwargs): # adopt performer for tokens to token 145 | if pretrained: 146 | kwargs.setdefault('qk_scale', 256 ** -0.5) 147 | model = LocalViT_T2T(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., reduction=128, 148 | **kwargs) 149 | model.default_cfg = default_cfgs['localvit_T2t_conv7'] 150 | if pretrained: 151 | load_pretrained( 152 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 153 | return model 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /models/deit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "DeiT: Data-efficient Image Transformers". 6 | Code borrowed from https://github.com/facebookresearch/deit 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from functools import partial 12 | 13 | from timm.models.vision_transformer import VisionTransformer, _cfg 14 | from timm.models.registry import register_model 15 | from timm.models.layers import trunc_normal_ 16 | 17 | 18 | __all__ = [ 19 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 20 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 21 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 22 | 'deit_base_distilled_patch16_384', 23 | ] 24 | 25 | 26 | class DistilledVisionTransformer(VisionTransformer): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 30 | num_patches = self.patch_embed.num_patches 31 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 32 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 33 | 34 | trunc_normal_(self.dist_token, std=.02) 35 | trunc_normal_(self.pos_embed, std=.02) 36 | self.head_dist.apply(self._init_weights) 37 | 38 | def forward_features(self, x): 39 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 40 | # with slight modifications to add the dist_token 41 | B = x.shape[0] 42 | x = self.patch_embed(x) 43 | 44 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 45 | dist_token = self.dist_token.expand(B, -1, -1) 46 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 47 | 48 | x = x + self.pos_embed 49 | x = self.pos_drop(x) 50 | 51 | for blk in self.blocks: 52 | x = blk(x) 53 | 54 | x = self.norm(x) 55 | return x[:, 0], x[:, 1] 56 | 57 | def forward(self, x): 58 | x, x_dist = self.forward_features(x) 59 | x = self.head(x) 60 | x_dist = self.head_dist(x_dist) 61 | if self.training: 62 | return x, x_dist 63 | else: 64 | # during inference, return the average of both classifier predictions 65 | return (x + x_dist) / 2 66 | 67 | 68 | @register_model 69 | def deit_tiny_patch16_224_ex6(pretrained=False, **kwargs): 70 | # the expanded Deit-T in Table 1 71 | model = VisionTransformer( 72 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=6, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | model.default_cfg = _cfg() 75 | return model 76 | 77 | 78 | @register_model 79 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 80 | model = VisionTransformer( 81 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | model.default_cfg = _cfg() 84 | if pretrained: 85 | checkpoint = torch.hub.load_state_dict_from_url( 86 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 87 | map_location="cpu", check_hash=True 88 | ) 89 | model.load_state_dict(checkpoint["model"]) 90 | return model 91 | 92 | 93 | @register_model 94 | def deit_small_patch16_224(pretrained=False, **kwargs): 95 | model = VisionTransformer( 96 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 97 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 98 | model.default_cfg = _cfg() 99 | if pretrained: 100 | checkpoint = torch.hub.load_state_dict_from_url( 101 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 102 | map_location="cpu", check_hash=True 103 | ) 104 | model.load_state_dict(checkpoint["model"]) 105 | return model 106 | 107 | 108 | @register_model 109 | def deit_base_patch16_224(pretrained=False, **kwargs): 110 | model = VisionTransformer( 111 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 113 | model.default_cfg = _cfg() 114 | if pretrained: 115 | checkpoint = torch.hub.load_state_dict_from_url( 116 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 117 | map_location="cpu", check_hash=True 118 | ) 119 | model.load_state_dict(checkpoint["model"]) 120 | return model 121 | 122 | 123 | @register_model 124 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 125 | model = DistilledVisionTransformer( 126 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 127 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 128 | model.default_cfg = _cfg() 129 | if pretrained: 130 | checkpoint = torch.hub.load_state_dict_from_url( 131 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 132 | map_location="cpu", check_hash=True 133 | ) 134 | model.load_state_dict(checkpoint["model"]) 135 | return model 136 | 137 | 138 | @register_model 139 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 140 | model = DistilledVisionTransformer( 141 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 142 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 143 | model.default_cfg = _cfg() 144 | if pretrained: 145 | checkpoint = torch.hub.load_state_dict_from_url( 146 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 147 | map_location="cpu", check_hash=True 148 | ) 149 | model.load_state_dict(checkpoint["model"]) 150 | return model 151 | 152 | 153 | @register_model 154 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 155 | model = DistilledVisionTransformer( 156 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 157 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 158 | model.default_cfg = _cfg() 159 | if pretrained: 160 | checkpoint = torch.hub.load_state_dict_from_url( 161 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 162 | map_location="cpu", check_hash=True 163 | ) 164 | model.load_state_dict(checkpoint["model"]) 165 | return model 166 | 167 | 168 | @register_model 169 | def deit_base_patch16_384(pretrained=False, **kwargs): 170 | model = VisionTransformer( 171 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 172 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 173 | model.default_cfg = _cfg() 174 | if pretrained: 175 | checkpoint = torch.hub.load_state_dict_from_url( 176 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 177 | map_location="cpu", check_hash=True 178 | ) 179 | model.load_state_dict(checkpoint["model"]) 180 | return model 181 | 182 | 183 | @register_model 184 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 185 | model = DistilledVisionTransformer( 186 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 187 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 188 | model.default_cfg = _cfg() 189 | if pretrained: 190 | checkpoint = torch.hub.load_state_dict_from_url( 191 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 192 | map_location="cpu", check_hash=True 193 | ) 194 | model.load_state_dict(checkpoint["model"]) 195 | return model 196 | -------------------------------------------------------------------------------- /models/tnt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Code borrowed from https://github.com/rwightman/pytorch-image-models 6 | 7 | Transformer in Transformer (TNT) in PyTorch 8 | A PyTorch implement of TNT as described in 9 | 'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 10 | The official mindspore code is released and available at 11 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT 12 | """ 13 | import math 14 | import torch 15 | import torch.nn as nn 16 | from functools import partial 17 | 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from timm.models.helpers import load_pretrained 20 | from timm.models.layers import DropPath, trunc_normal_ 21 | from timm.models.vision_transformer import Mlp 22 | from timm.models.registry import register_model 23 | 24 | 25 | def _cfg(url='', **kwargs): 26 | return { 27 | 'url': url, 28 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 29 | 'crop_pct': .9, 'interpolation': 'bicubic', 30 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 31 | 'first_conv': 'pixel_embed.proj', 'classifier': 'head', 32 | **kwargs 33 | } 34 | 35 | 36 | default_cfgs = { 37 | 'tnt_t_patch16_224': _cfg( 38 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 39 | ), 40 | 'tnt_s_patch16_224': _cfg( 41 | url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', 42 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 43 | ), 44 | 'tnt_b_patch16_224': _cfg( 45 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 46 | ), 47 | } 48 | 49 | 50 | class Attention(nn.Module): 51 | """ Multi-Head Attention 52 | """ 53 | 54 | def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 55 | super().__init__() 56 | self.hidden_dim = hidden_dim 57 | self.num_heads = num_heads 58 | head_dim = hidden_dim // num_heads 59 | self.head_dim = head_dim 60 | self.scale = head_dim ** -0.5 61 | 62 | self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) 63 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop, inplace=True) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop, inplace=True) 67 | 68 | def forward(self, x): 69 | B, N, C = x.shape 70 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 71 | q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) 72 | v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) 73 | 74 | attn = (q @ k.transpose(-2, -1)) * self.scale 75 | attn = attn.softmax(dim=-1) 76 | weights = attn 77 | attn = self.attn_drop(attn) 78 | 79 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 80 | x = self.proj(x) 81 | x = self.proj_drop(x) 82 | return x, weights 83 | 84 | 85 | class Block(nn.Module): 86 | """ TNT Block 87 | """ 88 | 89 | def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 90 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 91 | super().__init__() 92 | # Inner transformer 93 | self.norm_in = norm_layer(in_dim) 94 | self.attn_in = Attention( 95 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 96 | attn_drop=attn_drop, proj_drop=drop) 97 | 98 | self.norm_mlp_in = norm_layer(in_dim) 99 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 100 | out_features=in_dim, act_layer=act_layer, drop=drop) 101 | 102 | self.norm1_proj = norm_layer(in_dim) 103 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 104 | # Outer transformer 105 | self.norm_out = norm_layer(dim) 106 | self.attn_out = Attention( 107 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 108 | attn_drop=attn_drop, proj_drop=drop) 109 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 110 | 111 | self.norm_mlp = norm_layer(dim) 112 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), 113 | out_features=dim, act_layer=act_layer, drop=drop) 114 | 115 | def forward(self, pixel_embed, patch_embed): 116 | # inner 117 | x, _ = self.attn_in(self.norm_in(pixel_embed)) 118 | pixel_embed = pixel_embed + self.drop_path(x) 119 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 120 | # outer 121 | B, N, C = patch_embed.size() 122 | patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) 123 | x, weights = self.attn_out(self.norm_out(patch_embed)) 124 | patch_embed = patch_embed + self.drop_path(x) 125 | patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) 126 | return pixel_embed, patch_embed, weights 127 | 128 | 129 | class PixelEmbed(nn.Module): 130 | """ Image to Pixel Embedding 131 | """ 132 | 133 | def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): 134 | super().__init__() 135 | num_patches = (img_size // patch_size) ** 2 136 | self.img_size = img_size 137 | self.num_patches = num_patches 138 | self.in_dim = in_dim 139 | new_patch_size = math.ceil(patch_size / stride) 140 | self.new_patch_size = new_patch_size 141 | 142 | self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) 143 | self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) 144 | 145 | def forward(self, x, pixel_pos): 146 | B, C, H, W = x.shape 147 | assert H == self.img_size and W == self.img_size, \ 148 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." 149 | x = self.proj(x) 150 | x = self.unfold(x) 151 | x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) 152 | x = x + pixel_pos 153 | x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) 154 | return x 155 | 156 | 157 | class TNT(nn.Module): 158 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 159 | """ 160 | 161 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, 162 | num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 163 | drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 164 | super().__init__() 165 | self.num_classes = num_classes 166 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 167 | 168 | self.pixel_embed = PixelEmbed( 169 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) 170 | num_patches = self.pixel_embed.num_patches 171 | self.num_patches = num_patches 172 | new_patch_size = self.pixel_embed.new_patch_size 173 | num_pixel = new_patch_size ** 2 174 | 175 | self.norm1_proj = norm_layer(num_pixel * in_dim) 176 | self.proj = nn.Linear(num_pixel * in_dim, embed_dim) 177 | self.norm2_proj = norm_layer(embed_dim) 178 | 179 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 180 | self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 181 | self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) 182 | self.pos_drop = nn.Dropout(p=drop_rate) 183 | 184 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 185 | blocks = [] 186 | for i in range(depth): 187 | blocks.append(Block( 188 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 189 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 190 | drop_path=dpr[i], norm_layer=norm_layer)) 191 | self.blocks = nn.ModuleList(blocks) 192 | self.norm = norm_layer(embed_dim) 193 | 194 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 195 | 196 | trunc_normal_(self.cls_token, std=.02) 197 | trunc_normal_(self.patch_pos, std=.02) 198 | trunc_normal_(self.pixel_pos, std=.02) 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, m): 202 | if isinstance(m, nn.Linear): 203 | trunc_normal_(m.weight, std=.02) 204 | if isinstance(m, nn.Linear) and m.bias is not None: 205 | nn.init.constant_(m.bias, 0) 206 | elif isinstance(m, nn.LayerNorm): 207 | nn.init.constant_(m.bias, 0) 208 | nn.init.constant_(m.weight, 1.0) 209 | 210 | @torch.jit.ignore 211 | def no_weight_decay(self): 212 | return {'patch_pos', 'pixel_pos', 'cls_token'} 213 | 214 | def get_classifier(self): 215 | return self.head 216 | 217 | def reset_classifier(self, num_classes, global_pool=''): 218 | self.num_classes = num_classes 219 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 220 | 221 | def forward_features(self, x): 222 | attn_weights = [] 223 | B = x.shape[0] 224 | pixel_embed = self.pixel_embed(x, self.pixel_pos) 225 | 226 | patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) 227 | patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) 228 | patch_embed = patch_embed + self.patch_pos 229 | patch_embed = self.pos_drop(patch_embed) 230 | 231 | for blk in self.blocks: 232 | pixel_embed, patch_embed, weights = blk(pixel_embed, patch_embed) 233 | attn_weights.append(weights) 234 | patch_embed = self.norm(patch_embed) 235 | return patch_embed[:, 0], attn_weights 236 | 237 | def forward(self, x, vis=False): 238 | x, attn_weights = self.forward_features(x) 239 | x = self.head(x) 240 | if vis: 241 | return x, attn_weights 242 | else: 243 | return x 244 | 245 | 246 | @register_model 247 | def tnt_t_patch16_224(pretrained=False, **kwargs): 248 | model = TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3, 249 | qkv_bias=False, **kwargs) 250 | model.default_cfg = default_cfgs['tnt_t_patch16_224'] 251 | if pretrained: 252 | load_pretrained( 253 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 254 | return model 255 | 256 | 257 | @register_model 258 | def tnt_s_patch16_224(pretrained=False, **kwargs): 259 | model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, 260 | qkv_bias=False, **kwargs) 261 | model.default_cfg = default_cfgs['tnt_s_patch16_224'] 262 | if pretrained: 263 | load_pretrained( 264 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 265 | return model 266 | 267 | 268 | @register_model 269 | def tnt_b_patch16_224(pretrained=False, **kwargs): 270 | model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, 271 | qkv_bias=False, **kwargs) 272 | model.default_cfg = default_cfgs['tnt_b_patch16_224'] 273 | if pretrained: 274 | load_pretrained( 275 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 276 | return model -------------------------------------------------------------------------------- /models/localvit_swin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Introducing locality mechanism to "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | from timm.models.registry import register_model 13 | from models.swin_transformer import window_partition, window_reverse, WindowAttention, PatchMerging, PatchEmbed 14 | from models.swin_transformer import SwinTransformer 15 | from models.localvit import LocalityFeedForward 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | class SwinTransformerBlock(nn.Module): 38 | r""" Swin Transformer Block. 39 | Args: 40 | dim (int): Number of input channels. 41 | input_resolution (tuple[int]): Input resulotion. 42 | num_heads (int): Number of attention heads. 43 | window_size (int): Window size. 44 | shift_size (int): Shift size for SW-MSA. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 47 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 48 | drop (float, optional): Dropout rate. Default: 0.0 49 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 50 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 51 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 52 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 53 | """ 54 | 55 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 56 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 57 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_depthwise=True): 58 | super().__init__() 59 | self.dim = dim 60 | self.input_resolution = input_resolution 61 | self.num_heads = num_heads 62 | self.window_size = window_size 63 | self.shift_size = shift_size 64 | self.mlp_ratio = mlp_ratio 65 | self.use_depthwise = use_depthwise 66 | if min(self.input_resolution) <= self.window_size: 67 | # if window size is larger than input resolution, we don't partition windows 68 | self.shift_size = 0 69 | self.window_size = min(self.input_resolution) 70 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 71 | 72 | self.norm1 = norm_layer(dim) 73 | self.attn = WindowAttention( 74 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 75 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 76 | 77 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 78 | if not use_depthwise: 79 | self.norm2 = norm_layer(dim) 80 | mlp_hidden_dim = int(dim * mlp_ratio) 81 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 82 | else: 83 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, reduction=dim // 4) 84 | 85 | if self.shift_size > 0: 86 | # calculate attention mask for SW-MSA 87 | H, W = self.input_resolution 88 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 89 | h_slices = (slice(0, -self.window_size), 90 | slice(-self.window_size, -self.shift_size), 91 | slice(-self.shift_size, None)) 92 | w_slices = (slice(0, -self.window_size), 93 | slice(-self.window_size, -self.shift_size), 94 | slice(-self.shift_size, None)) 95 | cnt = 0 96 | for h in h_slices: 97 | for w in w_slices: 98 | img_mask[:, h, w, :] = cnt 99 | cnt += 1 100 | 101 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 102 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 103 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 104 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 105 | else: 106 | attn_mask = None 107 | 108 | self.register_buffer("attn_mask", attn_mask) 109 | 110 | def forward(self, x): 111 | H, W = self.input_resolution 112 | B, L, C = x.shape 113 | assert L == H * W, "input feature has wrong size" 114 | 115 | shortcut = x 116 | x = self.norm1(x) 117 | x = x.view(B, H, W, C) 118 | 119 | # cyclic shift 120 | if self.shift_size > 0: 121 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 122 | else: 123 | shifted_x = x 124 | 125 | # partition windows 126 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 127 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 128 | 129 | # W-MSA/SW-MSA 130 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 131 | 132 | # merge windows 133 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 134 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 135 | 136 | # reverse cyclic shift 137 | if self.shift_size > 0: 138 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 139 | else: 140 | x = shifted_x 141 | x = x.view(B, H * W, C) 142 | 143 | # FFN 144 | x = shortcut + self.drop_path(x) 145 | if not self.use_depthwise: 146 | x = x + self.drop_path(self.mlp(self.norm2(x))) 147 | else: 148 | x = self.conv(x.view(B, H, W, C).permute(0, 3, 1, 2)) 149 | x = x.permute(0, 2, 3, 1).view(B, H * W, C) 150 | 151 | return x 152 | 153 | def extra_repr(self) -> str: 154 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 155 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 156 | 157 | def flops(self): 158 | flops = 0 159 | H, W = self.input_resolution 160 | # norm1 161 | flops += self.dim * H * W 162 | # W-MSA/SW-MSA 163 | nW = H * W / self.window_size / self.window_size 164 | flops += nW * self.attn.flops(self.window_size * self.window_size) 165 | # mlp 166 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 167 | # norm2 168 | flops += self.dim * H * W 169 | return flops 170 | 171 | 172 | class BasicLayer(nn.Module): 173 | """ A basic Swin Transformer layer for one stage. 174 | Args: 175 | dim (int): Number of input channels. 176 | input_resolution (tuple[int]): Input resolution. 177 | depth (int): Number of blocks. 178 | num_heads (int): Number of attention heads. 179 | window_size (int): Local window size. 180 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 181 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 182 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 183 | drop (float, optional): Dropout rate. Default: 0.0 184 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 185 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 186 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 187 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 188 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 189 | """ 190 | 191 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 192 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 193 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, use_depthwise=True): 194 | 195 | super().__init__() 196 | self.dim = dim 197 | self.input_resolution = input_resolution 198 | self.depth = depth 199 | self.use_checkpoint = use_checkpoint 200 | 201 | # build blocks 202 | self.blocks = nn.ModuleList([ 203 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 204 | num_heads=num_heads, window_size=window_size, 205 | shift_size=0 if (i % 2 == 0) else window_size // 2, 206 | mlp_ratio=mlp_ratio, 207 | qkv_bias=qkv_bias, qk_scale=qk_scale, 208 | drop=drop, attn_drop=attn_drop, 209 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 210 | norm_layer=norm_layer, 211 | use_depthwise=use_depthwise) 212 | for i in range(depth)]) 213 | 214 | # patch merging layer 215 | if downsample is not None: 216 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 217 | else: 218 | self.downsample = None 219 | 220 | def forward(self, x): 221 | for blk in self.blocks: 222 | if self.use_checkpoint: 223 | x = checkpoint.checkpoint(blk, x) 224 | else: 225 | x = blk(x) 226 | if self.downsample is not None: 227 | x = self.downsample(x) 228 | return x 229 | 230 | def extra_repr(self) -> str: 231 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 232 | 233 | def flops(self): 234 | flops = 0 235 | for blk in self.blocks: 236 | flops += blk.flops() 237 | if self.downsample is not None: 238 | flops += self.downsample.flops() 239 | return flops 240 | 241 | 242 | class LocalViT_swin(SwinTransformer): 243 | 244 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 245 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 246 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 247 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 248 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 249 | use_checkpoint=False, **kwargs): 250 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, depths, num_heads, 251 | window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, 252 | norm_layer, ape, patch_norm, use_checkpoint, **kwargs) 253 | 254 | patches_resolution = self.patch_embed.patches_resolution 255 | 256 | # stochastic depth 257 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 258 | 259 | # build layers 260 | self.layers = nn.ModuleList() 261 | for i_layer in range(self.num_layers): 262 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 263 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 264 | patches_resolution[1] // (2 ** i_layer)), 265 | depth=depths[i_layer], 266 | num_heads=num_heads[i_layer], 267 | window_size=window_size, 268 | mlp_ratio=self.mlp_ratio, 269 | qkv_bias=qkv_bias, qk_scale=qk_scale, 270 | drop=drop_rate, attn_drop=attn_drop_rate, 271 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 272 | norm_layer=norm_layer, 273 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 274 | use_checkpoint=use_checkpoint, 275 | use_depthwise=i_layer < (self.num_layers)) 276 | self.layers.append(layer) 277 | 278 | self.apply(self._init_weights) 279 | 280 | 281 | @register_model 282 | def localvit_swin_tiny_patch4_window7_224(pretrain=False, **kwargs): 283 | model = LocalViT_swin(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7) 284 | return model -------------------------------------------------------------------------------- /models/t2t_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Tokens-to-token vit: Training vision transformers from scratch on imagenet". 6 | Code borrowed from https://github.com/yitu-opensource/T2T-ViT 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | from timm.models.helpers import load_pretrained 12 | from timm.models.registry import register_model 13 | from timm.models.layers import trunc_normal_ 14 | from models.token_transformer import Token_transformer 15 | from models.token_performer import Token_performer 16 | from models.t2t_vit_block import Block, get_sinusoid_encoding 17 | 18 | def _cfg(url='', **kwargs): 19 | return { 20 | 'url': url, 21 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 22 | 'crop_pct': .9, 'interpolation': 'bicubic', 23 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 24 | 'classifier': 'head', 25 | **kwargs 26 | } 27 | 28 | default_cfgs = { 29 | 'T2t_vit_7': _cfg(), 30 | 'T2t_vit_10': _cfg(), 31 | 'T2t_vit_12': _cfg(), 32 | 'T2t_vit_14': _cfg(), 33 | 'T2t_vit_19': _cfg(), 34 | 'T2t_vit_24': _cfg(), 35 | 'T2t_vit_t_14': _cfg(), 36 | 'T2t_vit_t_19': _cfg(), 37 | 'T2t_vit_t_24': _cfg(), 38 | 'T2t_vit_14_resnext': _cfg(), 39 | 'T2t_vit_14_wide': _cfg(), 40 | } 41 | 42 | 43 | 44 | 45 | 46 | class T2T_module(nn.Module): 47 | """ 48 | Tokens-to-Token encoding module 49 | """ 50 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64): 51 | super().__init__() 52 | 53 | if tokens_type == 'transformer': 54 | print('adopt transformer encoder for tokens-to-token') 55 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 56 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 57 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 58 | 59 | self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 60 | self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 61 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 62 | 63 | elif tokens_type == 'performer': 64 | print('adopt performer encoder for tokens-to-token') 65 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 66 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 67 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 68 | 69 | #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) 70 | #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) 71 | self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5) 72 | self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5) 73 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 74 | 75 | elif tokens_type == 'convolution': # just for comparison with conolution, not our model 76 | # for this tokens type, you need change forward as three convolution operation 77 | print('adopt convolution layers for tokens-to-token') 78 | self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 79 | self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 80 | self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 81 | 82 | self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately 83 | 84 | def forward(self, x): 85 | # step0: soft split 86 | x = self.soft_split0(x).transpose(1, 2) 87 | 88 | # iteration1: re-structurization/reconstruction 89 | x = self.attention1(x) 90 | B, new_HW, C = x.shape 91 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 92 | # iteration1: soft split 93 | x = self.soft_split1(x).transpose(1, 2) 94 | 95 | # iteration2: re-structurization/reconstruction 96 | x = self.attention2(x) 97 | B, new_HW, C = x.shape 98 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 99 | # iteration2: soft split 100 | x = self.soft_split2(x).transpose(1, 2) 101 | 102 | # final tokens 103 | x = self.project(x) 104 | 105 | return x 106 | 107 | class T2T_ViT(nn.Module): 108 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, 109 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 110 | drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64): 111 | super().__init__() 112 | self.num_classes = num_classes 113 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 114 | 115 | self.tokens_to_token = T2T_module( 116 | img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim) 117 | num_patches = self.tokens_to_token.num_patches 118 | 119 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 120 | self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) 121 | self.pos_drop = nn.Dropout(p=drop_rate) 122 | 123 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 124 | self.blocks = nn.ModuleList([ 125 | Block( 126 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 127 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 128 | for i in range(depth)]) 129 | self.norm = norm_layer(embed_dim) 130 | 131 | # Classifier head 132 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 133 | 134 | trunc_normal_(self.cls_token, std=.02) 135 | self.apply(self._init_weights) 136 | 137 | def _init_weights(self, m): 138 | if isinstance(m, nn.Linear): 139 | trunc_normal_(m.weight, std=.02) 140 | if isinstance(m, nn.Linear) and m.bias is not None: 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.LayerNorm): 143 | nn.init.constant_(m.bias, 0) 144 | nn.init.constant_(m.weight, 1.0) 145 | 146 | @torch.jit.ignore 147 | def no_weight_decay(self): 148 | return {'cls_token'} 149 | 150 | def get_classifier(self): 151 | return self.head 152 | 153 | def reset_classifier(self, num_classes, global_pool=''): 154 | self.num_classes = num_classes 155 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 156 | 157 | def forward_features(self, x): 158 | B = x.shape[0] 159 | x = self.tokens_to_token(x) 160 | 161 | cls_tokens = self.cls_token.expand(B, -1, -1) 162 | x = torch.cat((cls_tokens, x), dim=1) 163 | x = x + self.pos_embed 164 | x = self.pos_drop(x) 165 | 166 | for blk in self.blocks: 167 | x = blk(x) 168 | 169 | x = self.norm(x) 170 | return x[:, 0] 171 | 172 | def forward(self, x): 173 | x = self.forward_features(x) 174 | x = self.head(x) 175 | return x 176 | 177 | @register_model 178 | def T2t_vit_7(pretrained=False, **kwargs): # adopt performer for tokens to token 179 | if pretrained: 180 | kwargs.setdefault('qk_scale', 256 ** -0.5) 181 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs) 182 | model.default_cfg = default_cfgs['T2t_vit_7'] 183 | if pretrained: 184 | load_pretrained( 185 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 186 | return model 187 | 188 | @register_model 189 | def T2t_vit_10(pretrained=False, **kwargs): # adopt performer for tokens to token 190 | if pretrained: 191 | kwargs.setdefault('qk_scale', 256 ** -0.5) 192 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=10, num_heads=4, mlp_ratio=2., **kwargs) 193 | model.default_cfg = default_cfgs['T2t_vit_10'] 194 | if pretrained: 195 | load_pretrained( 196 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 197 | return model 198 | 199 | @register_model 200 | def T2t_vit_12(pretrained=False, **kwargs): # adopt performer for tokens to token 201 | if pretrained: 202 | kwargs.setdefault('qk_scale', 256 ** -0.5) 203 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=12, num_heads=4, mlp_ratio=2., **kwargs) 204 | model.default_cfg = default_cfgs['T2t_vit_12'] 205 | if pretrained: 206 | load_pretrained( 207 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 208 | return model 209 | 210 | 211 | @register_model 212 | def T2t_vit_14(pretrained=False, **kwargs): # adopt performer for tokens to token 213 | if pretrained: 214 | kwargs.setdefault('qk_scale', 384 ** -0.5) 215 | model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) 216 | model.default_cfg = default_cfgs['T2t_vit_14'] 217 | if pretrained: 218 | load_pretrained( 219 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 220 | return model 221 | 222 | @register_model 223 | def T2t_vit_19(pretrained=False, **kwargs): # adopt performer for tokens to token 224 | if pretrained: 225 | kwargs.setdefault('qk_scale', 448 ** -0.5) 226 | model = T2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) 227 | model.default_cfg = default_cfgs['T2t_vit_19'] 228 | if pretrained: 229 | load_pretrained( 230 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 231 | return model 232 | 233 | @register_model 234 | def T2t_vit_24(pretrained=False, **kwargs): # adopt performer for tokens to token 235 | if pretrained: 236 | kwargs.setdefault('qk_scale', 512 ** -0.5) 237 | model = T2T_ViT(tokens_type='performer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) 238 | model.default_cfg = default_cfgs['T2t_vit_24'] 239 | if pretrained: 240 | load_pretrained( 241 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 242 | return model 243 | 244 | @register_model 245 | def T2t_vit_t_14(pretrained=False, **kwargs): # adopt transformers for tokens to token 246 | if pretrained: 247 | kwargs.setdefault('qk_scale', 384 ** -0.5) 248 | model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) 249 | model.default_cfg = default_cfgs['T2t_vit_t_14'] 250 | if pretrained: 251 | load_pretrained( 252 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 253 | return model 254 | 255 | @register_model 256 | def T2t_vit_t_19(pretrained=False, **kwargs): # adopt transformers for tokens to token 257 | if pretrained: 258 | kwargs.setdefault('qk_scale', 448 ** -0.5) 259 | model = T2T_ViT(tokens_type='transformer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) 260 | model.default_cfg = default_cfgs['T2t_vit_t_19'] 261 | if pretrained: 262 | load_pretrained( 263 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 264 | return model 265 | 266 | @register_model 267 | def T2t_vit_t_24(pretrained=False, **kwargs): # adopt transformers for tokens to token 268 | if pretrained: 269 | kwargs.setdefault('qk_scale', 512 ** -0.5) 270 | model = T2T_ViT(tokens_type='transformer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) 271 | model.default_cfg = default_cfgs['T2t_vit_t_24'] 272 | if pretrained: 273 | load_pretrained( 274 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 275 | return model 276 | 277 | # rexnext and wide structure 278 | @register_model 279 | def T2t_vit_14_resnext(pretrained=False, **kwargs): 280 | if pretrained: 281 | kwargs.setdefault('qk_scale', 384 ** -0.5) 282 | model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=32, mlp_ratio=3., **kwargs) 283 | model.default_cfg = default_cfgs['T2t_vit_14_resnext'] 284 | if pretrained: 285 | load_pretrained( 286 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 287 | return model 288 | 289 | @register_model 290 | def T2t_vit_14_wide(pretrained=False, **kwargs): 291 | if pretrained: 292 | kwargs.setdefault('qk_scale', 512 ** -0.5) 293 | model = T2T_ViT(tokens_type='performer', embed_dim=768, depth=4, num_heads=12, mlp_ratio=3., **kwargs) 294 | model.default_cfg = default_cfgs['T2t_vit_14_wide'] 295 | if pretrained: 296 | load_pretrained( 297 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 298 | return model 299 | -------------------------------------------------------------------------------- /models/localvit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Introducing locality mechanism to "DeiT: Data-efficient Image Transformers". 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | from functools import partial 11 | from timm.models.vision_transformer import VisionTransformer 12 | from timm.models.layers import DropPath 13 | from timm.models.registry import register_model 14 | 15 | 16 | class h_sigmoid(nn.Module): 17 | def __init__(self, inplace=True): 18 | super(h_sigmoid, self).__init__() 19 | self.relu = nn.ReLU6(inplace=inplace) 20 | 21 | def forward(self, x): 22 | return self.relu(x + 3) / 6 23 | 24 | 25 | class h_swish(nn.Module): 26 | def __init__(self, inplace=True): 27 | super(h_swish, self).__init__() 28 | self.sigmoid = h_sigmoid(inplace=inplace) 29 | 30 | def forward(self, x): 31 | return x * self.sigmoid(x) 32 | 33 | 34 | class ECALayer(nn.Module): 35 | def __init__(self, channel, gamma=2, b=1, sigmoid=True): 36 | super(ECALayer, self).__init__() 37 | t = int(abs((math.log(channel, 2) + b) / gamma)) 38 | k = t if t % 2 else t + 1 39 | 40 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 41 | self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=k // 2, bias=False) 42 | if sigmoid: 43 | self.sigmoid = nn.Sigmoid() 44 | else: 45 | self.sigmoid = h_sigmoid() 46 | 47 | def forward(self, x): 48 | y = self.avg_pool(x) 49 | y = self.conv(y.squeeze(-1).transpose(-1, -2)) 50 | y = y.transpose(-1, -2).unsqueeze(-1) 51 | y = self.sigmoid(y) 52 | return x * y.expand_as(x) 53 | 54 | 55 | class SELayer(nn.Module): 56 | def __init__(self, channel, reduction=4): 57 | super(SELayer, self).__init__() 58 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 59 | self.fc = nn.Sequential( 60 | nn.Linear(channel, channel // reduction), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(channel // reduction, channel), 63 | h_sigmoid() 64 | ) 65 | 66 | def forward(self, x): 67 | b, c, _, _ = x.size() 68 | y = self.avg_pool(x).view(b, c) 69 | y = self.fc(y).view(b, c, 1, 1) 70 | return x * y 71 | 72 | 73 | class LocalityFeedForward(nn.Module): 74 | def __init__(self, in_dim, out_dim, stride, expand_ratio=4., act='hs+se', reduction=4, 75 | wo_dp_conv=False, dp_first=False): 76 | """ 77 | :param in_dim: the input dimension 78 | :param out_dim: the output dimension. The input and output dimension should be the same. 79 | :param stride: stride of the depth-wise convolution. 80 | :param expand_ratio: expansion ratio of the hidden dimension. 81 | :param act: the activation function. 82 | relu: ReLU 83 | hs: h_swish 84 | hs+se: h_swish and SE module 85 | hs+eca: h_swish and ECA module 86 | hs+ecah: h_swish and ECA module. Compared with eca, h_sigmoid is used. 87 | :param reduction: reduction rate in SE module. 88 | :param wo_dp_conv: without depth-wise convolution. 89 | :param dp_first: place depth-wise convolution as the first layer. 90 | """ 91 | super(LocalityFeedForward, self).__init__() 92 | hidden_dim = int(in_dim * expand_ratio) 93 | kernel_size = 3 94 | 95 | layers = [] 96 | # the first linear layer is replaced by 1x1 convolution. 97 | layers.extend([ 98 | nn.Conv2d(in_dim, hidden_dim, 1, 1, 0, bias=False), 99 | nn.BatchNorm2d(hidden_dim), 100 | h_swish() if act.find('hs') >= 0 else nn.ReLU6(inplace=True)]) 101 | 102 | # the depth-wise convolution between the two linear layers 103 | if not wo_dp_conv: 104 | dp = [ 105 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, kernel_size // 2, groups=hidden_dim, bias=False), 106 | nn.BatchNorm2d(hidden_dim), 107 | h_swish() if act.find('hs') >= 0 else nn.ReLU6(inplace=True) 108 | ] 109 | if dp_first: 110 | layers = dp + layers 111 | else: 112 | layers.extend(dp) 113 | 114 | if act.find('+') >= 0: 115 | attn = act.split('+')[1] 116 | if attn == 'se': 117 | layers.append(SELayer(hidden_dim, reduction=reduction)) 118 | elif attn.find('eca') >= 0: 119 | layers.append(ECALayer(hidden_dim, sigmoid=attn == 'eca')) 120 | else: 121 | raise NotImplementedError('Activation type {} is not implemented'.format(act)) 122 | 123 | # the second linear layer is replaced by 1x1 convolution. 124 | layers.extend([ 125 | nn.Conv2d(hidden_dim, out_dim, 1, 1, 0, bias=False), 126 | nn.BatchNorm2d(out_dim) 127 | ]) 128 | self.conv = nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | x = x + self.conv(x) 132 | return x 133 | 134 | 135 | class Attention(nn.Module): 136 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, qk_reduce=1, attn_drop=0., proj_drop=0.): 137 | """ 138 | :param dim: 139 | :param num_heads: 140 | :param qkv_bias: 141 | :param qk_scale: 142 | :param qk_reduce: reduce the output dimension for QK projection 143 | :param attn_drop: 144 | :param proj_drop: 145 | """ 146 | super().__init__() 147 | self.num_heads = num_heads 148 | head_dim = dim // num_heads 149 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 150 | self.scale = qk_scale or head_dim ** -0.5 151 | self.qk_reduce = qk_reduce 152 | self.dim = dim 153 | self.qk_dim = int(dim / self.qk_reduce) 154 | 155 | self.qkv = nn.Linear(dim, int(dim * (1 + 1 / qk_reduce * 2)), bias=qkv_bias) 156 | self.attn_drop = nn.Dropout(attn_drop) 157 | self.proj = nn.Linear(dim, dim) 158 | self.proj_drop = nn.Dropout(proj_drop) 159 | 160 | def forward(self, x): 161 | B, N, C = x.shape 162 | if self.qk_reduce == 1: 163 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 164 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 165 | else: 166 | q, k, v = torch.split(self.qkv(x), [self.qk_dim, self.qk_dim, self.dim], dim=-1) 167 | q = q.reshape(B, N, self.num_heads, -1).transpose(1, 2) 168 | k = k.reshape(B, N, self.num_heads, -1).transpose(1, 2) 169 | v = v.reshape(B, N, self.num_heads, -1).transpose(1, 2) 170 | 171 | attn = (q @ k.transpose(-2, -1)) * self.scale 172 | attn = attn.softmax(dim=-1) 173 | attn = self.attn_drop(attn) 174 | 175 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 176 | x = self.proj(x) 177 | x = self.proj_drop(x) 178 | return x 179 | 180 | 181 | class Block(nn.Module): 182 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, qk_reduce=1, drop=0., attn_drop=0., 183 | drop_path=0., norm_layer=nn.LayerNorm, act='hs+se', reduction=4, wo_dp_conv=False, dp_first=False): 184 | super().__init__() 185 | self.norm1 = norm_layer(dim) 186 | self.attn = Attention( 187 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, qk_reduce=qk_reduce, 188 | attn_drop=attn_drop, proj_drop=drop) 189 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 190 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 191 | # The MLP is replaced by the conv layers. 192 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, act, reduction, wo_dp_conv, dp_first) 193 | 194 | def forward(self, x): 195 | batch_size, num_token, embed_dim = x.shape # (B, 197, dim) 196 | patch_size = int(math.sqrt(num_token)) 197 | 198 | x = x + self.drop_path(self.attn(self.norm1(x))) # (B, 197, dim) 199 | # Split the class token and the image token. 200 | cls_token, x = torch.split(x, [1, num_token - 1], dim=1) # (B, 1, dim), (B, 196, dim) 201 | # Reshape and update the image token. 202 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) # (B, dim, 14, 14) 203 | x = self.conv(x).flatten(2).transpose(1, 2) # (B, 196, dim) 204 | # Concatenate the class token and the newly computed image token. 205 | x = torch.cat([cls_token, x], dim=1) 206 | return x 207 | 208 | 209 | class TransformerLayer(nn.Module): 210 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 211 | drop_path=0., norm_layer=nn.LayerNorm): 212 | super().__init__() 213 | self.norm1 = norm_layer(dim) 214 | self.attn = Attention( 215 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 216 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 217 | 218 | ######################################### 219 | # Origianl implementation 220 | # self.norm2 = norm_layer(dim) 221 | # mlp_hidden_dim = int(dim * mlp_ratio) 222 | # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 223 | ######################################### 224 | 225 | # Replace the MLP layer by LocalityFeedForward. 226 | self.conv = LocalityFeedForward(dim, dim, 1, mlp_ratio, act='hs+se', reduction=dim//4) 227 | 228 | def forward(self, x): 229 | x = x + self.drop_path(self.attn(self.norm1(x))) 230 | ######################################### 231 | # Origianl implementation 232 | # x = x + self.drop_path(self.mlp(self.norm2(x))) 233 | ######################################### 234 | 235 | # Change the computation accordingly in three steps. 236 | batch_size, num_token, embed_dim = x.shape 237 | patch_size = int(math.sqrt(num_token)) 238 | # 1. Split the class token and the image token. 239 | cls_token, x = torch.split(x, [1, embed_dim - 1], dim=1) 240 | # 2. Reshape and update the image token. 241 | x = x.transpose(1, 2).view(batch_size, embed_dim, patch_size, patch_size) 242 | x = self.conv(x).flatten(2).transpose(1, 2) 243 | # 3. Concatenate the class token and the newly computed image token. 244 | x = torch.cat([cls_token, x], dim=1) 245 | return x 246 | 247 | 248 | class LocalVisionTransformer(VisionTransformer): 249 | """ Vision Transformer with support for patch or hybrid CNN input stage 250 | """ 251 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 252 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 253 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, 254 | act=3, reduction=4, wo_dp_conv=False, dp_first=False): 255 | # print(hybrid_backbone is None) 256 | super().__init__(img_size, patch_size, in_chans, num_classes, embed_dim, depth, 257 | num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, 258 | drop_path_rate, hybrid_backbone, norm_layer) 259 | 260 | # parse act 261 | if act == 1: 262 | act = 'relu6' 263 | elif act == 2: 264 | act = 'hs' 265 | elif act == 3: 266 | act = 'hs+se' 267 | elif act == 4: 268 | act = 'hs+eca' 269 | else: 270 | act = 'hs+ecah' 271 | 272 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 273 | self.blocks = nn.ModuleList([ 274 | Block( 275 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 276 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 277 | act=act, reduction=reduction, wo_dp_conv=wo_dp_conv, dp_first=dp_first 278 | ) 279 | for i in range(depth)]) 280 | self.norm = norm_layer(embed_dim) 281 | 282 | self.apply(self._init_weights) 283 | 284 | 285 | 286 | @register_model 287 | def localvit_tiny_mlp6_act1(pretrained=False, **kwargs): 288 | model = LocalVisionTransformer( 289 | patch_size=16, embed_dim=192, depth=12, num_heads=4, mlp_ratio=6, qkv_bias=True, act=1, 290 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 291 | return model 292 | 293 | 294 | # reduction = 4 295 | @register_model 296 | def localvit_tiny_mlp4_act3_r4(pretrained=False, **kwargs): 297 | model = LocalVisionTransformer( 298 | patch_size=16, embed_dim=192, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, act=3, reduction=4, 299 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 300 | return model 301 | 302 | # reduction = 192 303 | @register_model 304 | def localvit_tiny_mlp4_act3_r192(pretrained=False, **kwargs): 305 | model = LocalVisionTransformer( 306 | patch_size=16, embed_dim=192, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, act=3, reduction=192, 307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 308 | return model 309 | 310 | 311 | @register_model 312 | def localvit_small_mlp4_act3_r384(pretrained=False, **kwargs): 313 | model = LocalVisionTransformer( 314 | patch_size=16, embed_dim=384, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True, act=3, reduction=384, 315 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 316 | return model 317 | -------------------------------------------------------------------------------- /models/tnt_moex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Code borrowed from https://github.com/rwightman/pytorch-image-models 6 | 7 | Transformer in Transformer (TNT) in PyTorch 8 | A PyTorch implement of TNT as described in 9 | 'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 10 | The official mindspore code is released and available at 11 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT 12 | """ 13 | import math 14 | import torch 15 | import torch.nn as nn 16 | from functools import partial 17 | 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from timm.models.helpers import load_pretrained 20 | from timm.models.layers import DropPath, trunc_normal_ 21 | from timm.models.vision_transformer import Mlp 22 | from timm.models.registry import register_model 23 | 24 | def moex(x, swap_index, norm_type, epsilon=1e-5, positive_only=False): 25 | '''MoEx operation''' 26 | dtype = x.dtype 27 | x = x.float() 28 | 29 | B, C, L = x.shape 30 | if norm_type == 'bn': 31 | norm_dims = [0, 2, 3] 32 | elif norm_type == 'in': 33 | norm_dims = [2, 3] 34 | elif norm_type == 'ln': 35 | norm_dims = [1, 2, 3] 36 | elif norm_type == 'pono': 37 | norm_dims = [1] 38 | elif norm_type.startswith('gn'): 39 | if norm_type.startswith('gn-d'): 40 | # gn-d4 means GN where each group has 4 dims 41 | G_dim = int(norm_type[4:]) 42 | G = C // G_dim 43 | else: 44 | # gn4 means GN with 4 groups 45 | G = int(norm_type[2:]) 46 | G_dim = C // G 47 | x = x.view(B, G, G_dim, H, W) 48 | norm_dims = [2, 3, 4] 49 | elif norm_type.startswith('gpono'): 50 | if norm_type.startswith('gpono-d'): 51 | # gpono-d4 means GPONO where each group has 4 dims 52 | G_dim = int(norm_type[len('gpono-d'):]) 53 | G = C // G_dim 54 | else: 55 | # gpono4 means GPONO with 4 groups 56 | G = int(norm_type[len('gpono'):]) 57 | G_dim = C // G 58 | x = x.view(B, G, G_dim, H, W) 59 | norm_dims = [2] 60 | else: 61 | raise NotImplementedError(f'norm_type={norm_type}') 62 | 63 | if positive_only: 64 | x_pos = F.relu(x) 65 | s1 = x_pos.sum(dim=norm_dims, keepdim=True) 66 | s2 = x_pos.pow(2).sum(dim=norm_dims, keepdim=True) 67 | count = x_pos.gt(0).sum(dim=norm_dims, keepdim=True) 68 | count[count == 0] = 1 # deal with 0/0 69 | mean = s1 / count 70 | var = s2 / count - mean.pow(2) 71 | std = var.add(epsilon).sqrt() 72 | else: 73 | mean = x.mean(dim=norm_dims, keepdim=True) 74 | std = x.var(dim=norm_dims, keepdim=True).add(epsilon).sqrt() 75 | swap_mean = mean[swap_index] 76 | swap_std = std[swap_index] 77 | # output = (x - mean) / std * swap_std + swap_mean 78 | # equvalent but for efficient 79 | scale = swap_std / std 80 | shift = swap_mean - mean * scale 81 | output = x * scale + shift 82 | return output.view(B, C, L).to(dtype) 83 | 84 | def _cfg(url='', **kwargs): 85 | return { 86 | 'url': url, 87 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 88 | 'crop_pct': .9, 'interpolation': 'bicubic', 89 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 90 | 'first_conv': 'pixel_embed.proj', 'classifier': 'head', 91 | **kwargs 92 | } 93 | 94 | 95 | default_cfgs = { 96 | 'tnt_t_patch16_224': _cfg( 97 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 98 | ), 99 | 'tnt_s_patch16_224': _cfg( 100 | url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', 101 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 102 | ), 103 | 'tnt_b_patch16_224': _cfg( 104 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 105 | ), 106 | } 107 | 108 | 109 | class Attention(nn.Module): 110 | """ Multi-Head Attention 111 | """ 112 | 113 | def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 114 | super().__init__() 115 | self.hidden_dim = hidden_dim 116 | self.num_heads = num_heads 117 | head_dim = hidden_dim // num_heads 118 | self.head_dim = head_dim 119 | self.scale = head_dim ** -0.5 120 | 121 | self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) 122 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 123 | self.attn_drop = nn.Dropout(attn_drop, inplace=True) 124 | self.proj = nn.Linear(dim, dim) 125 | self.proj_drop = nn.Dropout(proj_drop, inplace=True) 126 | 127 | def forward(self, x): 128 | B, N, C = x.shape 129 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 130 | q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) 131 | v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) 132 | 133 | attn = (q @ k.transpose(-2, -1)) * self.scale 134 | attn = attn.softmax(dim=-1) 135 | weights = attn 136 | attn = self.attn_drop(attn) 137 | 138 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 139 | x = self.proj(x) 140 | x = self.proj_drop(x) 141 | return x, weights 142 | 143 | 144 | class Block(nn.Module): 145 | """ TNT Block 146 | """ 147 | 148 | def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., 149 | qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 150 | super().__init__() 151 | # Inner transformer 152 | self.norm_in = norm_layer(in_dim) 153 | self.attn_in = Attention( 154 | in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, 155 | attn_drop=attn_drop, proj_drop=drop) 156 | 157 | self.norm_mlp_in = norm_layer(in_dim) 158 | self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), 159 | out_features=in_dim, act_layer=act_layer, drop=drop) 160 | 161 | self.norm1_proj = norm_layer(in_dim) 162 | self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) 163 | # Outer transformer 164 | self.norm_out = norm_layer(dim) 165 | self.attn_out = Attention( 166 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, 167 | attn_drop=attn_drop, proj_drop=drop) 168 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 169 | 170 | self.norm_mlp = norm_layer(dim) 171 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), 172 | out_features=dim, act_layer=act_layer, drop=drop) 173 | 174 | def forward(self, pixel_embed, patch_embed): 175 | # inner 176 | x, _ = self.attn_in(self.norm_in(pixel_embed)) 177 | pixel_embed = pixel_embed + self.drop_path(x) 178 | pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) 179 | # outer 180 | B, N, C = patch_embed.size() 181 | patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) 182 | x, weights = self.attn_out(self.norm_out(patch_embed)) 183 | patch_embed = patch_embed + self.drop_path(x) 184 | patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) 185 | return pixel_embed, patch_embed, weights 186 | 187 | 188 | class PixelEmbed(nn.Module): 189 | """ Image to Pixel Embedding 190 | """ 191 | 192 | def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): 193 | super().__init__() 194 | num_patches = (img_size // patch_size) ** 2 195 | self.img_size = img_size 196 | self.num_patches = num_patches 197 | self.in_dim = in_dim 198 | new_patch_size = math.ceil(patch_size / stride) 199 | self.new_patch_size = new_patch_size 200 | 201 | self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) 202 | self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) 203 | 204 | def forward(self, x, pixel_pos): 205 | B, C, H, W = x.shape 206 | assert H == self.img_size and W == self.img_size, \ 207 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." 208 | x = self.proj(x) 209 | x = self.unfold(x) 210 | x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) 211 | x = x + pixel_pos 212 | x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) 213 | return x 214 | 215 | 216 | class TNT(nn.Module): 217 | """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 218 | """ 219 | 220 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, 221 | num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 222 | drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): 223 | super().__init__() 224 | self.num_classes = num_classes 225 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 226 | 227 | self.pixel_embed = PixelEmbed( 228 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) 229 | num_patches = self.pixel_embed.num_patches 230 | self.num_patches = num_patches 231 | new_patch_size = self.pixel_embed.new_patch_size 232 | num_pixel = new_patch_size ** 2 233 | 234 | self.norm1_proj = norm_layer(num_pixel * in_dim) 235 | self.proj = nn.Linear(num_pixel * in_dim, embed_dim) 236 | self.norm2_proj = norm_layer(embed_dim) 237 | 238 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 239 | self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 240 | self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) 241 | self.pos_drop = nn.Dropout(p=drop_rate) 242 | 243 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 244 | blocks = [] 245 | for i in range(depth): 246 | blocks.append(Block( 247 | dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, 248 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, 249 | drop_path=dpr[i], norm_layer=norm_layer)) 250 | self.blocks = nn.ModuleList(blocks) 251 | self.norm = norm_layer(embed_dim) 252 | 253 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 254 | 255 | trunc_normal_(self.cls_token, std=.02) 256 | trunc_normal_(self.patch_pos, std=.02) 257 | trunc_normal_(self.pixel_pos, std=.02) 258 | self.apply(self._init_weights) 259 | 260 | def _init_weights(self, m): 261 | if isinstance(m, nn.Linear): 262 | trunc_normal_(m.weight, std=.02) 263 | if isinstance(m, nn.Linear) and m.bias is not None: 264 | nn.init.constant_(m.bias, 0) 265 | elif isinstance(m, nn.LayerNorm): 266 | nn.init.constant_(m.bias, 0) 267 | nn.init.constant_(m.weight, 1.0) 268 | 269 | @torch.jit.ignore 270 | def no_weight_decay(self): 271 | return {'patch_pos', 'pixel_pos', 'cls_token'} 272 | 273 | def get_classifier(self): 274 | return self.head 275 | 276 | def reset_classifier(self, num_classes, global_pool=''): 277 | self.num_classes = num_classes 278 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 279 | 280 | def forward_features(self, x, swap_index, moex_norm, moex_epsilon, 281 | moex_layer, moex_positive_only): 282 | attn_weights = [] 283 | B = x.shape[0] 284 | pixel_embed = self.pixel_embed(x, self.pixel_pos) 285 | 286 | patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) 287 | patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) 288 | patch_embed = patch_embed + self.patch_pos 289 | patch_embed = self.pos_drop(patch_embed) 290 | 291 | # moex 292 | if swap_index is not None and moex_layer == 'stem': 293 | patch_embed = moex(patch_embed, swap_index, moex_norm, moex_epsilon, moex_positive_only) 294 | 295 | for blk in self.blocks: 296 | pixel_embed, patch_embed, weights = blk(pixel_embed, patch_embed) 297 | attn_weights.append(weights) 298 | patch_embed = self.norm(patch_embed) 299 | return patch_embed[:, 0], attn_weights 300 | 301 | def forward(self, x, swap_index=None, moex_norm='pono', moex_epsilon=1e-5, 302 | moex_layer='stem', moex_positive_only=False, vis=False): 303 | x, attn_weights = self.forward_features(x, swap_index, moex_norm, moex_epsilon, 304 | moex_layer, moex_positive_only) 305 | x = self.head(x) 306 | if vis: 307 | return x, attn_weights 308 | else: 309 | return x 310 | 311 | 312 | @register_model 313 | def tnt_t_patch16_224(pretrained=False, **kwargs): 314 | model = TNT(patch_size=16, embed_dim=192, in_dim=12, depth=12, num_heads=3, in_num_head=3, 315 | qkv_bias=False, **kwargs) 316 | model.default_cfg = default_cfgs['tnt_t_patch16_224'] 317 | if pretrained: 318 | load_pretrained( 319 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 320 | return model 321 | 322 | 323 | @register_model 324 | def tnt_s_patch16_224(pretrained=False, **kwargs): 325 | model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, 326 | qkv_bias=False, **kwargs) 327 | model.default_cfg = default_cfgs['tnt_s_patch16_224'] 328 | if pretrained: 329 | load_pretrained( 330 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 331 | return model 332 | 333 | 334 | @register_model 335 | def tnt_b_patch16_224(pretrained=False, **kwargs): 336 | model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, 337 | qkv_bias=False, **kwargs) 338 | model.default_cfg = default_cfgs['tnt_b_patch16_224'] 339 | if pretrained: 340 | load_pretrained( 341 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 342 | return model -------------------------------------------------------------------------------- /models/pvt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions". 6 | Code borrowed from https://github.com/whai362/PVT 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from functools import partial 12 | 13 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 14 | from timm.models.registry import register_model 15 | from timm.models.vision_transformer import _cfg 16 | 17 | __all__ = [ 18 | 'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large' 19 | ] 20 | 21 | 22 | class Mlp(nn.Module): 23 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.fc1 = nn.Linear(in_features, hidden_features) 28 | self.act = act_layer() 29 | self.fc2 = nn.Linear(hidden_features, out_features) 30 | self.drop = nn.Dropout(drop) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.drop(x) 36 | x = self.fc2(x) 37 | x = self.drop(x) 38 | return x 39 | 40 | 41 | class Attention(nn.Module): 42 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 43 | super().__init__() 44 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 45 | 46 | self.dim = dim 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = qk_scale or head_dim ** -0.5 50 | 51 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 52 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 53 | self.attn_drop = nn.Dropout(attn_drop) 54 | self.proj = nn.Linear(dim, dim) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | 57 | self.sr_ratio = sr_ratio 58 | if sr_ratio > 1: 59 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 60 | self.norm = nn.LayerNorm(dim) 61 | 62 | def forward(self, x, H, W): 63 | B, N, C = x.shape 64 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 65 | 66 | if self.sr_ratio > 1: 67 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 68 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 69 | x_ = self.norm(x_) 70 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 71 | else: 72 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | k, v = kv[0], kv[1] 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 80 | x = self.proj(x) 81 | x = self.proj_drop(x) 82 | 83 | return x 84 | 85 | 86 | class Block(nn.Module): 87 | 88 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 89 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 90 | super().__init__() 91 | self.norm1 = norm_layer(dim) 92 | self.attn = Attention( 93 | dim, 94 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 95 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 96 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 97 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 98 | self.norm2 = norm_layer(dim) 99 | mlp_hidden_dim = int(dim * mlp_ratio) 100 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 101 | 102 | def forward(self, x, H, W): 103 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 104 | x = x + self.drop_path(self.mlp(self.norm2(x))) 105 | 106 | return x 107 | 108 | 109 | class PatchEmbed(nn.Module): 110 | """ Image to Patch Embedding 111 | """ 112 | 113 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 114 | super().__init__() 115 | img_size = to_2tuple(img_size) 116 | patch_size = to_2tuple(patch_size) 117 | 118 | self.img_size = img_size 119 | self.patch_size = patch_size 120 | assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 121 | f"img_size {img_size} should be divided by patch_size {patch_size}." 122 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 123 | self.num_patches = self.H * self.W 124 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 125 | self.norm = nn.LayerNorm(embed_dim) 126 | 127 | def forward(self, x): 128 | B, C, H, W = x.shape 129 | 130 | x = self.proj(x).flatten(2).transpose(1, 2) 131 | x = self.norm(x) 132 | H, W = H // self.patch_size[0], W // self.patch_size[1] 133 | 134 | return x, (H, W) 135 | 136 | 137 | class PyramidVisionTransformer(nn.Module): 138 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 139 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 140 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 141 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 142 | super().__init__() 143 | self.num_classes = num_classes 144 | self.depths = depths 145 | 146 | # patch_embed 147 | self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, 148 | embed_dim=embed_dims[0]) 149 | self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], 150 | embed_dim=embed_dims[1]) 151 | self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], 152 | embed_dim=embed_dims[2]) 153 | self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], 154 | embed_dim=embed_dims[3]) 155 | 156 | # pos_embed 157 | self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0])) 158 | self.pos_drop1 = nn.Dropout(p=drop_rate) 159 | self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1])) 160 | self.pos_drop2 = nn.Dropout(p=drop_rate) 161 | self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2])) 162 | self.pos_drop3 = nn.Dropout(p=drop_rate) 163 | self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3])) 164 | self.pos_drop4 = nn.Dropout(p=drop_rate) 165 | 166 | # transformer encoder 167 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 168 | cur = 0 169 | self.block1 = nn.ModuleList([Block( 170 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 171 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 172 | sr_ratio=sr_ratios[0]) 173 | for i in range(depths[0])]) 174 | 175 | cur += depths[0] 176 | self.block2 = nn.ModuleList([Block( 177 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 178 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 179 | sr_ratio=sr_ratios[1]) 180 | for i in range(depths[1])]) 181 | 182 | cur += depths[1] 183 | self.block3 = nn.ModuleList([Block( 184 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 185 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 186 | sr_ratio=sr_ratios[2]) 187 | for i in range(depths[2])]) 188 | 189 | cur += depths[2] 190 | self.block4 = nn.ModuleList([Block( 191 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 192 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 193 | sr_ratio=sr_ratios[3]) 194 | for i in range(depths[3])]) 195 | self.norm = norm_layer(embed_dims[3]) 196 | 197 | # cls_token 198 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) 199 | 200 | # classification head 201 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 202 | 203 | # init weights 204 | trunc_normal_(self.pos_embed1, std=.02) 205 | trunc_normal_(self.pos_embed2, std=.02) 206 | trunc_normal_(self.pos_embed3, std=.02) 207 | trunc_normal_(self.pos_embed4, std=.02) 208 | trunc_normal_(self.cls_token, std=.02) 209 | self.apply(self._init_weights) 210 | 211 | def reset_drop_path(self, drop_path_rate): 212 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 213 | cur = 0 214 | for i in range(self.depths[0]): 215 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 216 | 217 | cur += self.depths[0] 218 | for i in range(self.depths[1]): 219 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 220 | 221 | cur += self.depths[1] 222 | for i in range(self.depths[2]): 223 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 224 | 225 | cur += self.depths[2] 226 | for i in range(self.depths[3]): 227 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 228 | 229 | def _init_weights(self, m): 230 | if isinstance(m, nn.Linear): 231 | trunc_normal_(m.weight, std=.02) 232 | if isinstance(m, nn.Linear) and m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | elif isinstance(m, nn.LayerNorm): 235 | nn.init.constant_(m.bias, 0) 236 | nn.init.constant_(m.weight, 1.0) 237 | 238 | @torch.jit.ignore 239 | def no_weight_decay(self): 240 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better 241 | return {'cls_token'} 242 | 243 | def get_classifier(self): 244 | return self.head 245 | 246 | def reset_classifier(self, num_classes, global_pool=''): 247 | self.num_classes = num_classes 248 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | # def _get_pos_embed(self, pos_embed, patch_embed, H, W): 251 | # if H * W == self.patch_embed1.num_patches: 252 | # return pos_embed 253 | # else: 254 | # return F.interpolate( 255 | # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 256 | # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 257 | 258 | def forward_features(self, x): 259 | B = x.shape[0] 260 | 261 | # stage 1 262 | x, (H, W) = self.patch_embed1(x) 263 | x = x + self.pos_embed1 264 | x = self.pos_drop1(x) 265 | for blk in self.block1: 266 | x = blk(x, H, W) 267 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 268 | 269 | # stage 2 270 | x, (H, W) = self.patch_embed2(x) 271 | x = x + self.pos_embed2 272 | x = self.pos_drop2(x) 273 | for blk in self.block2: 274 | x = blk(x, H, W) 275 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 276 | 277 | # stage 3 278 | x, (H, W) = self.patch_embed3(x) 279 | x = x + self.pos_embed3 280 | x = self.pos_drop3(x) 281 | for blk in self.block3: 282 | x = blk(x, H, W) 283 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 284 | 285 | # stage 4 286 | x, (H, W) = self.patch_embed4(x) 287 | cls_tokens = self.cls_token.expand(B, -1, -1) 288 | x = torch.cat((cls_tokens, x), dim=1) 289 | x = x + self.pos_embed4 290 | x = self.pos_drop4(x) 291 | for blk in self.block4: 292 | x = blk(x, H, W) 293 | 294 | x = self.norm(x) 295 | 296 | return x[:, 0] 297 | 298 | def forward(self, x): 299 | x = self.forward_features(x) 300 | x = self.head(x) 301 | 302 | return x 303 | 304 | 305 | def _conv_filter(state_dict, patch_size=16): 306 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 307 | out_dict = {} 308 | for k, v in state_dict.items(): 309 | if 'patch_embed.proj.weight' in k: 310 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 311 | out_dict[k] = v 312 | 313 | return out_dict 314 | 315 | 316 | @register_model 317 | def pvt_tiny(pretrained=False, **kwargs): 318 | model = PyramidVisionTransformer( 319 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 320 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 321 | **kwargs) 322 | model.default_cfg = _cfg() 323 | 324 | return model 325 | 326 | 327 | @register_model 328 | def pvt_small(pretrained=False, **kwargs): 329 | model = PyramidVisionTransformer( 330 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 331 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 332 | model.default_cfg = _cfg() 333 | 334 | return model 335 | 336 | 337 | @register_model 338 | def pvt_medium(pretrained=False, **kwargs): 339 | model = PyramidVisionTransformer( 340 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 341 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 342 | **kwargs) 343 | model.default_cfg = _cfg() 344 | 345 | return model 346 | 347 | 348 | @register_model 349 | def pvt_large(pretrained=False, **kwargs): 350 | model = PyramidVisionTransformer( 351 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 352 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 353 | **kwargs) 354 | model.default_cfg = _cfg() 355 | 356 | return model 357 | 358 | 359 | @register_model 360 | def pvt_huge_v2(pretrained=False, **kwargs): 361 | model = PyramidVisionTransformer( 362 | patch_size=4, embed_dims=[128, 256, 512, 768], num_heads=[2, 4, 8, 12], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 363 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 10, 60, 3], sr_ratios=[8, 4, 2, 1], 364 | # drop_rate=0.0, drop_path_rate=0.02) 365 | **kwargs) 366 | model.default_cfg = _cfg() 367 | 368 | return model -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". 6 | Code borrowed from https://github.com/microsoft/Swin-Transformer 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | from timm.models.registry import register_model 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.act = act_layer() 23 | self.fc2 = nn.Linear(hidden_features, out_features) 24 | self.drop = nn.Dropout(drop) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | return x 33 | 34 | 35 | def window_partition(x, window_size): 36 | """ 37 | Args: 38 | x: (B, H, W, C) 39 | window_size (int): window size 40 | Returns: 41 | windows: (num_windows*B, window_size, window_size, C) 42 | """ 43 | B, H, W, C = x.shape 44 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 45 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 46 | return windows 47 | 48 | 49 | def window_reverse(windows, window_size, H, W): 50 | """ 51 | Args: 52 | windows: (num_windows*B, window_size, window_size, C) 53 | window_size (int): Window size 54 | H (int): Height of image 55 | W (int): Width of image 56 | Returns: 57 | x: (B, H, W, C) 58 | """ 59 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 62 | return x 63 | 64 | 65 | class WindowAttention(nn.Module): 66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 67 | It supports both of shifted and non-shifted window. 68 | Args: 69 | dim (int): Number of input channels. 70 | window_size (tuple[int]): The height and width of the window. 71 | num_heads (int): Number of attention heads. 72 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 73 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 74 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 75 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 76 | """ 77 | 78 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 79 | 80 | super().__init__() 81 | self.dim = dim 82 | self.window_size = window_size # Wh, Ww 83 | self.num_heads = num_heads 84 | head_dim = dim // num_heads 85 | self.scale = qk_scale or head_dim ** -0.5 86 | 87 | # define a parameter table of relative position bias 88 | self.relative_position_bias_table = nn.Parameter( 89 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 90 | 91 | # get pair-wise relative position index for each token inside the window 92 | coords_h = torch.arange(self.window_size[0]) 93 | coords_w = torch.arange(self.window_size[1]) 94 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 95 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 96 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 97 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 98 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 99 | relative_coords[:, :, 1] += self.window_size[1] - 1 100 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 101 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 102 | self.register_buffer("relative_position_index", relative_position_index) 103 | 104 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 105 | self.attn_drop = nn.Dropout(attn_drop) 106 | self.proj = nn.Linear(dim, dim) 107 | self.proj_drop = nn.Dropout(proj_drop) 108 | 109 | trunc_normal_(self.relative_position_bias_table, std=.02) 110 | self.softmax = nn.Softmax(dim=-1) 111 | 112 | def forward(self, x, mask=None): 113 | """ 114 | Args: 115 | x: input features with shape of (num_windows*B, N, C) 116 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 117 | """ 118 | B_, N, C = x.shape 119 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 120 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 121 | 122 | q = q * self.scale 123 | attn = (q @ k.transpose(-2, -1)) 124 | 125 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 126 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 127 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 128 | attn = attn + relative_position_bias.unsqueeze(0) 129 | 130 | if mask is not None: 131 | nW = mask.shape[0] 132 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 133 | attn = attn.view(-1, self.num_heads, N, N) 134 | attn = self.softmax(attn) 135 | else: 136 | attn = self.softmax(attn) 137 | 138 | attn = self.attn_drop(attn) 139 | 140 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 141 | x = self.proj(x) 142 | x = self.proj_drop(x) 143 | return x 144 | 145 | def extra_repr(self) -> str: 146 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 147 | 148 | def flops(self, N): 149 | # calculate flops for 1 window with token length of N 150 | flops = 0 151 | # qkv = self.qkv(x) 152 | flops += N * self.dim * 3 * self.dim 153 | # attn = (q @ k.transpose(-2, -1)) 154 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 155 | # x = (attn @ v) 156 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 157 | # x = self.proj(x) 158 | flops += N * self.dim * self.dim 159 | return flops 160 | 161 | 162 | class SwinTransformerBlock(nn.Module): 163 | r""" Swin Transformer Block. 164 | Args: 165 | dim (int): Number of input channels. 166 | input_resolution (tuple[int]): Input resulotion. 167 | num_heads (int): Number of attention heads. 168 | window_size (int): Window size. 169 | shift_size (int): Shift size for SW-MSA. 170 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 171 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 172 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 173 | drop (float, optional): Dropout rate. Default: 0.0 174 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 175 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 176 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 177 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 178 | """ 179 | 180 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 181 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 182 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 183 | super().__init__() 184 | self.dim = dim 185 | self.input_resolution = input_resolution 186 | self.num_heads = num_heads 187 | self.window_size = window_size 188 | self.shift_size = shift_size 189 | self.mlp_ratio = mlp_ratio 190 | if min(self.input_resolution) <= self.window_size: 191 | # if window size is larger than input resolution, we don't partition windows 192 | self.shift_size = 0 193 | self.window_size = min(self.input_resolution) 194 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 195 | 196 | self.norm1 = norm_layer(dim) 197 | self.attn = WindowAttention( 198 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 199 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 200 | 201 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 202 | self.norm2 = norm_layer(dim) 203 | mlp_hidden_dim = int(dim * mlp_ratio) 204 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 205 | 206 | if self.shift_size > 0: 207 | # calculate attention mask for SW-MSA 208 | H, W = self.input_resolution 209 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 210 | h_slices = (slice(0, -self.window_size), 211 | slice(-self.window_size, -self.shift_size), 212 | slice(-self.shift_size, None)) 213 | w_slices = (slice(0, -self.window_size), 214 | slice(-self.window_size, -self.shift_size), 215 | slice(-self.shift_size, None)) 216 | cnt = 0 217 | for h in h_slices: 218 | for w in w_slices: 219 | img_mask[:, h, w, :] = cnt 220 | cnt += 1 221 | 222 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 223 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 224 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 225 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 226 | else: 227 | attn_mask = None 228 | 229 | self.register_buffer("attn_mask", attn_mask) 230 | 231 | def forward(self, x): 232 | H, W = self.input_resolution 233 | B, L, C = x.shape 234 | assert L == H * W, "input feature has wrong size" 235 | 236 | shortcut = x 237 | x = self.norm1(x) 238 | x = x.view(B, H, W, C) 239 | 240 | # cyclic shift 241 | if self.shift_size > 0: 242 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 243 | else: 244 | shifted_x = x 245 | 246 | # partition windows 247 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 248 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 249 | 250 | # W-MSA/SW-MSA 251 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 252 | 253 | # merge windows 254 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 255 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 256 | 257 | # reverse cyclic shift 258 | if self.shift_size > 0: 259 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 260 | else: 261 | x = shifted_x 262 | x = x.view(B, H * W, C) 263 | 264 | # FFN 265 | x = shortcut + self.drop_path(x) 266 | x = x + self.drop_path(self.mlp(self.norm2(x))) 267 | 268 | return x 269 | 270 | def extra_repr(self) -> str: 271 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 272 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 273 | 274 | def flops(self): 275 | flops = 0 276 | H, W = self.input_resolution 277 | # norm1 278 | flops += self.dim * H * W 279 | # W-MSA/SW-MSA 280 | nW = H * W / self.window_size / self.window_size 281 | flops += nW * self.attn.flops(self.window_size * self.window_size) 282 | # mlp 283 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 284 | # norm2 285 | flops += self.dim * H * W 286 | return flops 287 | 288 | 289 | class PatchMerging(nn.Module): 290 | r""" Patch Merging Layer. 291 | Args: 292 | input_resolution (tuple[int]): Resolution of input feature. 293 | dim (int): Number of input channels. 294 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 295 | """ 296 | 297 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 298 | super().__init__() 299 | self.input_resolution = input_resolution 300 | self.dim = dim 301 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 302 | self.norm = norm_layer(4 * dim) 303 | 304 | def forward(self, x): 305 | """ 306 | x: B, H*W, C 307 | """ 308 | H, W = self.input_resolution 309 | B, L, C = x.shape 310 | assert L == H * W, "input feature has wrong size" 311 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 312 | 313 | x = x.view(B, H, W, C) 314 | 315 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 316 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 317 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 318 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 319 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 320 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 321 | 322 | x = self.norm(x) 323 | x = self.reduction(x) 324 | 325 | return x 326 | 327 | def extra_repr(self) -> str: 328 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 329 | 330 | def flops(self): 331 | H, W = self.input_resolution 332 | flops = H * W * self.dim 333 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 334 | return flops 335 | 336 | 337 | class BasicLayer(nn.Module): 338 | """ A basic Swin Transformer layer for one stage. 339 | Args: 340 | dim (int): Number of input channels. 341 | input_resolution (tuple[int]): Input resolution. 342 | depth (int): Number of blocks. 343 | num_heads (int): Number of attention heads. 344 | window_size (int): Local window size. 345 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 346 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 347 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 348 | drop (float, optional): Dropout rate. Default: 0.0 349 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 350 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 351 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 352 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 353 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 354 | """ 355 | 356 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 357 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 358 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 359 | 360 | super().__init__() 361 | self.dim = dim 362 | self.input_resolution = input_resolution 363 | self.depth = depth 364 | self.use_checkpoint = use_checkpoint 365 | 366 | # build blocks 367 | self.blocks = nn.ModuleList([ 368 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 369 | num_heads=num_heads, window_size=window_size, 370 | shift_size=0 if (i % 2 == 0) else window_size // 2, 371 | mlp_ratio=mlp_ratio, 372 | qkv_bias=qkv_bias, qk_scale=qk_scale, 373 | drop=drop, attn_drop=attn_drop, 374 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 375 | norm_layer=norm_layer) 376 | for i in range(depth)]) 377 | 378 | # patch merging layer 379 | if downsample is not None: 380 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 381 | else: 382 | self.downsample = None 383 | 384 | def forward(self, x): 385 | for blk in self.blocks: 386 | if self.use_checkpoint: 387 | x = checkpoint.checkpoint(blk, x) 388 | else: 389 | x = blk(x) 390 | if self.downsample is not None: 391 | x = self.downsample(x) 392 | return x 393 | 394 | def extra_repr(self) -> str: 395 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 396 | 397 | def flops(self): 398 | flops = 0 399 | for blk in self.blocks: 400 | flops += blk.flops() 401 | if self.downsample is not None: 402 | flops += self.downsample.flops() 403 | return flops 404 | 405 | 406 | class PatchEmbed(nn.Module): 407 | r""" Image to Patch Embedding 408 | Args: 409 | img_size (int): Image size. Default: 224. 410 | patch_size (int): Patch token size. Default: 4. 411 | in_chans (int): Number of input image channels. Default: 3. 412 | embed_dim (int): Number of linear projection output channels. Default: 96. 413 | norm_layer (nn.Module, optional): Normalization layer. Default: None 414 | """ 415 | 416 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 417 | super().__init__() 418 | img_size = to_2tuple(img_size) 419 | patch_size = to_2tuple(patch_size) 420 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 421 | self.img_size = img_size 422 | self.patch_size = patch_size 423 | self.patches_resolution = patches_resolution 424 | self.num_patches = patches_resolution[0] * patches_resolution[1] 425 | 426 | self.in_chans = in_chans 427 | self.embed_dim = embed_dim 428 | 429 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 430 | if norm_layer is not None: 431 | self.norm = norm_layer(embed_dim) 432 | else: 433 | self.norm = None 434 | 435 | def forward(self, x): 436 | B, C, H, W = x.shape 437 | # FIXME look at relaxing size constraints 438 | assert H == self.img_size[0] and W == self.img_size[1], \ 439 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 440 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 441 | if self.norm is not None: 442 | x = self.norm(x) 443 | return x 444 | 445 | def flops(self): 446 | Ho, Wo = self.patches_resolution 447 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 448 | if self.norm is not None: 449 | flops += Ho * Wo * self.embed_dim 450 | return flops 451 | 452 | 453 | class SwinTransformer(nn.Module): 454 | r""" Swin Transformer 455 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 456 | https://arxiv.org/pdf/2103.14030 457 | Args: 458 | img_size (int | tuple(int)): Input image size. Default 224 459 | patch_size (int | tuple(int)): Patch size. Default: 4 460 | in_chans (int): Number of input image channels. Default: 3 461 | num_classes (int): Number of classes for classification head. Default: 1000 462 | embed_dim (int): Patch embedding dimension. Default: 96 463 | depths (tuple(int)): Depth of each Swin Transformer layer. 464 | num_heads (tuple(int)): Number of attention heads in different layers. 465 | window_size (int): Window size. Default: 7 466 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 467 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 468 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 469 | drop_rate (float): Dropout rate. Default: 0 470 | attn_drop_rate (float): Attention dropout rate. Default: 0 471 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 472 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 473 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 474 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 475 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 476 | """ 477 | 478 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 479 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 480 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 481 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 482 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 483 | use_checkpoint=False, **kwargs): 484 | super().__init__() 485 | 486 | self.num_classes = num_classes 487 | self.num_layers = len(depths) 488 | self.embed_dim = embed_dim 489 | self.ape = ape 490 | self.patch_norm = patch_norm 491 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 492 | self.mlp_ratio = mlp_ratio 493 | 494 | # split image into non-overlapping patches 495 | self.patch_embed = PatchEmbed( 496 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 497 | norm_layer=norm_layer if self.patch_norm else None) 498 | num_patches = self.patch_embed.num_patches 499 | patches_resolution = self.patch_embed.patches_resolution 500 | self.patches_resolution = patches_resolution 501 | 502 | # absolute position embedding 503 | if self.ape: 504 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 505 | trunc_normal_(self.absolute_pos_embed, std=.02) 506 | 507 | self.pos_drop = nn.Dropout(p=drop_rate) 508 | 509 | # stochastic depth 510 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 511 | 512 | # build layers 513 | self.layers = nn.ModuleList() 514 | for i_layer in range(self.num_layers): 515 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 516 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 517 | patches_resolution[1] // (2 ** i_layer)), 518 | depth=depths[i_layer], 519 | num_heads=num_heads[i_layer], 520 | window_size=window_size, 521 | mlp_ratio=self.mlp_ratio, 522 | qkv_bias=qkv_bias, qk_scale=qk_scale, 523 | drop=drop_rate, attn_drop=attn_drop_rate, 524 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 525 | norm_layer=norm_layer, 526 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 527 | use_checkpoint=use_checkpoint) 528 | self.layers.append(layer) 529 | 530 | self.norm = norm_layer(self.num_features) 531 | self.avgpool = nn.AdaptiveAvgPool1d(1) 532 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 533 | 534 | self.apply(self._init_weights) 535 | 536 | def _init_weights(self, m): 537 | if isinstance(m, nn.Linear): 538 | trunc_normal_(m.weight, std=.02) 539 | if isinstance(m, nn.Linear) and m.bias is not None: 540 | nn.init.constant_(m.bias, 0) 541 | elif isinstance(m, nn.LayerNorm): 542 | nn.init.constant_(m.bias, 0) 543 | nn.init.constant_(m.weight, 1.0) 544 | 545 | @torch.jit.ignore 546 | def no_weight_decay(self): 547 | return {'absolute_pos_embed'} 548 | 549 | @torch.jit.ignore 550 | def no_weight_decay_keywords(self): 551 | return {'relative_position_bias_table'} 552 | 553 | def forward_features(self, x): 554 | x = self.patch_embed(x) 555 | if self.ape: 556 | x = x + self.absolute_pos_embed 557 | x = self.pos_drop(x) 558 | 559 | for layer in self.layers: 560 | x = layer(x) 561 | 562 | x = self.norm(x) # B L C 563 | x = self.avgpool(x.transpose(1, 2)) # B C 1 564 | x = torch.flatten(x, 1) 565 | return x 566 | 567 | def forward(self, x): 568 | x = self.forward_features(x) 569 | x = self.head(x) 570 | return x 571 | 572 | def flops(self): 573 | flops = 0 574 | flops += self.patch_embed.flops() 575 | for i, layer in enumerate(self.layers): 576 | flops += layer.flops() 577 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 578 | flops += self.num_features * self.num_classes 579 | return flops 580 | 581 | 582 | @register_model 583 | def swin_tiny_patch4_window7_224(pretrain=False, **kwargs): 584 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, use_checkpoint=False) 585 | return model -------------------------------------------------------------------------------- /models/swin_moex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Omid Nejati 3 | Email: omid_nejaty@alumni.iust.ac.ir 4 | 5 | Implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". 6 | Code borrowed from https://github.com/microsoft/Swin-Transformer 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | from timm.models.registry import register_model 14 | 15 | def moex(x, swap_index, norm_type, epsilon=1e-5, positive_only=False): 16 | '''MoEx operation''' 17 | dtype = x.dtype 18 | x = x.float() 19 | 20 | B, C, L = x.shape 21 | if norm_type == 'bn': 22 | norm_dims = [0, 2, 3] 23 | elif norm_type == 'in': 24 | norm_dims = [2, 3] 25 | elif norm_type == 'ln': 26 | norm_dims = [1, 2, 3] 27 | elif norm_type == 'pono': 28 | norm_dims = [1] 29 | elif norm_type.startswith('gn'): 30 | if norm_type.startswith('gn-d'): 31 | # gn-d4 means GN where each group has 4 dims 32 | G_dim = int(norm_type[4:]) 33 | G = C // G_dim 34 | else: 35 | # gn4 means GN with 4 groups 36 | G = int(norm_type[2:]) 37 | G_dim = C // G 38 | x = x.view(B, G, G_dim, H, W) 39 | norm_dims = [2, 3, 4] 40 | elif norm_type.startswith('gpono'): 41 | if norm_type.startswith('gpono-d'): 42 | # gpono-d4 means GPONO where each group has 4 dims 43 | G_dim = int(norm_type[len('gpono-d'):]) 44 | G = C // G_dim 45 | else: 46 | # gpono4 means GPONO with 4 groups 47 | G = int(norm_type[len('gpono'):]) 48 | G_dim = C // G 49 | x = x.view(B, G, G_dim, H, W) 50 | norm_dims = [2] 51 | else: 52 | raise NotImplementedError(f'norm_type={norm_type}') 53 | 54 | if positive_only: 55 | x_pos = F.relu(x) 56 | s1 = x_pos.sum(dim=norm_dims, keepdim=True) 57 | s2 = x_pos.pow(2).sum(dim=norm_dims, keepdim=True) 58 | count = x_pos.gt(0).sum(dim=norm_dims, keepdim=True) 59 | count[count == 0] = 1 # deal with 0/0 60 | mean = s1 / count 61 | var = s2 / count - mean.pow(2) 62 | std = var.add(epsilon).sqrt() 63 | else: 64 | mean = x.mean(dim=norm_dims, keepdim=True) 65 | std = x.var(dim=norm_dims, keepdim=True).add(epsilon).sqrt() 66 | swap_mean = mean[swap_index] 67 | swap_std = std[swap_index] 68 | # output = (x - mean) / std * swap_std + swap_mean 69 | # equvalent but for efficient 70 | scale = swap_std / std 71 | shift = swap_mean - mean * scale 72 | output = x * scale + shift 73 | return output.view(B, C, L).to(dtype) 74 | 75 | class Mlp(nn.Module): 76 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 77 | super().__init__() 78 | out_features = out_features or in_features 79 | hidden_features = hidden_features or in_features 80 | self.fc1 = nn.Linear(in_features, hidden_features) 81 | self.act = act_layer() 82 | self.fc2 = nn.Linear(hidden_features, out_features) 83 | self.drop = nn.Dropout(drop) 84 | 85 | def forward(self, x): 86 | x = self.fc1(x) 87 | x = self.act(x) 88 | x = self.drop(x) 89 | x = self.fc2(x) 90 | x = self.drop(x) 91 | return x 92 | 93 | 94 | def window_partition(x, window_size): 95 | """ 96 | Args: 97 | x: (B, H, W, C) 98 | window_size (int): window size 99 | Returns: 100 | windows: (num_windows*B, window_size, window_size, C) 101 | """ 102 | B, H, W, C = x.shape 103 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 104 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 105 | return windows 106 | 107 | 108 | def window_reverse(windows, window_size, H, W): 109 | """ 110 | Args: 111 | windows: (num_windows*B, window_size, window_size, C) 112 | window_size (int): Window size 113 | H (int): Height of image 114 | W (int): Width of image 115 | Returns: 116 | x: (B, H, W, C) 117 | """ 118 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 119 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 120 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 121 | return x 122 | 123 | 124 | class WindowAttention(nn.Module): 125 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 126 | It supports both of shifted and non-shifted window. 127 | Args: 128 | dim (int): Number of input channels. 129 | window_size (tuple[int]): The height and width of the window. 130 | num_heads (int): Number of attention heads. 131 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 132 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 133 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 134 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 135 | """ 136 | 137 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 138 | 139 | super().__init__() 140 | self.dim = dim 141 | self.window_size = window_size # Wh, Ww 142 | self.num_heads = num_heads 143 | head_dim = dim // num_heads 144 | self.scale = qk_scale or head_dim ** -0.5 145 | 146 | # define a parameter table of relative position bias 147 | self.relative_position_bias_table = nn.Parameter( 148 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 149 | 150 | # get pair-wise relative position index for each token inside the window 151 | coords_h = torch.arange(self.window_size[0]) 152 | coords_w = torch.arange(self.window_size[1]) 153 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 154 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 155 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 156 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 157 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 158 | relative_coords[:, :, 1] += self.window_size[1] - 1 159 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 160 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 161 | self.register_buffer("relative_position_index", relative_position_index) 162 | 163 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 164 | self.attn_drop = nn.Dropout(attn_drop) 165 | self.proj = nn.Linear(dim, dim) 166 | self.proj_drop = nn.Dropout(proj_drop) 167 | 168 | trunc_normal_(self.relative_position_bias_table, std=.02) 169 | self.softmax = nn.Softmax(dim=-1) 170 | 171 | def forward(self, x, mask=None): 172 | """ 173 | Args: 174 | x: input features with shape of (num_windows*B, N, C) 175 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 176 | """ 177 | B_, N, C = x.shape 178 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 179 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 180 | 181 | q = q * self.scale 182 | attn = (q @ k.transpose(-2, -1)) 183 | 184 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 185 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 186 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 187 | attn = attn + relative_position_bias.unsqueeze(0) 188 | 189 | if mask is not None: 190 | nW = mask.shape[0] 191 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 192 | attn = attn.view(-1, self.num_heads, N, N) 193 | attn = self.softmax(attn) 194 | else: 195 | attn = self.softmax(attn) 196 | 197 | attn = self.attn_drop(attn) 198 | 199 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 200 | x = self.proj(x) 201 | x = self.proj_drop(x) 202 | return x 203 | 204 | def extra_repr(self) -> str: 205 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 206 | 207 | def flops(self, N): 208 | # calculate flops for 1 window with token length of N 209 | flops = 0 210 | # qkv = self.qkv(x) 211 | flops += N * self.dim * 3 * self.dim 212 | # attn = (q @ k.transpose(-2, -1)) 213 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 214 | # x = (attn @ v) 215 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 216 | # x = self.proj(x) 217 | flops += N * self.dim * self.dim 218 | return flops 219 | 220 | 221 | class SwinTransformerBlock(nn.Module): 222 | r""" Swin Transformer Block. 223 | Args: 224 | dim (int): Number of input channels. 225 | input_resolution (tuple[int]): Input resulotion. 226 | num_heads (int): Number of attention heads. 227 | window_size (int): Window size. 228 | shift_size (int): Shift size for SW-MSA. 229 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 230 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 231 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 232 | drop (float, optional): Dropout rate. Default: 0.0 233 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 234 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 235 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 236 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 237 | """ 238 | 239 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 240 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 241 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 242 | super().__init__() 243 | self.dim = dim 244 | self.input_resolution = input_resolution 245 | self.num_heads = num_heads 246 | self.window_size = window_size 247 | self.shift_size = shift_size 248 | self.mlp_ratio = mlp_ratio 249 | if min(self.input_resolution) <= self.window_size: 250 | # if window size is larger than input resolution, we don't partition windows 251 | self.shift_size = 0 252 | self.window_size = min(self.input_resolution) 253 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 254 | 255 | self.norm1 = norm_layer(dim) 256 | self.attn = WindowAttention( 257 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 258 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 259 | 260 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 261 | self.norm2 = norm_layer(dim) 262 | mlp_hidden_dim = int(dim * mlp_ratio) 263 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 264 | 265 | if self.shift_size > 0: 266 | # calculate attention mask for SW-MSA 267 | H, W = self.input_resolution 268 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 269 | h_slices = (slice(0, -self.window_size), 270 | slice(-self.window_size, -self.shift_size), 271 | slice(-self.shift_size, None)) 272 | w_slices = (slice(0, -self.window_size), 273 | slice(-self.window_size, -self.shift_size), 274 | slice(-self.shift_size, None)) 275 | cnt = 0 276 | for h in h_slices: 277 | for w in w_slices: 278 | img_mask[:, h, w, :] = cnt 279 | cnt += 1 280 | 281 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 282 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 283 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 284 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 285 | else: 286 | attn_mask = None 287 | 288 | self.register_buffer("attn_mask", attn_mask) 289 | 290 | def forward(self, x): 291 | H, W = self.input_resolution 292 | B, L, C = x.shape 293 | assert L == H * W, "input feature has wrong size" 294 | 295 | shortcut = x 296 | x = self.norm1(x) 297 | x = x.view(B, H, W, C) 298 | 299 | # cyclic shift 300 | if self.shift_size > 0: 301 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 302 | else: 303 | shifted_x = x 304 | 305 | # partition windows 306 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 307 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 308 | 309 | # W-MSA/SW-MSA 310 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 311 | 312 | # merge windows 313 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 314 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 315 | 316 | # reverse cyclic shift 317 | if self.shift_size > 0: 318 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 319 | else: 320 | x = shifted_x 321 | x = x.view(B, H * W, C) 322 | 323 | # FFN 324 | x = shortcut + self.drop_path(x) 325 | x = x + self.drop_path(self.mlp(self.norm2(x))) 326 | 327 | return x 328 | 329 | def extra_repr(self) -> str: 330 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 331 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 332 | 333 | def flops(self): 334 | flops = 0 335 | H, W = self.input_resolution 336 | # norm1 337 | flops += self.dim * H * W 338 | # W-MSA/SW-MSA 339 | nW = H * W / self.window_size / self.window_size 340 | flops += nW * self.attn.flops(self.window_size * self.window_size) 341 | # mlp 342 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 343 | # norm2 344 | flops += self.dim * H * W 345 | return flops 346 | 347 | 348 | class PatchMerging(nn.Module): 349 | r""" Patch Merging Layer. 350 | Args: 351 | input_resolution (tuple[int]): Resolution of input feature. 352 | dim (int): Number of input channels. 353 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 354 | """ 355 | 356 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 357 | super().__init__() 358 | self.input_resolution = input_resolution 359 | self.dim = dim 360 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 361 | self.norm = norm_layer(4 * dim) 362 | 363 | def forward(self, x): 364 | """ 365 | x: B, H*W, C 366 | """ 367 | H, W = self.input_resolution 368 | B, L, C = x.shape 369 | assert L == H * W, "input feature has wrong size" 370 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 371 | 372 | x = x.view(B, H, W, C) 373 | 374 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 375 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 376 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 377 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 378 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 379 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 380 | 381 | x = self.norm(x) 382 | x = self.reduction(x) 383 | 384 | return x 385 | 386 | def extra_repr(self) -> str: 387 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 388 | 389 | def flops(self): 390 | H, W = self.input_resolution 391 | flops = H * W * self.dim 392 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 393 | return flops 394 | 395 | 396 | class BasicLayer(nn.Module): 397 | """ A basic Swin Transformer layer for one stage. 398 | Args: 399 | dim (int): Number of input channels. 400 | input_resolution (tuple[int]): Input resolution. 401 | depth (int): Number of blocks. 402 | num_heads (int): Number of attention heads. 403 | window_size (int): Local window size. 404 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 405 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 406 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 407 | drop (float, optional): Dropout rate. Default: 0.0 408 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 409 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 410 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 411 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 412 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 413 | """ 414 | 415 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 416 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 417 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 418 | 419 | super().__init__() 420 | self.dim = dim 421 | self.input_resolution = input_resolution 422 | self.depth = depth 423 | self.use_checkpoint = use_checkpoint 424 | 425 | # build blocks 426 | self.blocks = nn.ModuleList([ 427 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 428 | num_heads=num_heads, window_size=window_size, 429 | shift_size=0 if (i % 2 == 0) else window_size // 2, 430 | mlp_ratio=mlp_ratio, 431 | qkv_bias=qkv_bias, qk_scale=qk_scale, 432 | drop=drop, attn_drop=attn_drop, 433 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 434 | norm_layer=norm_layer) 435 | for i in range(depth)]) 436 | 437 | # patch merging layer 438 | if downsample is not None: 439 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 440 | else: 441 | self.downsample = None 442 | 443 | def forward(self, x): 444 | for blk in self.blocks: 445 | if self.use_checkpoint: 446 | x = checkpoint.checkpoint(blk, x) 447 | else: 448 | x = blk(x) 449 | if self.downsample is not None: 450 | x = self.downsample(x) 451 | return x 452 | 453 | def extra_repr(self) -> str: 454 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 455 | 456 | def flops(self): 457 | flops = 0 458 | for blk in self.blocks: 459 | flops += blk.flops() 460 | if self.downsample is not None: 461 | flops += self.downsample.flops() 462 | return flops 463 | 464 | 465 | class PatchEmbed(nn.Module): 466 | r""" Image to Patch Embedding 467 | Args: 468 | img_size (int): Image size. Default: 224. 469 | patch_size (int): Patch token size. Default: 4. 470 | in_chans (int): Number of input image channels. Default: 3. 471 | embed_dim (int): Number of linear projection output channels. Default: 96. 472 | norm_layer (nn.Module, optional): Normalization layer. Default: None 473 | """ 474 | 475 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 476 | super().__init__() 477 | img_size = to_2tuple(img_size) 478 | patch_size = to_2tuple(patch_size) 479 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 480 | self.img_size = img_size 481 | self.patch_size = patch_size 482 | self.patches_resolution = patches_resolution 483 | self.num_patches = patches_resolution[0] * patches_resolution[1] 484 | 485 | self.in_chans = in_chans 486 | self.embed_dim = embed_dim 487 | 488 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 489 | if norm_layer is not None: 490 | self.norm = norm_layer(embed_dim) 491 | else: 492 | self.norm = None 493 | 494 | def forward(self, x): 495 | B, C, H, W = x.shape 496 | # FIXME look at relaxing size constraints 497 | assert H == self.img_size[0] and W == self.img_size[1], \ 498 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 499 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 500 | if self.norm is not None: 501 | x = self.norm(x) 502 | return x 503 | 504 | def flops(self): 505 | Ho, Wo = self.patches_resolution 506 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 507 | if self.norm is not None: 508 | flops += Ho * Wo * self.embed_dim 509 | return flops 510 | 511 | 512 | class SwinTransformer(nn.Module): 513 | r""" Swin Transformer 514 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 515 | https://arxiv.org/pdf/2103.14030 516 | Args: 517 | img_size (int | tuple(int)): Input image size. Default 224 518 | patch_size (int | tuple(int)): Patch size. Default: 4 519 | in_chans (int): Number of input image channels. Default: 3 520 | num_classes (int): Number of classes for classification head. Default: 1000 521 | embed_dim (int): Patch embedding dimension. Default: 96 522 | depths (tuple(int)): Depth of each Swin Transformer layer. 523 | num_heads (tuple(int)): Number of attention heads in different layers. 524 | window_size (int): Window size. Default: 7 525 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 526 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 527 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 528 | drop_rate (float): Dropout rate. Default: 0 529 | attn_drop_rate (float): Attention dropout rate. Default: 0 530 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 531 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 532 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 533 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 534 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 535 | """ 536 | 537 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 538 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 539 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 540 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 541 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 542 | use_checkpoint=False, **kwargs): 543 | super().__init__() 544 | 545 | self.num_classes = num_classes 546 | self.num_layers = len(depths) 547 | self.embed_dim = embed_dim 548 | self.ape = ape 549 | self.patch_norm = patch_norm 550 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 551 | self.mlp_ratio = mlp_ratio 552 | 553 | # split image into non-overlapping patches 554 | self.patch_embed = PatchEmbed( 555 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 556 | norm_layer=norm_layer if self.patch_norm else None) 557 | num_patches = self.patch_embed.num_patches 558 | patches_resolution = self.patch_embed.patches_resolution 559 | self.patches_resolution = patches_resolution 560 | 561 | # absolute position embedding 562 | if self.ape: 563 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 564 | trunc_normal_(self.absolute_pos_embed, std=.02) 565 | 566 | self.pos_drop = nn.Dropout(p=drop_rate) 567 | 568 | # stochastic depth 569 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 570 | 571 | # build layers 572 | self.layers = nn.ModuleList() 573 | for i_layer in range(self.num_layers): 574 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 575 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 576 | patches_resolution[1] // (2 ** i_layer)), 577 | depth=depths[i_layer], 578 | num_heads=num_heads[i_layer], 579 | window_size=window_size, 580 | mlp_ratio=self.mlp_ratio, 581 | qkv_bias=qkv_bias, qk_scale=qk_scale, 582 | drop=drop_rate, attn_drop=attn_drop_rate, 583 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 584 | norm_layer=norm_layer, 585 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 586 | use_checkpoint=use_checkpoint) 587 | self.layers.append(layer) 588 | 589 | self.norm = norm_layer(self.num_features) 590 | self.avgpool = nn.AdaptiveAvgPool1d(1) 591 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 592 | 593 | self.apply(self._init_weights) 594 | 595 | def _init_weights(self, m): 596 | if isinstance(m, nn.Linear): 597 | trunc_normal_(m.weight, std=.02) 598 | if isinstance(m, nn.Linear) and m.bias is not None: 599 | nn.init.constant_(m.bias, 0) 600 | elif isinstance(m, nn.LayerNorm): 601 | nn.init.constant_(m.bias, 0) 602 | nn.init.constant_(m.weight, 1.0) 603 | 604 | @torch.jit.ignore 605 | def no_weight_decay(self): 606 | return {'absolute_pos_embed'} 607 | 608 | @torch.jit.ignore 609 | def no_weight_decay_keywords(self): 610 | return {'relative_position_bias_table'} 611 | 612 | def forward_features(self, x, swap_index, moex_norm, moex_epsilon, 613 | moex_layer, moex_positive_only): 614 | x = self.patch_embed(x) 615 | 616 | # moex 617 | if swap_index is not None and moex_layer == 'stem': 618 | x = moex(x, swap_index, moex_norm, moex_epsilon, moex_positive_only) 619 | 620 | if self.ape: 621 | x = x + self.absolute_pos_embed 622 | x = self.pos_drop(x) 623 | 624 | for layer in self.layers: 625 | x = layer(x) 626 | 627 | x = self.norm(x) # B L C 628 | x = self.avgpool(x.transpose(1, 2)) # B C 1 629 | x = torch.flatten(x, 1) 630 | return x 631 | 632 | def forward(self, x, swap_index=None, moex_norm='pono', moex_epsilon=1e-5, 633 | moex_layer='stem', moex_positive_only=False): 634 | 635 | x = self.forward_features(x, swap_index, moex_norm, moex_epsilon, 636 | moex_layer, moex_positive_only) 637 | x = self.head(x) 638 | return x 639 | 640 | def flops(self): 641 | flops = 0 642 | flops += self.patch_embed.flops() 643 | for i, layer in enumerate(self.layers): 644 | flops += layer.flops() 645 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 646 | flops += self.num_features * self.num_classes 647 | return flops 648 | 649 | 650 | @register_model 651 | def swin_tiny_patch4_window7_224(pretrain=False, **kwargs): 652 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, use_checkpoint=False) 653 | return model --------------------------------------------------------------------------------