├── LICENSE ├── README.md └── mobilevit.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Chin-Hsuan Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileViT 2 | 3 | ## Overview 4 | 5 | This is a PyTorch implementation of MobileViT specified in ["MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer"](https://arxiv.org/abs/2110.02178), arXiv 2021. 6 | 7 | ![img](https://user-images.githubusercontent.com/67839539/136470152-2573529e-1a24-4494-821d-70eb4647a51d.png) 8 | 9 | 👉 Check out [CoAtNet](https://github.com/chinhsuanwu/coatnet-pytorch) if you are interested in other **Convolution + Transformer** models. 10 | 11 | ## Usage 12 | 13 | ```python 14 | import torch 15 | from mobilevit import mobilevit_xxs 16 | 17 | img = torch.randn(1, 3, 256, 256) 18 | vit = mobilevit_xxs() 19 | out = vit(img) 20 | ``` 21 | 22 | ## Citation 23 | 24 | ```bibtex 25 | @article{mehta2021mobilevit, 26 | title={MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer}, 27 | author={Mehta, Sachin and Rastegari, Mohammad}, 28 | journal={arXiv preprint arXiv:2110.02178}, 29 | year={2021} 30 | } 31 | ``` 32 | 33 | ## Credits 34 | 35 | Code adapted from [MobileNetV2](https://github.com/tonylins/pytorch-mobilenet-v2) and [ViT](https://github.com/lucidrains/vit-pytorch). 36 | -------------------------------------------------------------------------------- /mobilevit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops import rearrange 5 | 6 | 7 | def conv_1x1_bn(inp, oup): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 10 | nn.BatchNorm2d(oup), 11 | nn.SiLU() 12 | ) 13 | 14 | 15 | def conv_nxn_bn(inp, oup, kernal_size=3, stride=1): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.SiLU() 20 | ) 21 | 22 | 23 | class PreNorm(nn.Module): 24 | def __init__(self, dim, fn): 25 | super().__init__() 26 | self.norm = nn.LayerNorm(dim) 27 | self.fn = fn 28 | 29 | def forward(self, x, **kwargs): 30 | return self.fn(self.norm(x), **kwargs) 31 | 32 | 33 | class FeedForward(nn.Module): 34 | def __init__(self, dim, hidden_dim, dropout=0.): 35 | super().__init__() 36 | self.net = nn.Sequential( 37 | nn.Linear(dim, hidden_dim), 38 | nn.SiLU(), 39 | nn.Dropout(dropout), 40 | nn.Linear(hidden_dim, dim), 41 | nn.Dropout(dropout) 42 | ) 43 | 44 | def forward(self, x): 45 | return self.net(x) 46 | 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 50 | super().__init__() 51 | inner_dim = dim_head * heads 52 | project_out = not (heads == 1 and dim_head == dim) 53 | 54 | self.heads = heads 55 | self.scale = dim_head ** -0.5 56 | 57 | self.attend = nn.Softmax(dim = -1) 58 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 59 | 60 | self.to_out = nn.Sequential( 61 | nn.Linear(inner_dim, dim), 62 | nn.Dropout(dropout) 63 | ) if project_out else nn.Identity() 64 | 65 | def forward(self, x): 66 | qkv = self.to_qkv(x).chunk(3, dim=-1) 67 | q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv) 68 | 69 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 70 | attn = self.attend(dots) 71 | out = torch.matmul(attn, v) 72 | out = rearrange(out, 'b p h n d -> b p n (h d)') 73 | return self.to_out(out) 74 | 75 | 76 | class Transformer(nn.Module): 77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 78 | super().__init__() 79 | self.layers = nn.ModuleList([]) 80 | for _ in range(depth): 81 | self.layers.append(nn.ModuleList([ 82 | PreNorm(dim, Attention(dim, heads, dim_head, dropout)), 83 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) 84 | ])) 85 | 86 | def forward(self, x): 87 | for attn, ff in self.layers: 88 | x = attn(x) + x 89 | x = ff(x) + x 90 | return x 91 | 92 | 93 | class MV2Block(nn.Module): 94 | def __init__(self, inp, oup, stride=1, expansion=4): 95 | super().__init__() 96 | self.stride = stride 97 | assert stride in [1, 2] 98 | 99 | hidden_dim = int(inp * expansion) 100 | self.use_res_connect = self.stride == 1 and inp == oup 101 | 102 | if expansion == 1: 103 | self.conv = nn.Sequential( 104 | # dw 105 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 106 | nn.BatchNorm2d(hidden_dim), 107 | nn.SiLU(), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 110 | nn.BatchNorm2d(oup), 111 | ) 112 | else: 113 | self.conv = nn.Sequential( 114 | # pw 115 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 116 | nn.BatchNorm2d(hidden_dim), 117 | nn.SiLU(), 118 | # dw 119 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 120 | nn.BatchNorm2d(hidden_dim), 121 | nn.SiLU(), 122 | # pw-linear 123 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 124 | nn.BatchNorm2d(oup), 125 | ) 126 | 127 | def forward(self, x): 128 | if self.use_res_connect: 129 | return x + self.conv(x) 130 | else: 131 | return self.conv(x) 132 | 133 | 134 | class MobileViTBlock(nn.Module): 135 | def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): 136 | super().__init__() 137 | self.ph, self.pw = patch_size 138 | 139 | self.conv1 = conv_nxn_bn(channel, channel, kernel_size) 140 | self.conv2 = conv_1x1_bn(channel, dim) 141 | 142 | self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout) 143 | 144 | self.conv3 = conv_1x1_bn(dim, channel) 145 | self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) 146 | 147 | def forward(self, x): 148 | y = x.clone() 149 | 150 | # Local representations 151 | x = self.conv1(x) 152 | x = self.conv2(x) 153 | 154 | # Global representations 155 | _, _, h, w = x.shape 156 | x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) 157 | x = self.transformer(x) 158 | x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) 159 | 160 | # Fusion 161 | x = self.conv3(x) 162 | x = torch.cat((x, y), 1) 163 | x = self.conv4(x) 164 | return x 165 | 166 | 167 | class MobileViT(nn.Module): 168 | def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)): 169 | super().__init__() 170 | ih, iw = image_size 171 | ph, pw = patch_size 172 | assert ih % ph == 0 and iw % pw == 0 173 | 174 | L = [2, 4, 3] 175 | 176 | self.conv1 = conv_nxn_bn(3, channels[0], stride=2) 177 | 178 | self.mv2 = nn.ModuleList([]) 179 | self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion)) 180 | self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion)) 181 | self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) 182 | self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat 183 | self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion)) 184 | self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion)) 185 | self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion)) 186 | 187 | self.mvit = nn.ModuleList([]) 188 | self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2))) 189 | self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4))) 190 | self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4))) 191 | 192 | self.conv2 = conv_1x1_bn(channels[-2], channels[-1]) 193 | 194 | self.pool = nn.AvgPool2d(ih//32, 1) 195 | self.fc = nn.Linear(channels[-1], num_classes, bias=False) 196 | 197 | def forward(self, x): 198 | x = self.conv1(x) 199 | x = self.mv2[0](x) 200 | 201 | x = self.mv2[1](x) 202 | x = self.mv2[2](x) 203 | x = self.mv2[3](x) # Repeat 204 | 205 | x = self.mv2[4](x) 206 | x = self.mvit[0](x) 207 | 208 | x = self.mv2[5](x) 209 | x = self.mvit[1](x) 210 | 211 | x = self.mv2[6](x) 212 | x = self.mvit[2](x) 213 | x = self.conv2(x) 214 | 215 | x = self.pool(x).view(-1, x.shape[1]) 216 | x = self.fc(x) 217 | return x 218 | 219 | 220 | def mobilevit_xxs(): 221 | dims = [64, 80, 96] 222 | channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320] 223 | return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2) 224 | 225 | 226 | def mobilevit_xs(): 227 | dims = [96, 120, 144] 228 | channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384] 229 | return MobileViT((256, 256), dims, channels, num_classes=1000) 230 | 231 | 232 | def mobilevit_s(): 233 | dims = [144, 192, 240] 234 | channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] 235 | return MobileViT((256, 256), dims, channels, num_classes=1000) 236 | 237 | 238 | def count_parameters(model): 239 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 240 | 241 | 242 | if __name__ == '__main__': 243 | img = torch.randn(5, 3, 256, 256) 244 | 245 | vit = mobilevit_xxs() 246 | out = vit(img) 247 | print(out.shape) 248 | print(count_parameters(vit)) 249 | 250 | vit = mobilevit_xs() 251 | out = vit(img) 252 | print(out.shape) 253 | print(count_parameters(vit)) 254 | 255 | vit = mobilevit_s() 256 | out = vit(img) 257 | print(out.shape) 258 | print(count_parameters(vit)) --------------------------------------------------------------------------------