├── model.PNG ├── LICENSE ├── README.md ├── .gitignore ├── module.py └── coat.py /model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/CoaT-pytorch/HEAD/model.PNG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoaT: Co-Scale Conv-Attentional Image Transformers 2 | This repo contains PyTorch implementation of paper [Co-Scale Conv-Attentional Image Transformers](https://arxiv.org/abs/2104.06399) 3 | and this is not the official implementation. For official implementation please visit [here](https://github.com/mlpc-ucsd/CoaT). 4 | 5 | ![](model.PNG) 6 | 7 | 8 | ## Usage: 9 | ```python 10 | import numpy as np 11 | from coat import CoaT 12 | import torch 13 | 14 | img = torch.ones([1, 3, 224, 224]) 15 | 16 | coatlite = CoaT(3, 224, 1000) 17 | out = coatlite(img) 18 | print("Shape of out :", out.shape) # [B, num_classes] 19 | 20 | parameters = filter(lambda p: p.requires_grad, coatlite.parameters()) 21 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 22 | print('Trainable Parameters in CoaT-Lite: %.3fM' % parameters) 23 | 24 | # use_parallel=True for Parallel Group 25 | coat_tiny = CoaT(3, 224, 1000, out_channels=[152, 152, 152, 152], scales=[4, 4, 4, 4], use_parallel=True) 26 | out = coat_tiny(img) 27 | print("Shape of out :", out.shape) # [B, num_classes] 28 | 29 | parameters = filter(lambda p: p.requires_grad, coat_tiny.parameters()) 30 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 31 | print('Trainable Parameters in CoaT Tiny: %.3fM' % parameters) 32 | ``` 33 | ## Citation: 34 | ``` 35 | @misc{xu2021coscale, 36 | title={Co-Scale Conv-Attentional Image Transformers}, 37 | author={Weijian Xu and Yifan Xu and Tyler Chang and Zhuowen Tu}, 38 | year={2021}, 39 | eprint={2104.06399}, 40 | archivePrefix={arXiv}, 41 | primaryClass={cs.CV} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | class SepConv2d(nn.Module): 12 | def __init__(self, nin, nout): 13 | super(SepConv2d, self).__init__() 14 | self.depthwise = nn.Conv2d(nin, nin, kernel_size=3, padding=1, groups=nin) 15 | self.pointwise = nn.Conv2d(nin, nout, kernel_size=1) 16 | 17 | def forward(self, x): 18 | out = self.depthwise(x) 19 | out = self.pointwise(out) 20 | return out 21 | 22 | class Residual(nn.Module): 23 | def __init__(self, fn): 24 | super().__init__() 25 | self.fn = fn 26 | def forward(self, x, **kwargs): 27 | return self.fn(x, **kwargs) + x 28 | 29 | class PreNorm(nn.Module): 30 | def __init__(self, dim, fn): 31 | super().__init__() 32 | self.norm = nn.LayerNorm(dim) 33 | self.fn = fn 34 | def forward(self, x, **kwargs): 35 | return self.fn(self.norm(x), **kwargs) 36 | 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, hidden_dim, dropout = 0.): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | nn.Linear(dim, hidden_dim), 42 | nn.GELU(), 43 | nn.Dropout(dropout), 44 | nn.Linear(hidden_dim, dim), 45 | nn.Dropout(dropout) 46 | ) 47 | def forward(self, x): 48 | return self.net(x) 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 52 | super().__init__() 53 | inner_dim = dim_head * heads 54 | project_out = not (heads == 1 and dim_head == dim) 55 | 56 | self.heads = heads 57 | self.scale = dim_head ** -0.5 58 | 59 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 60 | 61 | self.to_out = nn.Sequential( 62 | nn.Linear(inner_dim, dim), 63 | nn.Dropout(dropout) 64 | ) if project_out else nn.Identity() 65 | 66 | def forward(self, x): 67 | b, n, _, h = *x.shape, self.heads 68 | qkv = self.to_qkv(x).chunk(3, dim = -1) 69 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 70 | 71 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 72 | 73 | attn = dots.softmax(dim=-1) 74 | 75 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 76 | out = rearrange(out, 'b h n d -> b n (h d)') 77 | out = self.to_out(out) 78 | return out 79 | 80 | 81 | class ConvAttention(nn.Module): 82 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 83 | super().__init__() 84 | inner_dim = dim_head * heads 85 | project_out = not (heads == 1 and dim_head == dim) 86 | 87 | self.heads = heads 88 | self.scale = dim_head ** -0.5 89 | self.in_depthwiseconv = SepConv2d(dim, dim) 90 | 91 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 92 | self.attn_depthwiseconv = SepConv2d(dim, dim) 93 | self.to_out = nn.Sequential( 94 | nn.Linear(inner_dim, dim), 95 | nn.Dropout(dropout) 96 | ) if project_out else nn.Identity() 97 | 98 | def forward(self, x): 99 | b, n, d, h = *x.shape, self.heads 100 | 101 | cls = x[:, :1] 102 | image_token = x[:, 1:] 103 | H = W = int(math.sqrt(n - 1)) 104 | 105 | image_token = rearrange(image_token, 'b (l w) d -> b d l w', l=H, w=W) 106 | image_token = self.in_depthwiseconv(image_token) 107 | image_token = rearrange(image_token, ' b d h w -> b (h w) d') 108 | x = x + torch.cat((cls, image_token), dim=1) 109 | 110 | qkv = self.to_qkv(x).chunk(3, dim=-1) 111 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 112 | 113 | k = k.transpose(2, 3) 114 | k = k.softmax(dim=-1) 115 | context = einsum('b h i j, b h j a -> b h i a', k, v) 116 | attn = einsum('b h i j, b h j j -> b h i j', q, context) 117 | 118 | cls = v[:, :, :1] 119 | value_token = v[:, :, 1:] 120 | value_token = rearrange(value_token, 'b h (l w) d -> b (h d) l w', l=H, w=W) 121 | value_token = self.attn_depthwiseconv(value_token) 122 | value_token = rearrange(value_token, ' b (h d) l w -> b h (l w) d', h=h) 123 | v = torch.cat((cls, value_token), dim=2) 124 | 125 | out = einsum('b h i j, b h i j -> b h i j', q, v) 126 | out = out + attn 127 | 128 | out = rearrange(out, 'b h n d -> b n (h d)') 129 | out = self.to_out(out) 130 | return out -------------------------------------------------------------------------------- /coat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | import math 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | from module import FeedForward, ConvAttention, PreNorm 8 | import numpy as np 9 | 10 | 11 | class Transformer(nn.Module): 12 | 13 | def __init__(self, depth, dim, heads, dim_head, scale, dropout): 14 | super(Transformer, self).__init__() 15 | self.layers = nn.ModuleList([]) 16 | for _ in range(depth): 17 | self.layers.append( 18 | nn.ModuleList([ 19 | PreNorm(dim, ConvAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 20 | PreNorm(dim, FeedForward(dim, dim*scale, dropout=dropout)) 21 | ])) 22 | 23 | def forward(self, x): 24 | for attn, ff in self.layers: 25 | x = attn(x) + x 26 | x = ff(x) + x 27 | return x 28 | 29 | 30 | class SerialBlock(nn.Module): 31 | 32 | def __init__(self, feature_size, in_channels, out_channels, depth=2, nheads=8, scale=8, 33 | conv_kernel=7, stride=2, dropout=0.): 34 | super(SerialBlock, self).__init__() 35 | self.cls_embed = nn.Linear(in_channels, out_channels) 36 | padding = (conv_kernel -1)//2 37 | self.conv_embed = nn.Sequential( 38 | nn.Conv2d(in_channels, out_channels, conv_kernel, stride, padding), 39 | Rearrange('b c h w -> b (h w) c', h = feature_size, w = feature_size), 40 | nn.LayerNorm(out_channels) 41 | ) 42 | 43 | 44 | self.transformer = Transformer(depth=depth, dim=out_channels, heads=nheads, dim_head=out_channels//nheads, 45 | scale=scale, dropout=dropout) 46 | 47 | def forward(self, x, cls_tokens): 48 | ''' 49 | 50 | :param x: [B C H W] 51 | :return: [B (H W) C] 52 | ''' 53 | x = self.conv_embed(x) 54 | cls_tokens = self.cls_embed(cls_tokens) 55 | x = torch.cat((cls_tokens, x), dim=1) 56 | x = self.transformer(x) 57 | return x 58 | 59 | class ParallelBlock(nn.Module): 60 | 61 | def __init__(self, in_channels, nheads=8, dropout=0.): 62 | super(ParallelBlock, self).__init__() 63 | 64 | 65 | self.p1 = PreNorm(in_channels, ConvAttention(in_channels, 66 | heads=nheads, 67 | dim_head=in_channels//nheads, 68 | dropout=dropout)) 69 | self.p2 = PreNorm(in_channels, ConvAttention(in_channels, 70 | heads=nheads, 71 | dim_head=in_channels // nheads, 72 | dropout=dropout)) 73 | self.p3 = PreNorm(in_channels, ConvAttention(in_channels, 74 | heads=nheads, 75 | dim_head=in_channels // nheads, 76 | dropout=dropout)) 77 | def forward(self, x1, x2, x3): 78 | ''' 79 | 80 | :param x: [B C H W] 81 | :return: [B (H W) C] 82 | ''' 83 | return self.p1(x1), self.p2(x2), self.p3(x3) 84 | 85 | 86 | 87 | 88 | class CoaT(nn.Module): 89 | 90 | def __init__(self, in_channels, image_size, num_classes, out_channels=[64, 128, 256, 320], depths=[2, 2, 2, 2], 91 | heads=8, scales=[8, 8, 4, 4], downscales=[4, 2, 2, 2], kernels=[7, 3, 3, 3], use_parallel=False, 92 | parallel_depth = 6, parallel_channels=152, dropout=0.): 93 | super(CoaT, self).__init__() 94 | 95 | assert len(out_channels) == len(depths) == len(scales) == len(downscales) == len(kernels) 96 | feature_size = image_size 97 | self.cls_token = nn.Parameter(torch.randn(1, 1, in_channels)) 98 | self.serial_layers = nn.ModuleList([]) 99 | for out_channel, depth, scale, downscale, kernel in zip(out_channels, depths, scales, downscales, kernels): 100 | feature_size = feature_size // downscale 101 | self.serial_layers.append( 102 | SerialBlock(feature_size, in_channels, out_channel, depth, heads, scale, kernel, downscale, dropout) 103 | ) 104 | in_channels = out_channel 105 | 106 | 107 | self.use_parallel = use_parallel 108 | if use_parallel: 109 | self.parallel_conv_attn = nn.ModuleList([]) 110 | self.parallel_ffn = nn.ModuleList([]) 111 | for _ in range(parallel_depth): 112 | self.parallel_conv_attn.append(ParallelBlock(parallel_channels, heads, dropout) 113 | ) 114 | self.parallel_ffn.append( 115 | PreNorm(parallel_channels, FeedForward(parallel_channels, parallel_channels * 4, dropout=dropout)) 116 | ) 117 | 118 | self.parallel_mlp_head = nn.Sequential( 119 | nn.LayerNorm(in_channels*3), 120 | nn.Linear(in_channels*3, num_classes) 121 | ) 122 | 123 | 124 | 125 | self.serial_mlp_head = nn.Sequential( 126 | nn.LayerNorm(in_channels), 127 | nn.Linear(in_channels, num_classes) 128 | ) 129 | 130 | def forward(self, x): 131 | b, c, _, _ = x.shape 132 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) 133 | serial_outputs = [] 134 | for serial_block in self.serial_layers: 135 | x = serial_block(x, cls_tokens) 136 | serial_outputs.append(x) 137 | cls_tokens = x[:, :1] 138 | l = w = int(math.sqrt(x[:, 1:].shape[1])) 139 | x = rearrange(x[:, 1:], 'b (l w) c -> b c l w', l=l, w=w) 140 | 141 | s2 = serial_outputs[1] 142 | s3 = serial_outputs[2] 143 | s4 = serial_outputs[3] 144 | if self.use_parallel: 145 | for attn, ffn in zip(self.parallel_conv_attn, self.parallel_ffn): 146 | s2, s3, s4 = attn(s2, s3, s4) 147 | cls_s2 = s2[:, :1] 148 | cls_s3 = s3[:, :1] 149 | cls_s4 = s4[:, :1] 150 | s2 = rearrange(s2[:,1:], 'b (l w) d -> b d l w', l=28, w=28) 151 | s3 = rearrange(s3[:, 1:], 'b (l w) d -> b d l w', l=14, w=14) 152 | s4 = rearrange(s4[:, 1:], 'b (l w) d -> b d l w', l=7, w=7) 153 | 154 | s2 = s2 + F.interpolate(s3, (28, 28), mode='bilinear') + F.interpolate(s4, (28, 28), mode='bilinear') 155 | s3 = s3 + F.interpolate(s2, (14, 14), mode='bilinear') + F.interpolate(s4, (14, 14), mode='bilinear') 156 | s4 = s4 + F.interpolate(s2, (7, 7), mode='bilinear') + F.interpolate(s3, (7, 7), mode='bilinear') 157 | 158 | s2 = rearrange(s2, 'b d l w -> b (l w) d') 159 | s3 = rearrange(s3, 'b d l w -> b (l w) d') 160 | s4 = rearrange(s4, 'b d l w -> b (l w) d') 161 | 162 | s2 = ffn(torch.cat([cls_s2, s2], dim=1)) 163 | s3 = ffn(torch.cat([cls_s3, s3], dim=1)) 164 | s4 = ffn(torch.cat([cls_s4, s4], dim=1)) 165 | 166 | cls_tokens = torch.cat([s2[:,0], s3[:,0], s4[:,0]], dim=1) 167 | return self.parallel_mlp_head(cls_tokens) 168 | else: 169 | return self.serial_mlp_head(cls_tokens.squeeze(1)) 170 | 171 | 172 | if __name__ == "__main__": 173 | img = torch.ones([1, 3, 224, 224]) 174 | 175 | model = CoaT(3, 224, 1000, out_channels=[152, 152, 152, 152], scales=[4, 4, 4, 4], use_parallel=True) 176 | 177 | out = model(img) 178 | 179 | print("Shape of out :", out.shape) # [B, num_classes] 180 | 181 | 182 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 183 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 184 | print('Trainable Parameters: %.3fM' % parameters) 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | --------------------------------------------------------------------------------