├── assets └── model.PNG ├── README.md ├── LICENSE ├── .gitignore ├── module.py └── crossvit.py /assets/model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/CrossViT-pytorch/HEAD/assets/model.PNG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification 2 | This is an unofficial PyTorch implementation of [CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification](https://arxiv.org/abs/2103.14899) . 3 | ![](assets/model.PNG) 4 | 5 | 6 | ## Usage : 7 | ```python 8 | import torch 9 | from crossvit import CrossViT 10 | 11 | img = torch.ones([1, 3, 224, 224]) 12 | 13 | model = CrossViT(image_size = 224, channels = 3, num_classes = 100) 14 | out = model(img) 15 | 16 | print("Shape of out :", out.shape) # [B, num_classes] 17 | 18 | 19 | ``` 20 | 21 | ## Citation 22 | ``` 23 | @misc{chen2021crossvit, 24 | title={CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}, 25 | author={Chun-Fu Chen and Quanfu Fan and Rameswar Panda}, 26 | year={2021}, 27 | eprint={2103.14899}, 28 | archivePrefix={arXiv}, 29 | primaryClass={cs.CV} 30 | } 31 | ``` 32 | 33 | ## Acknowledgement 34 | * Base ViT code is borrowed from [@lucidrains](https://github.com/lucidrains) repo : https://github.com/lucidrains/vit-pytorch 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | class CrossAttention(nn.Module): 67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | project_out = not (heads == 1 and dim_head == dim) 71 | 72 | self.heads = heads 73 | self.scale = dim_head ** -0.5 74 | 75 | self.to_k = nn.Linear(dim, inner_dim , bias=False) 76 | self.to_v = nn.Linear(dim, inner_dim , bias = False) 77 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 78 | 79 | self.to_out = nn.Sequential( 80 | nn.Linear(inner_dim, dim), 81 | nn.Dropout(dropout) 82 | ) if project_out else nn.Identity() 83 | 84 | def forward(self, x_qkv): 85 | b, n, _, h = *x_qkv.shape, self.heads 86 | 87 | k = self.to_k(x_qkv) 88 | k = rearrange(k, 'b n (h d) -> b h n d', h = h) 89 | 90 | v = self.to_v(x_qkv) 91 | v = rearrange(v, 'b n (h d) -> b h n d', h = h) 92 | 93 | q = self.to_q(x_qkv[:, 0].unsqueeze(1)) 94 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 95 | 96 | 97 | 98 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 99 | 100 | attn = dots.softmax(dim=-1) 101 | 102 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | out = self.to_out(out) 105 | return out 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /crossvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from module import Attention, PreNorm, FeedForward, CrossAttention 7 | import numpy as np 8 | 9 | 10 | 11 | class Transformer(nn.Module): 12 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 13 | super().__init__() 14 | self.layers = nn.ModuleList([]) 15 | for _ in range(depth): 16 | self.layers.append(nn.ModuleList([ 17 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 18 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 19 | ])) 20 | def forward(self, x): 21 | for attn, ff in self.layers: 22 | x = attn(x) + x 23 | x = ff(x) + x 24 | return x 25 | 26 | 27 | class MultiScaleTransformerEncoder(nn.Module): 28 | 29 | def __init__(self, small_dim = 96, small_depth = 4, small_heads =3, small_dim_head = 32, small_mlp_dim = 384, 30 | large_dim = 192, large_depth = 1, large_heads = 3, large_dim_head = 64, large_mlp_dim = 768, 31 | cross_attn_depth = 1, cross_attn_heads = 3, dropout = 0.): 32 | super().__init__() 33 | self.transformer_enc_small = Transformer(small_dim, small_depth, small_heads, small_dim_head, small_mlp_dim) 34 | self.transformer_enc_large = Transformer(large_dim, large_depth, large_heads, large_dim_head, large_mlp_dim) 35 | 36 | self.cross_attn_layers = nn.ModuleList([]) 37 | for _ in range(cross_attn_depth): 38 | self.cross_attn_layers.append(nn.ModuleList([ 39 | nn.Linear(small_dim, large_dim), 40 | nn.Linear(large_dim, small_dim), 41 | PreNorm(large_dim, CrossAttention(large_dim, heads = cross_attn_heads, dim_head = large_dim_head, dropout = dropout)), 42 | nn.Linear(large_dim, small_dim), 43 | nn.Linear(small_dim, large_dim), 44 | PreNorm(small_dim, CrossAttention(small_dim, heads = cross_attn_heads, dim_head = small_dim_head, dropout = dropout)), 45 | ])) 46 | 47 | def forward(self, xs, xl): 48 | 49 | xs = self.transformer_enc_small(xs) 50 | xl = self.transformer_enc_large(xl) 51 | 52 | for f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l in self.cross_attn_layers: 53 | small_class = xs[:, 0] 54 | x_small = xs[:, 1:] 55 | large_class = xl[:, 0] 56 | x_large = xl[:, 1:] 57 | 58 | # Cross Attn for Large Patch 59 | 60 | cal_q = f_ls(large_class.unsqueeze(1)) 61 | cal_qkv = torch.cat((cal_q, x_small), dim=1) 62 | cal_out = cal_q + cross_attn_l(cal_qkv) 63 | cal_out = g_sl(cal_out) 64 | xl = torch.cat((cal_out, x_large), dim=1) 65 | 66 | # Cross Attn for Smaller Patch 67 | cal_q = f_sl(small_class.unsqueeze(1)) 68 | cal_qkv = torch.cat((cal_q, x_large), dim=1) 69 | cal_out = cal_q + cross_attn_s(cal_qkv) 70 | cal_out = g_ls(cal_out) 71 | xs = torch.cat((cal_out, x_small), dim=1) 72 | 73 | return xs, xl 74 | 75 | 76 | 77 | 78 | 79 | class CrossViT(nn.Module): 80 | def __init__(self, image_size, channels, num_classes, patch_size_small = 14, patch_size_large = 16, small_dim = 96, 81 | large_dim = 192, small_depth = 1, large_depth = 4, cross_attn_depth = 1, multi_scale_enc_depth = 3, 82 | heads = 3, pool = 'cls', dropout = 0., emb_dropout = 0., scale_dim = 4): 83 | super().__init__() 84 | 85 | assert image_size % patch_size_small == 0, 'Image dimensions must be divisible by the patch size.' 86 | num_patches_small = (image_size // patch_size_small) ** 2 87 | patch_dim_small = channels * patch_size_small ** 2 88 | 89 | assert image_size % patch_size_large == 0, 'Image dimensions must be divisible by the patch size.' 90 | num_patches_large = (image_size // patch_size_large) ** 2 91 | patch_dim_large = channels * patch_size_large ** 2 92 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 93 | 94 | 95 | self.to_patch_embedding_small = nn.Sequential( 96 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size_small, p2 = patch_size_small), 97 | nn.Linear(patch_dim_small, small_dim), 98 | ) 99 | 100 | self.to_patch_embedding_large = nn.Sequential( 101 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size_large, p2=patch_size_large), 102 | nn.Linear(patch_dim_large, large_dim), 103 | ) 104 | 105 | self.pos_embedding_small = nn.Parameter(torch.randn(1, num_patches_small + 1, small_dim)) 106 | self.cls_token_small = nn.Parameter(torch.randn(1, 1, small_dim)) 107 | self.dropout_small = nn.Dropout(emb_dropout) 108 | 109 | self.pos_embedding_large = nn.Parameter(torch.randn(1, num_patches_large + 1, large_dim)) 110 | self.cls_token_large = nn.Parameter(torch.randn(1, 1, large_dim)) 111 | self.dropout_large = nn.Dropout(emb_dropout) 112 | 113 | self.multi_scale_transformers = nn.ModuleList([]) 114 | for _ in range(multi_scale_enc_depth): 115 | self.multi_scale_transformers.append(MultiScaleTransformerEncoder(small_dim=small_dim, small_depth=small_depth, 116 | small_heads=heads, small_dim_head=small_dim//heads, 117 | small_mlp_dim=small_dim*scale_dim, 118 | large_dim=large_dim, large_depth=large_depth, 119 | large_heads=heads, large_dim_head=large_dim//heads, 120 | large_mlp_dim=large_dim*scale_dim, 121 | cross_attn_depth=cross_attn_depth, cross_attn_heads=heads, 122 | dropout=dropout)) 123 | 124 | self.pool = pool 125 | self.to_latent = nn.Identity() 126 | 127 | self.mlp_head_small = nn.Sequential( 128 | nn.LayerNorm(small_dim), 129 | nn.Linear(small_dim, num_classes) 130 | ) 131 | 132 | self.mlp_head_large = nn.Sequential( 133 | nn.LayerNorm(large_dim), 134 | nn.Linear(large_dim, num_classes) 135 | ) 136 | 137 | 138 | def forward(self, img): 139 | 140 | xs = self.to_patch_embedding_small(img) 141 | b, n, _ = xs.shape 142 | 143 | cls_token_small = repeat(self.cls_token_small, '() n d -> b n d', b = b) 144 | xs = torch.cat((cls_token_small, xs), dim=1) 145 | xs += self.pos_embedding_small[:, :(n + 1)] 146 | xs = self.dropout_small(xs) 147 | 148 | xl = self.to_patch_embedding_large(img) 149 | b, n, _ = xl.shape 150 | 151 | cls_token_large = repeat(self.cls_token_large, '() n d -> b n d', b=b) 152 | xl = torch.cat((cls_token_large, xl), dim=1) 153 | xl += self.pos_embedding_large[:, :(n + 1)] 154 | xl = self.dropout_large(xl) 155 | 156 | for multi_scale_transformer in self.multi_scale_transformers: 157 | xs, xl = multi_scale_transformer(xs, xl) 158 | 159 | xs = xs.mean(dim = 1) if self.pool == 'mean' else xs[:, 0] 160 | xl = xl.mean(dim = 1) if self.pool == 'mean' else xl[:, 0] 161 | 162 | xs = self.mlp_head_small(xs) 163 | xl = self.mlp_head_large(xl) 164 | x = xs + xl 165 | return x 166 | 167 | 168 | 169 | 170 | if __name__ == "__main__": 171 | 172 | img = torch.ones([1, 3, 224, 224]) 173 | 174 | model = CrossViT(224, 3, 1000) 175 | 176 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 177 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 178 | print('Trainable Parameters: %.3fM' % parameters) 179 | 180 | out = model(img) 181 | 182 | print("Shape of out :", out.shape) # [B, num_classes] 183 | 184 | 185 | --------------------------------------------------------------------------------