├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── res_mlp_pytorch
├── __init__.py
└── res_mlp_pytorch.py
├── resmlp.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## ResMLP - Pytorch
4 |
5 | Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install res-mlp-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from res_mlp_pytorch import ResMLP
18 |
19 | model = ResMLP(
20 | image_size = 256,
21 | patch_size = 16,
22 | dim = 512,
23 | depth = 12,
24 | num_classes = 1000
25 | )
26 |
27 | img = torch.randn(1, 3, 256, 256)
28 | pred = model(img) # (1, 1000)
29 | ```
30 |
31 | Rectangular image
32 |
33 | ```python
34 | import torch
35 | from res_mlp_pytorch import ResMLP
36 |
37 | model = ResMLP(
38 | image_size = (128, 256), # (128 x 256)
39 | patch_size = 16,
40 | dim = 512,
41 | depth = 12,
42 | num_classes = 1000
43 | )
44 |
45 | img = torch.randn(1, 3, 128, 256)
46 | pred = model(img) # (1, 1000)
47 | ```
48 |
49 | ## Citations
50 |
51 | ```bibtex
52 | @misc{touvron2021resmlp,
53 | title = {ResMLP: Feedforward networks for image classification with data-efficient training},
54 | author = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
55 | year = {2021},
56 | eprint = {2105.03404},
57 | archivePrefix = {arXiv},
58 | primaryClass = {cs.CV}
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/res_mlp_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from res_mlp_pytorch.res_mlp_pytorch import ResMLP
2 |
--------------------------------------------------------------------------------
/res_mlp_pytorch/res_mlp_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops.layers.torch import Rearrange, Reduce
4 |
5 | # helpers
6 |
7 | def pair(val):
8 | return (val, val) if not isinstance(val, tuple) else val
9 |
10 | # classes
11 |
12 | class Affine(nn.Module):
13 | def __init__(self, dim):
14 | super().__init__()
15 | self.g = nn.Parameter(torch.ones(1, 1, dim))
16 | self.b = nn.Parameter(torch.zeros(1, 1, dim))
17 |
18 | def forward(self, x):
19 | return x * self.g + self.b
20 |
21 | class PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239
22 | def __init__(self, dim, depth, fn):
23 | super().__init__()
24 | if depth <= 18:
25 | init_eps = 0.1
26 | elif depth > 18 and depth <= 24:
27 | init_eps = 1e-5
28 | else:
29 | init_eps = 1e-6
30 |
31 | scale = torch.zeros(1, 1, dim).fill_(init_eps)
32 | self.scale = nn.Parameter(scale)
33 | self.affine = Affine(dim)
34 | self.fn = fn
35 |
36 | def forward(self, x):
37 | return self.fn(self.affine(x)) * self.scale + x
38 |
39 | def ResMLP(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4):
40 | image_height, image_width = pair(image_size)
41 | assert (image_height % patch_size) == 0 and (image_width % patch_size) == 0, 'image height and width must be divisible by patch size'
42 | num_patches = (image_height // patch_size) * (image_width // patch_size)
43 | wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)
44 |
45 | return nn.Sequential(
46 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
47 | nn.Linear((patch_size ** 2) * 3, dim),
48 | *[nn.Sequential(
49 | wrapper(i, nn.Conv1d(num_patches, num_patches, 1)),
50 | wrapper(i, nn.Sequential(
51 | nn.Linear(dim, dim * expansion_factor),
52 | nn.GELU(),
53 | nn.Linear(dim * expansion_factor, dim)
54 | ))
55 | ) for i in range(depth)],
56 | Affine(dim),
57 | Reduce('b n c -> b c', 'mean'),
58 | nn.Linear(dim, num_classes)
59 | )
60 |
--------------------------------------------------------------------------------
/resmlp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/res-mlp-pytorch/562814a406cc418bdb4710aa3bdc569206ac171b/resmlp.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'res-mlp-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.6',
7 | license='MIT',
8 | description = 'ResMLP - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/res-mlp-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning',
15 | 'image recognition'
16 | ],
17 | install_requires=[
18 | 'einops>=0.3',
19 | 'torch>=1.6'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------