├── README.md └── edgevit.py /README.md: -------------------------------------------------------------------------------- 1 | This is an unofficial PyTorch implementation of EdgeViT in "EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers", arXiv 2022. 2 | 3 | 4 | 5 | Pretrained models will come soon. 6 | 7 | 8 | 9 | ## Usage 10 | 11 | ```python 12 | from edgevit import EdgeViT_XXS, EdgeViT_XS, EdgeViT_S 13 | 14 | model = EdgeViT_XXS() 15 | inputs = torch.randn((1, 3, 224, 224)) 16 | print(model(inputs)) 17 | ``` 18 | 19 | 20 | 21 | ## Citation 22 | 23 | ``` 24 | @article{pan2022edgevits, 25 | title={EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers}, 26 | author={Pan, Junting and Bulat, Adrian and Tan, Fuwen and Zhu, Xiatian and Dudziak, Lukasz and Li, Hongsheng and Tzimiropoulos, Georgios and Martinez, Brais}, 27 | journal={arXiv preprint arXiv:2205.03436}, 28 | year={2022} 29 | } 30 | ``` 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /edgevit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | edgevit_configs = { 5 | 'XXS': { 6 | 'channels': (36, 72, 144, 288), 7 | 'blocks': (1, 1, 3, 2), 8 | 'heads': (1, 2, 4, 8) 9 | } 10 | , 11 | 'XS': { 12 | 'channels': (48, 96, 240, 384), 13 | 'blocks': (1, 1, 2, 2), 14 | 'heads': (1, 2, 4, 8) 15 | } 16 | , 17 | 'S': { 18 | 'channels': (48, 96, 240, 384), 19 | 'blocks': (1, 2, 3, 2), 20 | 'heads': (1, 2, 4, 8) 21 | } 22 | } 23 | 24 | HYPERPARAMETERS = { 25 | 'r': (4, 2, 2, 1) 26 | } 27 | 28 | 29 | class Residual(nn.Module): 30 | def __init__(self, module): 31 | super().__init__() 32 | self.module = module 33 | 34 | def forward(self, x): 35 | return x + self.module(x) 36 | 37 | 38 | class ConditionalPositionalEncoding(nn.Sequential): 39 | def __init__(self, channels): 40 | super().__init__() 41 | self.add_module('conditional_ositional_encoding', nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels, bias=False)) 42 | 43 | 44 | class MLP(nn.Sequential): 45 | def __init__(self, channels): 46 | super().__init__() 47 | expansion = 4 48 | self.add_module('mlp_layer_0', nn.Conv2d(channels, channels*expansion, kernel_size=1, bias=False)) 49 | self.add_module('mlp_act', nn.GELU()) 50 | self.add_module('mlp_layer_1', nn.Conv2d(channels*expansion, channels, kernel_size=1, bias=False)) 51 | 52 | 53 | class LocalAggModule(nn.Sequential): 54 | def __init__(self, channels): 55 | super().__init__() 56 | self.add_module('pointwise_prenorm_0', nn.BatchNorm2d(channels)) 57 | self.add_module('pointwise_conv_0', nn.Conv2d(channels, channels, kernel_size=1, bias=False)) 58 | self.add_module('depthwise_conv', nn.Conv2d(channels, channels, padding=1, kernel_size=3, groups=channels, bias=False)) 59 | self.add_module('pointwise_prenorm_1', nn.BatchNorm2d(channels)) 60 | self.add_module('pointwise_conv_1', nn.Conv2d(channels, channels, kernel_size=1, bias=False)) 61 | 62 | 63 | class GlobalSparseAttetionModule(nn.Module): 64 | def __init__(self, channels, r, heads): 65 | super().__init__() 66 | self.head_dim = channels//heads 67 | self.scale = self.head_dim**-0.5 68 | self.num_heads = heads 69 | 70 | self.sparse_sampler = nn.AvgPool2d(kernel_size=1, stride=r) 71 | self.norm = nn.GroupNorm(num_groups=1, num_channels=channels) 72 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 73 | self.local_prop = nn.ConvTranspose2d(channels, channels, kernel_size=r, stride=r, groups=channels) 74 | self.proj = nn.Conv2d(channels, channels, kernel_size=1, bias=False) 75 | 76 | def forward(self, x): 77 | x = self.sparse_sampler(x) 78 | B, C, H, W = x.shape 79 | q, k, v = self.qkv(x).view(B, self.num_heads, -1, H*W).split([self.head_dim, self.head_dim, self.head_dim], dim=2) 80 | attn = (q.transpose(-2, -1) @ k).softmax(-1) 81 | x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) 82 | x = self.local_prop(x) 83 | x = self.norm(x) 84 | x = self.proj(x) 85 | 86 | return x 87 | 88 | 89 | class ConvDownsampling(nn.Sequential): 90 | def __init__(self, inp, oup, r, bias=False): 91 | super().__init__() 92 | self.add_module('downsampling_conv', nn.Conv2d(inp, oup, kernel_size=r, stride=r, bias=bias)) 93 | self.add_module('downsampling_norm', nn.GroupNorm(num_groups=1, num_channels=oup)) 94 | 95 | 96 | class EdgeViT(nn.Module): 97 | def __init__(self, channels, blocks, heads, r=[4, 2, 2, 1], num_classes=1000, distillation=False): 98 | super().__init__() 99 | self.distillation = distillation 100 | 101 | l = [] 102 | in_channels = 3 103 | for stage_id, (num_channels, num_blocks, num_heads, sample_ratio) in enumerate(zip(channels, blocks, heads, r)): 104 | l.append(ConvDownsampling(inp=in_channels, oup=num_channels, r=4 if stage_id == 0 else 2)) 105 | 106 | for _ in range(num_blocks): 107 | l.append(Residual(ConditionalPositionalEncoding(num_channels))) 108 | l.append(Residual(LocalAggModule(num_channels))) 109 | l.append(Residual(MLP(num_channels))) 110 | l.append(Residual(ConditionalPositionalEncoding(num_channels))) 111 | l.append(Residual(GlobalSparseAttetionModule(channels=num_channels, r=sample_ratio, heads=num_heads))) 112 | l.append(Residual(MLP(num_channels))) 113 | 114 | in_channels = num_channels 115 | 116 | self.main_body = nn.Sequential(*l) 117 | self.pooling = nn.AdaptiveAvgPool2d(1) 118 | 119 | self.classifier = nn.Linear(in_channels, num_classes, bias=True) 120 | 121 | if self.distillation: 122 | self.dist_classifier = nn.Linear(in_channels, num_classes, bias=True) 123 | 124 | def forward(self, x): 125 | x = self.main_body(x) 126 | x = self.pooling(x).flatten(1) 127 | 128 | if self.distillation: 129 | x = self.classifier(x), self.dist_classifier(x) 130 | 131 | if not self.training: 132 | x = 1/2 * (x[0] + x[1]) 133 | else: 134 | x = self.classifier(x) 135 | 136 | return x 137 | 138 | 139 | def EdgeViT_XXS(pretrained=False): 140 | model = EdgeViT(**edgevit_configs['XXS']) 141 | 142 | if pretrained: 143 | raise NotImplementedError 144 | 145 | return model 146 | 147 | def EdgeViT_XS(pretrained=False): 148 | model = EdgeViT(**edgevit_configs['XS']) 149 | 150 | if pretrained: 151 | raise NotImplementedError 152 | 153 | return model 154 | 155 | def EdgeViT_S(pretrained=False): 156 | model = EdgeViT(**edgevit_configs['S']) 157 | 158 | if pretrained: 159 | raise NotImplementedError 160 | 161 | return model --------------------------------------------------------------------------------