├── .gitignore ├── LICENSE ├── README.md ├── model.PNG └── resmlp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ResMLP : Feedforward networks for image classification with data-efficient training 2 | Pytorch implementaion of [ResMLP: Feedforward networks for image classification with data-efficient training](https://arxiv.org/abs/2105.03404). 3 | 4 | ![](model.PNG) 5 | ## Usage: 6 | ```python 7 | import torch 8 | import numpy as np 9 | from resmlp import ResMLP 10 | 11 | img = torch.ones([1, 3, 224, 224]) 12 | 13 | model = ResMLP(in_channels=3, image_size=224, patch_size=16, num_classes=1000, 14 | dim=384, depth=12, mlp_dim=384*4) 15 | 16 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 17 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 18 | print('Trainable Parameters: %.3fM' % parameters) 19 | 20 | out_img = model(img) 21 | 22 | print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size] 23 | ``` 24 | 25 | ## Citation: 26 | ``` 27 | @misc{touvron2021resmlp, 28 | title={ResMLP: Feedforward networks for image classification with data-efficient training}, 29 | author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou}, 30 | year={2021}, 31 | eprint={2105.03404}, 32 | archivePrefix={arXiv}, 33 | primaryClass={cs.CV} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/ResMLP-pytorch/091ab75cb36fdb9fe34946ecfbb9fe826c815013/model.PNG -------------------------------------------------------------------------------- /resmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from einops.layers.torch import Rearrange 5 | 6 | 7 | class Aff(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | 11 | self.alpha = nn.Parameter(torch.ones([1, 1, dim])) 12 | self.beta = nn.Parameter(torch.zeros([1, 1, dim])) 13 | 14 | def forward(self, x): 15 | x = x * self.alpha + self.beta 16 | return x 17 | 18 | class FeedForward(nn.Module): 19 | def __init__(self, dim, hidden_dim, dropout = 0.): 20 | super().__init__() 21 | self.net = nn.Sequential( 22 | nn.Linear(dim, hidden_dim), 23 | nn.GELU(), 24 | nn.Dropout(dropout), 25 | nn.Linear(hidden_dim, dim), 26 | nn.Dropout(dropout) 27 | ) 28 | def forward(self, x): 29 | return self.net(x) 30 | 31 | class MLPblock(nn.Module): 32 | 33 | def __init__(self, dim, num_patch, mlp_dim, dropout = 0., init_values=1e-4): 34 | super().__init__() 35 | 36 | self.pre_affine = Aff(dim) 37 | self.token_mix = nn.Sequential( 38 | Rearrange('b n d -> b d n'), 39 | nn.Linear(num_patch, num_patch), 40 | Rearrange('b d n -> b n d'), 41 | ) 42 | self.ff = nn.Sequential( 43 | FeedForward(dim, mlp_dim, dropout), 44 | ) 45 | self.post_affine = Aff(dim) 46 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 47 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 48 | 49 | def forward(self, x): 50 | x = self.pre_affine(x) 51 | x = x + self.gamma_1 * self.token_mix(x) 52 | x = self.post_affine(x) 53 | x = x + self.gamma_2 * self.ff(x) 54 | return x 55 | 56 | 57 | class ResMLP(nn.Module): 58 | 59 | def __init__(self, in_channels, dim, num_classes, patch_size, image_size, depth, mlp_dim): 60 | super().__init__() 61 | 62 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 63 | self.num_patch = (image_size// patch_size) ** 2 64 | self.to_patch_embedding = nn.Sequential( 65 | nn.Conv2d(in_channels, dim, patch_size, patch_size), 66 | Rearrange('b c h w -> b (h w) c'), 67 | ) 68 | 69 | self.mlp_blocks = nn.ModuleList([]) 70 | 71 | for _ in range(depth): 72 | self.mlp_blocks.append(MLPblock(dim, self.num_patch, mlp_dim)) 73 | 74 | self.affine = Aff(dim) 75 | 76 | self.mlp_head = nn.Sequential( 77 | nn.Linear(dim, num_classes) 78 | ) 79 | 80 | def forward(self, x): 81 | 82 | x = self.to_patch_embedding(x) 83 | 84 | for mlp_block in self.mlp_blocks: 85 | x = mlp_block(x) 86 | 87 | x = self.affine(x) 88 | 89 | x = x.mean(dim=1) 90 | 91 | return self.mlp_head(x) 92 | 93 | 94 | 95 | 96 | if __name__ == "__main__": 97 | img = torch.ones([1, 3, 224, 224]) 98 | 99 | model = ResMLP(in_channels=3, image_size=224, patch_size=16, num_classes=1000, 100 | dim=384, depth=12, mlp_dim=384*4) 101 | 102 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 103 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 104 | print('Trainable Parameters: %.3fM' % parameters) 105 | 106 | out_img = model(img) 107 | 108 | print("Shape of out :", out_img.shape) # [B, in_channels, image_size, image_size] 109 | 110 | --------------------------------------------------------------------------------