├── .github ├── FUNDING.yml └── workflows │ ├── python-publish.yml │ └── python-test.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples └── cats_and_dogs.ipynb ├── images ├── ats.png ├── cait.png ├── cross_vit.png ├── crossformer.png ├── crossformer2.png ├── cvt.png ├── dino.png ├── distill.png ├── esvit.png ├── learnable-memory-vit.png ├── levit.png ├── mae.png ├── max-vit.png ├── mbvit.png ├── mp3.png ├── navit.png ├── nest.png ├── parallel-vit.png ├── patch_merger.png ├── pit.png ├── regionvit.png ├── regionvit2.png ├── scalable-vit-1.png ├── scalable-vit-2.png ├── sep-vit.png ├── simmim.png ├── t2t.png ├── twins_svt.png ├── vit.gif ├── vit_for_small_datasets.png ├── vivit.png └── xcit.png ├── setup.py ├── tests └── test.py └── vit_pytorch ├── __init__.py ├── ats_vit.py ├── cait.py ├── cct.py ├── cct_3d.py ├── cross_vit.py ├── crossformer.py ├── cvt.py ├── deepvit.py ├── dino.py ├── distill.py ├── efficient.py ├── es_vit.py ├── extractor.py ├── jumbo_vit.py ├── learnable_memory_vit.py ├── levit.py ├── local_vit.py ├── look_vit.py ├── mae.py ├── max_vit.py ├── max_vit_with_registers.py ├── mobile_vit.py ├── mp3.py ├── mpp.py ├── na_vit.py ├── na_vit_nested_tensor.py ├── na_vit_nested_tensor_3d.py ├── nest.py ├── normalized_vit.py ├── parallel_vit.py ├── pit.py ├── recorder.py ├── regionvit.py ├── rvt.py ├── scalable_vit.py ├── sep_vit.py ├── simmim.py ├── simple_flash_attn_vit.py ├── simple_flash_attn_vit_3d.py ├── simple_uvit.py ├── simple_vit.py ├── simple_vit_1d.py ├── simple_vit_3d.py ├── simple_vit_with_fft.py ├── simple_vit_with_hyper_connections.py ├── simple_vit_with_patch_dropout.py ├── simple_vit_with_qk_norm.py ├── simple_vit_with_register_tokens.py ├── simple_vit_with_value_residual.py ├── t2t.py ├── twins_svt.py ├── vit.py ├── vit_1d.py ├── vit_3d.py ├── vit_for_small_dataset.py ├── vit_with_patch_dropout.py ├── vit_with_patch_merger.py ├── vivit.py └── xcit.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [lucidrains] 4 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Test 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install pytest 30 | python -m pip install wheel 31 | python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Test with pytest 34 | run: | 35 | python setup.py test 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tests * 2 | -------------------------------------------------------------------------------- /images/ats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/ats.png -------------------------------------------------------------------------------- /images/cait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/cait.png -------------------------------------------------------------------------------- /images/cross_vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/cross_vit.png -------------------------------------------------------------------------------- /images/crossformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/crossformer.png -------------------------------------------------------------------------------- /images/crossformer2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/crossformer2.png -------------------------------------------------------------------------------- /images/cvt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/cvt.png -------------------------------------------------------------------------------- /images/dino.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/dino.png -------------------------------------------------------------------------------- /images/distill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/distill.png -------------------------------------------------------------------------------- /images/esvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/esvit.png -------------------------------------------------------------------------------- /images/learnable-memory-vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/learnable-memory-vit.png -------------------------------------------------------------------------------- /images/levit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/levit.png -------------------------------------------------------------------------------- /images/mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/mae.png -------------------------------------------------------------------------------- /images/max-vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/max-vit.png -------------------------------------------------------------------------------- /images/mbvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/mbvit.png -------------------------------------------------------------------------------- /images/mp3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/mp3.png -------------------------------------------------------------------------------- /images/navit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/navit.png -------------------------------------------------------------------------------- /images/nest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/nest.png -------------------------------------------------------------------------------- /images/parallel-vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/parallel-vit.png -------------------------------------------------------------------------------- /images/patch_merger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/patch_merger.png -------------------------------------------------------------------------------- /images/pit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/pit.png -------------------------------------------------------------------------------- /images/regionvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/regionvit.png -------------------------------------------------------------------------------- /images/regionvit2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/regionvit2.png -------------------------------------------------------------------------------- /images/scalable-vit-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/scalable-vit-1.png -------------------------------------------------------------------------------- /images/scalable-vit-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/scalable-vit-2.png -------------------------------------------------------------------------------- /images/sep-vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/sep-vit.png -------------------------------------------------------------------------------- /images/simmim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/simmim.png -------------------------------------------------------------------------------- /images/t2t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/t2t.png -------------------------------------------------------------------------------- /images/twins_svt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/twins_svt.png -------------------------------------------------------------------------------- /images/vit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/vit.gif -------------------------------------------------------------------------------- /images/vit_for_small_datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/vit_for_small_datasets.png -------------------------------------------------------------------------------- /images/vivit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/vivit.png -------------------------------------------------------------------------------- /images/xcit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/vit-pytorch/db05a141a6e3886353a249e64dc9678c6fa30419/images/xcit.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name = 'vit-pytorch', 8 | packages = find_packages(exclude=['examples']), 9 | version = '1.10.1', 10 | license='MIT', 11 | description = 'Vision Transformer (ViT) - Pytorch', 12 | long_description = long_description, 13 | long_description_content_type = 'text/markdown', 14 | author = 'Phil Wang', 15 | author_email = 'lucidrains@gmail.com', 16 | url = 'https://github.com/lucidrains/vit-pytorch', 17 | keywords = [ 18 | 'artificial intelligence', 19 | 'attention mechanism', 20 | 'image recognition' 21 | ], 22 | install_requires=[ 23 | 'einops>=0.7.0', 24 | 'torch>=1.10', 25 | 'torchvision' 26 | ], 27 | setup_requires=[ 28 | 'pytest-runner', 29 | ], 30 | tests_require=[ 31 | 'pytest', 32 | 'torch==2.4.0', 33 | 'torchvision==0.19.0' 34 | ], 35 | classifiers=[ 36 | 'Development Status :: 4 - Beta', 37 | 'Intended Audience :: Developers', 38 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 39 | 'License :: OSI Approved :: MIT License', 40 | 'Programming Language :: Python :: 3.6', 41 | ], 42 | ) 43 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vit_pytorch import ViT 3 | 4 | def test(): 5 | v = ViT( 6 | image_size = 256, 7 | patch_size = 32, 8 | num_classes = 1000, 9 | dim = 1024, 10 | depth = 6, 11 | heads = 16, 12 | mlp_dim = 2048, 13 | dropout = 0.1, 14 | emb_dropout = 0.1 15 | ) 16 | 17 | img = torch.randn(1, 3, 256, 256) 18 | 19 | preds = v(img) 20 | assert preds.shape == (1, 1000), 'correct logits outputted' 21 | -------------------------------------------------------------------------------- /vit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from vit_pytorch.vit import ViT 2 | from vit_pytorch.simple_vit import SimpleViT 3 | 4 | from vit_pytorch.mae import MAE 5 | from vit_pytorch.dino import Dino 6 | -------------------------------------------------------------------------------- /vit_pytorch/cait.py: -------------------------------------------------------------------------------- 1 | from random import randrange 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def dropout_layers(layers, dropout): 15 | if dropout == 0: 16 | return layers 17 | 18 | num_layers = len(layers) 19 | to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout 20 | 21 | # make sure at least one layer makes it 22 | if all(to_drop): 23 | rand_index = randrange(num_layers) 24 | to_drop[rand_index] = False 25 | 26 | layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] 27 | return layers 28 | 29 | # classes 30 | 31 | class LayerScale(nn.Module): 32 | def __init__(self, dim, fn, depth): 33 | super().__init__() 34 | if depth <= 18: # epsilon detailed in section 2 of paper 35 | init_eps = 0.1 36 | elif depth > 18 and depth <= 24: 37 | init_eps = 1e-5 38 | else: 39 | init_eps = 1e-6 40 | 41 | scale = torch.zeros(1, 1, dim).fill_(init_eps) 42 | self.scale = nn.Parameter(scale) 43 | self.fn = fn 44 | def forward(self, x, **kwargs): 45 | return self.fn(x, **kwargs) * self.scale 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, hidden_dim, dropout = 0.): 49 | super().__init__() 50 | self.net = nn.Sequential( 51 | nn.LayerNorm(dim), 52 | nn.Linear(dim, hidden_dim), 53 | nn.GELU(), 54 | nn.Dropout(dropout), 55 | nn.Linear(hidden_dim, dim), 56 | nn.Dropout(dropout) 57 | ) 58 | def forward(self, x): 59 | return self.net(x) 60 | 61 | class Attention(nn.Module): 62 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 63 | super().__init__() 64 | inner_dim = dim_head * heads 65 | self.heads = heads 66 | self.scale = dim_head ** -0.5 67 | 68 | self.norm = nn.LayerNorm(dim) 69 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 70 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 71 | 72 | self.attend = nn.Softmax(dim = -1) 73 | self.dropout = nn.Dropout(dropout) 74 | 75 | self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads)) 76 | self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads)) 77 | 78 | self.to_out = nn.Sequential( 79 | nn.Linear(inner_dim, dim), 80 | nn.Dropout(dropout) 81 | ) 82 | 83 | def forward(self, x, context = None): 84 | b, n, _, h = *x.shape, self.heads 85 | 86 | x = self.norm(x) 87 | context = x if not exists(context) else torch.cat((x, context), dim = 1) 88 | 89 | qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 90 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 91 | 92 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 93 | 94 | dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax 95 | 96 | attn = self.attend(dots) 97 | attn = self.dropout(attn) 98 | 99 | attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax 100 | 101 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 102 | out = rearrange(out, 'b h n d -> b n (h d)') 103 | return self.to_out(out) 104 | 105 | class Transformer(nn.Module): 106 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.): 107 | super().__init__() 108 | self.layers = nn.ModuleList([]) 109 | self.layer_dropout = layer_dropout 110 | 111 | for ind in range(depth): 112 | self.layers.append(nn.ModuleList([ 113 | LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1), 114 | LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1) 115 | ])) 116 | def forward(self, x, context = None): 117 | layers = dropout_layers(self.layers, dropout = self.layer_dropout) 118 | 119 | for attn, ff in layers: 120 | x = attn(x, context = context) + x 121 | x = ff(x) + x 122 | return x 123 | 124 | class CaiT(nn.Module): 125 | def __init__( 126 | self, 127 | *, 128 | image_size, 129 | patch_size, 130 | num_classes, 131 | dim, 132 | depth, 133 | cls_depth, 134 | heads, 135 | mlp_dim, 136 | dim_head = 64, 137 | dropout = 0., 138 | emb_dropout = 0., 139 | layer_dropout = 0. 140 | ): 141 | super().__init__() 142 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 143 | num_patches = (image_size // patch_size) ** 2 144 | patch_dim = 3 * patch_size ** 2 145 | 146 | self.to_patch_embedding = nn.Sequential( 147 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 148 | nn.LayerNorm(patch_dim), 149 | nn.Linear(patch_dim, dim), 150 | nn.LayerNorm(dim) 151 | ) 152 | 153 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) 154 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 155 | 156 | self.dropout = nn.Dropout(emb_dropout) 157 | 158 | self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout) 159 | self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout) 160 | 161 | self.mlp_head = nn.Sequential( 162 | nn.LayerNorm(dim), 163 | nn.Linear(dim, num_classes) 164 | ) 165 | 166 | def forward(self, img): 167 | x = self.to_patch_embedding(img) 168 | b, n, _ = x.shape 169 | 170 | x += self.pos_embedding[:, :n] 171 | x = self.dropout(x) 172 | 173 | x = self.patch_transformer(x) 174 | 175 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 176 | x = self.cls_transformer(cls_tokens, context = x) 177 | 178 | return self.mlp_head(x[:, 0]) 179 | -------------------------------------------------------------------------------- /vit_pytorch/cvt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # helper methods 9 | 10 | def group_dict_by_key(cond, d): 11 | return_val = [dict(), dict()] 12 | for key in d.keys(): 13 | match = bool(cond(key)) 14 | ind = int(not match) 15 | return_val[ind][key] = d[key] 16 | return (*return_val,) 17 | 18 | def group_by_key_prefix_and_remove_prefix(prefix, d): 19 | kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d) 20 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 21 | return kwargs_without_prefix, kwargs 22 | 23 | # classes 24 | 25 | class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1 26 | def __init__(self, dim, eps = 1e-5): 27 | super().__init__() 28 | self.eps = eps 29 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 30 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 31 | 32 | def forward(self, x): 33 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 34 | mean = torch.mean(x, dim = 1, keepdim = True) 35 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 36 | 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, mult = 4, dropout = 0.): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | LayerNorm(dim), 42 | nn.Conv2d(dim, dim * mult, 1), 43 | nn.GELU(), 44 | nn.Dropout(dropout), 45 | nn.Conv2d(dim * mult, dim, 1), 46 | nn.Dropout(dropout) 47 | ) 48 | def forward(self, x): 49 | return self.net(x) 50 | 51 | class DepthWiseConv2d(nn.Module): 52 | def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True): 53 | super().__init__() 54 | self.net = nn.Sequential( 55 | nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), 56 | nn.BatchNorm2d(dim_in), 57 | nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) 58 | ) 59 | def forward(self, x): 60 | return self.net(x) 61 | 62 | class Attention(nn.Module): 63 | def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.): 64 | super().__init__() 65 | inner_dim = dim_head * heads 66 | padding = proj_kernel // 2 67 | self.heads = heads 68 | self.scale = dim_head ** -0.5 69 | 70 | self.norm = LayerNorm(dim) 71 | self.attend = nn.Softmax(dim = -1) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False) 75 | self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False) 76 | 77 | self.to_out = nn.Sequential( 78 | nn.Conv2d(inner_dim, dim, 1), 79 | nn.Dropout(dropout) 80 | ) 81 | 82 | def forward(self, x): 83 | shape = x.shape 84 | b, n, _, y, h = *shape, self.heads 85 | 86 | x = self.norm(x) 87 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1)) 88 | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v)) 89 | 90 | dots = einsum('b i d, b j d -> b i j', q, k) * self.scale 91 | 92 | attn = self.attend(dots) 93 | attn = self.dropout(attn) 94 | 95 | out = einsum('b i j, b j d -> b i d', attn, v) 96 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y) 97 | return self.to_out(out) 98 | 99 | class Transformer(nn.Module): 100 | def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.): 101 | super().__init__() 102 | self.layers = nn.ModuleList([]) 103 | for _ in range(depth): 104 | self.layers.append(nn.ModuleList([ 105 | Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout), 106 | FeedForward(dim, mlp_mult, dropout = dropout) 107 | ])) 108 | def forward(self, x): 109 | for attn, ff in self.layers: 110 | x = attn(x) + x 111 | x = ff(x) + x 112 | return x 113 | 114 | class CvT(nn.Module): 115 | def __init__( 116 | self, 117 | *, 118 | num_classes, 119 | s1_emb_dim = 64, 120 | s1_emb_kernel = 7, 121 | s1_emb_stride = 4, 122 | s1_proj_kernel = 3, 123 | s1_kv_proj_stride = 2, 124 | s1_heads = 1, 125 | s1_depth = 1, 126 | s1_mlp_mult = 4, 127 | s2_emb_dim = 192, 128 | s2_emb_kernel = 3, 129 | s2_emb_stride = 2, 130 | s2_proj_kernel = 3, 131 | s2_kv_proj_stride = 2, 132 | s2_heads = 3, 133 | s2_depth = 2, 134 | s2_mlp_mult = 4, 135 | s3_emb_dim = 384, 136 | s3_emb_kernel = 3, 137 | s3_emb_stride = 2, 138 | s3_proj_kernel = 3, 139 | s3_kv_proj_stride = 2, 140 | s3_heads = 6, 141 | s3_depth = 10, 142 | s3_mlp_mult = 4, 143 | dropout = 0., 144 | channels = 3 145 | ): 146 | super().__init__() 147 | kwargs = dict(locals()) 148 | 149 | dim = channels 150 | layers = [] 151 | 152 | for prefix in ('s1', 's2', 's3'): 153 | config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs) 154 | 155 | layers.append(nn.Sequential( 156 | nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']), 157 | LayerNorm(config['emb_dim']), 158 | Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout) 159 | )) 160 | 161 | dim = config['emb_dim'] 162 | 163 | self.layers = nn.Sequential(*layers) 164 | 165 | self.to_logits = nn.Sequential( 166 | nn.AdaptiveAvgPool2d(1), 167 | Rearrange('... () () -> ...'), 168 | nn.Linear(dim, num_classes) 169 | ) 170 | 171 | def forward(self, x): 172 | latents = self.layers(x) 173 | return self.to_logits(latents) 174 | -------------------------------------------------------------------------------- /vit_pytorch/deepvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class FeedForward(nn.Module): 9 | def __init__(self, dim, hidden_dim, dropout = 0.): 10 | super().__init__() 11 | self.net = nn.Sequential( 12 | nn.LayerNorm(dim), 13 | nn.Linear(dim, hidden_dim), 14 | nn.GELU(), 15 | nn.Dropout(dropout), 16 | nn.Linear(hidden_dim, dim), 17 | nn.Dropout(dropout) 18 | ) 19 | def forward(self, x): 20 | return self.net(x) 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 24 | super().__init__() 25 | inner_dim = dim_head * heads 26 | self.heads = heads 27 | self.scale = dim_head ** -0.5 28 | 29 | self.norm = nn.LayerNorm(dim) 30 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) 35 | 36 | self.reattn_norm = nn.Sequential( 37 | Rearrange('b h i j -> b i j h'), 38 | nn.LayerNorm(heads), 39 | Rearrange('b i j h -> b h i j') 40 | ) 41 | 42 | self.to_out = nn.Sequential( 43 | nn.Linear(inner_dim, dim), 44 | nn.Dropout(dropout) 45 | ) 46 | 47 | def forward(self, x): 48 | b, n, _, h = *x.shape, self.heads 49 | x = self.norm(x) 50 | 51 | qkv = self.to_qkv(x).chunk(3, dim = -1) 52 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 53 | 54 | # attention 55 | 56 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 57 | attn = dots.softmax(dim=-1) 58 | attn = self.dropout(attn) 59 | 60 | # re-attention 61 | 62 | attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) 63 | attn = self.reattn_norm(attn) 64 | 65 | # aggregate and out 66 | 67 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 68 | out = rearrange(out, 'b h n d -> b n (h d)') 69 | out = self.to_out(out) 70 | return out 71 | 72 | class Transformer(nn.Module): 73 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 74 | super().__init__() 75 | self.layers = nn.ModuleList([]) 76 | for _ in range(depth): 77 | self.layers.append(nn.ModuleList([ 78 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 79 | FeedForward(dim, mlp_dim, dropout = dropout) 80 | ])) 81 | def forward(self, x): 82 | for attn, ff in self.layers: 83 | x = attn(x) + x 84 | x = ff(x) + x 85 | return x 86 | 87 | class DeepViT(nn.Module): 88 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 89 | super().__init__() 90 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 91 | num_patches = (image_size // patch_size) ** 2 92 | patch_dim = channels * patch_size ** 2 93 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 94 | 95 | self.to_patch_embedding = nn.Sequential( 96 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 97 | nn.LayerNorm(patch_dim), 98 | nn.Linear(patch_dim, dim), 99 | nn.LayerNorm(dim) 100 | ) 101 | 102 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 103 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 104 | self.dropout = nn.Dropout(emb_dropout) 105 | 106 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 107 | 108 | self.pool = pool 109 | self.to_latent = nn.Identity() 110 | 111 | self.mlp_head = nn.Sequential( 112 | nn.LayerNorm(dim), 113 | nn.Linear(dim, num_classes) 114 | ) 115 | 116 | def forward(self, img): 117 | x = self.to_patch_embedding(img) 118 | b, n, _ = x.shape 119 | 120 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 121 | x = torch.cat((cls_tokens, x), dim=1) 122 | x += self.pos_embedding[:, :(n + 1)] 123 | x = self.dropout(x) 124 | 125 | x = self.transformer(x) 126 | 127 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 128 | 129 | x = self.to_latent(x) 130 | return self.mlp_head(x) 131 | -------------------------------------------------------------------------------- /vit_pytorch/distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module 4 | import torch.nn.functional as F 5 | 6 | from vit_pytorch.vit import ViT 7 | from vit_pytorch.t2t import T2TViT 8 | from vit_pytorch.efficient import ViT as EfficientViT 9 | 10 | from einops import rearrange, repeat 11 | 12 | # helpers 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def default(val, d): 18 | return val if exists(val) else d 19 | 20 | # classes 21 | 22 | class DistillMixin: 23 | def forward(self, img, distill_token = None): 24 | distilling = exists(distill_token) 25 | x = self.to_patch_embedding(img) 26 | b, n, _ = x.shape 27 | 28 | cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b) 29 | x = torch.cat((cls_tokens, x), dim = 1) 30 | x += self.pos_embedding[:, :(n + 1)] 31 | 32 | if distilling: 33 | distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b) 34 | x = torch.cat((x, distill_tokens), dim = 1) 35 | 36 | x = self._attend(x) 37 | 38 | if distilling: 39 | x, distill_tokens = x[:, :-1], x[:, -1] 40 | 41 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 42 | 43 | x = self.to_latent(x) 44 | out = self.mlp_head(x) 45 | 46 | if distilling: 47 | return out, distill_tokens 48 | 49 | return out 50 | 51 | class DistillableViT(DistillMixin, ViT): 52 | def __init__(self, *args, **kwargs): 53 | super(DistillableViT, self).__init__(*args, **kwargs) 54 | self.args = args 55 | self.kwargs = kwargs 56 | self.dim = kwargs['dim'] 57 | self.num_classes = kwargs['num_classes'] 58 | 59 | def to_vit(self): 60 | v = ViT(*self.args, **self.kwargs) 61 | v.load_state_dict(self.state_dict()) 62 | return v 63 | 64 | def _attend(self, x): 65 | x = self.dropout(x) 66 | x = self.transformer(x) 67 | return x 68 | 69 | class DistillableT2TViT(DistillMixin, T2TViT): 70 | def __init__(self, *args, **kwargs): 71 | super(DistillableT2TViT, self).__init__(*args, **kwargs) 72 | self.args = args 73 | self.kwargs = kwargs 74 | self.dim = kwargs['dim'] 75 | self.num_classes = kwargs['num_classes'] 76 | 77 | def to_vit(self): 78 | v = T2TViT(*self.args, **self.kwargs) 79 | v.load_state_dict(self.state_dict()) 80 | return v 81 | 82 | def _attend(self, x): 83 | x = self.dropout(x) 84 | x = self.transformer(x) 85 | return x 86 | 87 | class DistillableEfficientViT(DistillMixin, EfficientViT): 88 | def __init__(self, *args, **kwargs): 89 | super(DistillableEfficientViT, self).__init__(*args, **kwargs) 90 | self.args = args 91 | self.kwargs = kwargs 92 | self.dim = kwargs['dim'] 93 | self.num_classes = kwargs['num_classes'] 94 | 95 | def to_vit(self): 96 | v = EfficientViT(*self.args, **self.kwargs) 97 | v.load_state_dict(self.state_dict()) 98 | return v 99 | 100 | def _attend(self, x): 101 | return self.transformer(x) 102 | 103 | # knowledge distillation wrapper 104 | 105 | class DistillWrapper(Module): 106 | def __init__( 107 | self, 108 | *, 109 | teacher, 110 | student, 111 | temperature = 1., 112 | alpha = 0.5, 113 | hard = False, 114 | mlp_layernorm = False 115 | ): 116 | super().__init__() 117 | assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' 118 | 119 | self.teacher = teacher 120 | self.student = student 121 | 122 | dim = student.dim 123 | num_classes = student.num_classes 124 | self.temperature = temperature 125 | self.alpha = alpha 126 | self.hard = hard 127 | 128 | self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) 129 | 130 | self.distill_mlp = nn.Sequential( 131 | nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(), 132 | nn.Linear(dim, num_classes) 133 | ) 134 | 135 | def forward(self, img, labels, temperature = None, alpha = None, **kwargs): 136 | 137 | alpha = default(alpha, self.alpha) 138 | T = default(temperature, self.temperature) 139 | 140 | with torch.no_grad(): 141 | teacher_logits = self.teacher(img) 142 | 143 | student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs) 144 | distill_logits = self.distill_mlp(distill_tokens) 145 | 146 | loss = F.cross_entropy(student_logits, labels) 147 | 148 | if not self.hard: 149 | distill_loss = F.kl_div( 150 | F.log_softmax(distill_logits / T, dim = -1), 151 | F.softmax(teacher_logits / T, dim = -1).detach(), 152 | reduction = 'batchmean') 153 | distill_loss *= T ** 2 154 | 155 | else: 156 | teacher_labels = teacher_logits.argmax(dim = -1) 157 | distill_loss = F.cross_entropy(distill_logits, teacher_labels) 158 | 159 | return loss * (1 - alpha) + distill_loss * alpha 160 | -------------------------------------------------------------------------------- /vit_pytorch/efficient.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | 6 | def pair(t): 7 | return t if isinstance(t, tuple) else (t, t) 8 | 9 | class ViT(nn.Module): 10 | def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): 11 | super().__init__() 12 | image_size_h, image_size_w = pair(image_size) 13 | assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size' 14 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 15 | num_patches = (image_size_h // patch_size) * (image_size_w // patch_size) 16 | patch_dim = channels * patch_size ** 2 17 | 18 | self.to_patch_embedding = nn.Sequential( 19 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 20 | nn.LayerNorm(patch_dim), 21 | nn.Linear(patch_dim, dim), 22 | nn.LayerNorm(dim) 23 | ) 24 | 25 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 26 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 27 | self.transformer = transformer 28 | 29 | self.pool = pool 30 | self.to_latent = nn.Identity() 31 | 32 | self.mlp_head = nn.Sequential( 33 | nn.LayerNorm(dim), 34 | nn.Linear(dim, num_classes) 35 | ) 36 | 37 | def forward(self, img): 38 | x = self.to_patch_embedding(img) 39 | b, n, _ = x.shape 40 | 41 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 42 | x = torch.cat((cls_tokens, x), dim=1) 43 | x += self.pos_embedding[:, :(n + 1)] 44 | x = self.transformer(x) 45 | 46 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 47 | 48 | x = self.to_latent(x) 49 | return self.mlp_head(x) 50 | -------------------------------------------------------------------------------- /vit_pytorch/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | def exists(val): 5 | return val is not None 6 | 7 | def identity(t): 8 | return t 9 | 10 | def clone_and_detach(t): 11 | return t.clone().detach() 12 | 13 | def apply_tuple_or_single(fn, val): 14 | if isinstance(val, tuple): 15 | return tuple(map(fn, val)) 16 | return fn(val) 17 | 18 | class Extractor(nn.Module): 19 | def __init__( 20 | self, 21 | vit, 22 | device = None, 23 | layer = None, 24 | layer_name = 'transformer', 25 | layer_save_input = False, 26 | return_embeddings_only = False, 27 | detach = True 28 | ): 29 | super().__init__() 30 | self.vit = vit 31 | 32 | self.data = None 33 | self.latents = None 34 | self.hooks = [] 35 | self.hook_registered = False 36 | self.ejected = False 37 | self.device = device 38 | 39 | self.layer = layer 40 | self.layer_name = layer_name 41 | self.layer_save_input = layer_save_input # whether to save input or output of layer 42 | self.return_embeddings_only = return_embeddings_only 43 | 44 | self.detach_fn = clone_and_detach if detach else identity 45 | 46 | def _hook(self, _, inputs, output): 47 | layer_output = inputs if self.layer_save_input else output 48 | self.latents = apply_tuple_or_single(self.detach_fn, layer_output) 49 | 50 | def _register_hook(self): 51 | if not exists(self.layer): 52 | assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer' 53 | layer = getattr(self.vit, self.layer_name) 54 | else: 55 | layer = self.layer 56 | 57 | handle = layer.register_forward_hook(self._hook) 58 | self.hooks.append(handle) 59 | self.hook_registered = True 60 | 61 | def eject(self): 62 | self.ejected = True 63 | for hook in self.hooks: 64 | hook.remove() 65 | self.hooks.clear() 66 | return self.vit 67 | 68 | def clear(self): 69 | del self.latents 70 | self.latents = None 71 | 72 | def forward( 73 | self, 74 | img, 75 | return_embeddings_only = False 76 | ): 77 | assert not self.ejected, 'extractor has been ejected, cannot be used anymore' 78 | self.clear() 79 | if not self.hook_registered: 80 | self._register_hook() 81 | 82 | pred = self.vit(img) 83 | 84 | target_device = self.device if exists(self.device) else img.device 85 | latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents) 86 | 87 | if return_embeddings_only or self.return_embeddings_only: 88 | return latents 89 | 90 | return pred, latents 91 | -------------------------------------------------------------------------------- /vit_pytorch/jumbo_vit.py: -------------------------------------------------------------------------------- 1 | # Simpler Fast Vision Transformers with a Jumbo CLS Token 2 | # https://arxiv.org/abs/2502.15021 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import Module, ModuleList 7 | 8 | from einops import rearrange, repeat, reduce, pack, unpack 9 | from einops.layers.torch import Rearrange 10 | 11 | # helpers 12 | 13 | def pair(t): 14 | return t if isinstance(t, tuple) else (t, t) 15 | 16 | def divisible_by(num, den): 17 | return (num % den) == 0 18 | 19 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 20 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 21 | assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb" 22 | 23 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 24 | omega = temperature ** -omega 25 | 26 | y = y.flatten()[:, None] * omega[None, :] 27 | x = x.flatten()[:, None] * omega[None, :] 28 | pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 29 | 30 | return pos_emb.type(dtype) 31 | 32 | # classes 33 | 34 | def FeedForward(dim, mult = 4.): 35 | hidden_dim = int(dim * mult) 36 | return nn.Sequential( 37 | nn.LayerNorm(dim), 38 | nn.Linear(dim, hidden_dim), 39 | nn.GELU(), 40 | nn.Linear(hidden_dim, dim), 41 | ) 42 | 43 | class Attention(Module): 44 | def __init__(self, dim, heads = 8, dim_head = 64): 45 | super().__init__() 46 | inner_dim = dim_head * heads 47 | self.heads = heads 48 | self.scale = dim_head ** -0.5 49 | self.norm = nn.LayerNorm(dim) 50 | 51 | self.attend = nn.Softmax(dim = -1) 52 | 53 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 54 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 55 | 56 | def forward(self, x): 57 | x = self.norm(x) 58 | 59 | qkv = self.to_qkv(x).chunk(3, dim = -1) 60 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 61 | 62 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 63 | 64 | attn = self.attend(dots) 65 | 66 | out = torch.matmul(attn, v) 67 | out = rearrange(out, 'b h n d -> b n (h d)') 68 | return self.to_out(out) 69 | 70 | class JumboViT(Module): 71 | def __init__( 72 | self, 73 | *, 74 | image_size, 75 | patch_size, 76 | num_classes, 77 | dim, 78 | depth, 79 | heads, 80 | mlp_dim, 81 | num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example 82 | jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on 83 | jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward 84 | channels = 3, 85 | dim_head = 64 86 | ): 87 | super().__init__() 88 | image_height, image_width = pair(image_size) 89 | patch_height, patch_width = pair(patch_size) 90 | 91 | assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.' 92 | 93 | patch_dim = channels * patch_height * patch_width 94 | 95 | self.to_patch_embedding = nn.Sequential( 96 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 97 | nn.LayerNorm(patch_dim), 98 | nn.Linear(patch_dim, dim), 99 | nn.LayerNorm(dim), 100 | ) 101 | 102 | self.pos_embedding = posemb_sincos_2d( 103 | h = image_height // patch_height, 104 | w = image_width // patch_width, 105 | dim = dim, 106 | ) 107 | 108 | jumbo_cls_dim = dim * jumbo_cls_k 109 | 110 | self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim)) 111 | 112 | jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k) 113 | self.jumbo_cls_to_tokens = jumbo_cls_to_tokens 114 | 115 | self.norm = nn.LayerNorm(dim) 116 | self.layers = ModuleList([]) 117 | 118 | # attention and feedforwards 119 | 120 | self.jumbo_ff = nn.Sequential( 121 | Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k), 122 | FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient 123 | jumbo_cls_to_tokens 124 | ) 125 | 126 | for _ in range(depth): 127 | self.layers.append(ModuleList([ 128 | Attention(dim, heads = heads, dim_head = dim_head), 129 | FeedForward(dim, mlp_dim), 130 | ])) 131 | 132 | self.to_latent = nn.Identity() 133 | 134 | self.linear_head = nn.Linear(dim, num_classes) 135 | 136 | def forward(self, img): 137 | 138 | batch, device = img.shape[0], img.device 139 | 140 | x = self.to_patch_embedding(img) 141 | 142 | # pos embedding 143 | 144 | pos_emb = self.pos_embedding.to(device, dtype = x.dtype) 145 | 146 | x = x + pos_emb 147 | 148 | # add cls tokens 149 | 150 | cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch) 151 | 152 | jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens) 153 | 154 | x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d') 155 | 156 | # attention and feedforwards 157 | 158 | for layer, (attn, ff) in enumerate(self.layers, start = 1): 159 | is_last = layer == len(self.layers) 160 | 161 | x = attn(x) + x 162 | 163 | # jumbo feedforward 164 | 165 | jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d') 166 | 167 | x = ff(x) + x 168 | jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens 169 | 170 | if is_last: 171 | continue 172 | 173 | x, _ = pack([jumbo_cls_tokens, x], 'b * d') 174 | 175 | pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean') 176 | 177 | # normalization and project to logits 178 | 179 | embed = self.norm(pooled) 180 | 181 | embed = self.to_latent(embed) 182 | logits = self.linear_head(embed) 183 | return logits 184 | 185 | # copy pasteable file 186 | 187 | if __name__ == '__main__': 188 | 189 | v = JumboViT( 190 | num_classes = 1000, 191 | image_size = 64, 192 | patch_size = 8, 193 | dim = 16, 194 | depth = 2, 195 | heads = 2, 196 | mlp_dim = 32, 197 | jumbo_cls_k = 3, 198 | jumbo_ff_mult = 2, 199 | ) 200 | 201 | images = torch.randn(1, 3, 64, 64) 202 | 203 | logits = v(images) 204 | assert logits.shape == (1, 1000) 205 | -------------------------------------------------------------------------------- /vit_pytorch/learnable_memory_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def pair(t): 14 | return t if isinstance(t, tuple) else (t, t) 15 | 16 | # controlling freezing of layers 17 | 18 | def set_module_requires_grad_(module, requires_grad): 19 | for param in module.parameters(): 20 | param.requires_grad = requires_grad 21 | 22 | def freeze_all_layers_(module): 23 | set_module_requires_grad_(module, False) 24 | 25 | def unfreeze_all_layers_(module): 26 | set_module_requires_grad_(module, True) 27 | 28 | # classes 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, dim, hidden_dim, dropout = 0.): 32 | super().__init__() 33 | self.net = nn.Sequential( 34 | nn.LayerNorm(dim), 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Dropout(dropout), 38 | nn.Linear(hidden_dim, dim), 39 | nn.Dropout(dropout) 40 | ) 41 | def forward(self, x): 42 | return self.net(x) 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | self.norm = nn.LayerNorm(dim) 52 | 53 | self.attend = nn.Softmax(dim = -1) 54 | self.dropout = nn.Dropout(dropout) 55 | 56 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 57 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 58 | 59 | self.to_out = nn.Sequential( 60 | nn.Linear(inner_dim, dim), 61 | nn.Dropout(dropout) 62 | ) 63 | 64 | def forward(self, x, attn_mask = None, memories = None): 65 | x = self.norm(x) 66 | 67 | x_kv = x # input for key / values projection 68 | 69 | if exists(memories): 70 | # add memories to key / values if it is passed in 71 | memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories 72 | x_kv = torch.cat((x_kv, memories), dim = 1) 73 | 74 | qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1)) 75 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 76 | 77 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 78 | 79 | if exists(attn_mask): 80 | dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max) 81 | 82 | attn = self.attend(dots) 83 | attn = self.dropout(attn) 84 | 85 | out = torch.matmul(attn, v) 86 | out = rearrange(out, 'b h n d -> b n (h d)') 87 | return self.to_out(out) 88 | 89 | class Transformer(nn.Module): 90 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 91 | super().__init__() 92 | self.layers = nn.ModuleList([]) 93 | for _ in range(depth): 94 | self.layers.append(nn.ModuleList([ 95 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 96 | FeedForward(dim, mlp_dim, dropout = dropout) 97 | ])) 98 | 99 | def forward(self, x, attn_mask = None, memories = None): 100 | for ind, (attn, ff) in enumerate(self.layers): 101 | layer_memories = memories[ind] if exists(memories) else None 102 | 103 | x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x 104 | x = ff(x) + x 105 | return x 106 | 107 | class ViT(nn.Module): 108 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 109 | super().__init__() 110 | image_height, image_width = pair(image_size) 111 | patch_height, patch_width = pair(patch_size) 112 | 113 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 114 | 115 | num_patches = (image_height // patch_height) * (image_width // patch_width) 116 | patch_dim = channels * patch_height * patch_width 117 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 118 | 119 | self.to_patch_embedding = nn.Sequential( 120 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 121 | nn.LayerNorm(patch_dim), 122 | nn.Linear(patch_dim, dim), 123 | nn.LayerNorm(dim) 124 | ) 125 | 126 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 127 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 128 | self.dropout = nn.Dropout(emb_dropout) 129 | 130 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 131 | 132 | self.mlp_head = nn.Sequential( 133 | nn.LayerNorm(dim), 134 | nn.Linear(dim, num_classes) 135 | ) 136 | 137 | def img_to_tokens(self, img): 138 | x = self.to_patch_embedding(img) 139 | 140 | cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0]) 141 | x = torch.cat((cls_tokens, x), dim = 1) 142 | 143 | x += self.pos_embedding 144 | x = self.dropout(x) 145 | return x 146 | 147 | def forward(self, img): 148 | x = self.img_to_tokens(img) 149 | 150 | x = self.transformer(x) 151 | 152 | cls_tokens = x[:, 0] 153 | return self.mlp_head(cls_tokens) 154 | 155 | # adapter with learnable memories per layer, memory CLS token, and learnable adapter head 156 | 157 | class Adapter(nn.Module): 158 | def __init__( 159 | self, 160 | *, 161 | vit, 162 | num_memories_per_layer = 10, 163 | num_classes = 2, 164 | ): 165 | super().__init__() 166 | assert isinstance(vit, ViT) 167 | 168 | # extract some model variables needed 169 | 170 | dim = vit.cls_token.shape[-1] 171 | layers = len(vit.transformer.layers) 172 | num_patches = vit.pos_embedding.shape[-2] 173 | 174 | self.vit = vit 175 | 176 | # freeze ViT backbone - only memories will be finetuned 177 | 178 | freeze_all_layers_(vit) 179 | 180 | # learnable parameters 181 | 182 | self.memory_cls_token = nn.Parameter(torch.randn(dim)) 183 | self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim)) 184 | 185 | self.mlp_head = nn.Sequential( 186 | nn.LayerNorm(dim), 187 | nn.Linear(dim, num_classes) 188 | ) 189 | 190 | # specialized attention mask to preserve the output of the original ViT 191 | # it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa 192 | 193 | attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool) 194 | attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer 195 | attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything 196 | self.register_buffer('attn_mask', attn_mask) 197 | 198 | def forward(self, img): 199 | b = img.shape[0] 200 | 201 | tokens = self.vit.img_to_tokens(img) 202 | 203 | # add task specific memory tokens 204 | 205 | memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b) 206 | tokens = torch.cat((memory_cls_tokens, tokens), dim = 1) 207 | 208 | # pass memories along with image tokens through transformer for attending 209 | 210 | out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask) 211 | 212 | # extract memory CLS tokens 213 | 214 | memory_cls_tokens = out[:, 0] 215 | 216 | # pass through task specific adapter head 217 | 218 | return self.mlp_head(memory_cls_tokens) 219 | -------------------------------------------------------------------------------- /vit_pytorch/levit.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | def cast_tuple(val, l = 3): 19 | val = val if isinstance(val, tuple) else (val,) 20 | return (*val, *((val[-1],) * max(l - len(val), 0))) 21 | 22 | def always(val): 23 | return lambda *args, **kwargs: val 24 | 25 | # classes 26 | 27 | class FeedForward(nn.Module): 28 | def __init__(self, dim, mult, dropout = 0.): 29 | super().__init__() 30 | self.net = nn.Sequential( 31 | nn.Conv2d(dim, dim * mult, 1), 32 | nn.Hardswish(), 33 | nn.Dropout(dropout), 34 | nn.Conv2d(dim * mult, dim, 1), 35 | nn.Dropout(dropout) 36 | ) 37 | def forward(self, x): 38 | return self.net(x) 39 | 40 | class Attention(nn.Module): 41 | def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False): 42 | super().__init__() 43 | inner_dim_key = dim_key * heads 44 | inner_dim_value = dim_value * heads 45 | dim_out = default(dim_out, dim) 46 | 47 | self.heads = heads 48 | self.scale = dim_key ** -0.5 49 | 50 | self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key)) 51 | self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key)) 52 | self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value)) 53 | 54 | self.attend = nn.Softmax(dim = -1) 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | out_batch_norm = nn.BatchNorm2d(dim_out) 58 | nn.init.zeros_(out_batch_norm.weight) 59 | 60 | self.to_out = nn.Sequential( 61 | nn.GELU(), 62 | nn.Conv2d(inner_dim_value, dim_out, 1), 63 | out_batch_norm, 64 | nn.Dropout(dropout) 65 | ) 66 | 67 | # positional bias 68 | 69 | self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads) 70 | 71 | q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1)) 72 | k_range = torch.arange(fmap_size) 73 | 74 | q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1) 75 | k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1) 76 | 77 | q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos)) 78 | rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs() 79 | 80 | x_rel, y_rel = rel_pos.unbind(dim = -1) 81 | pos_indices = (x_rel * fmap_size) + y_rel 82 | 83 | self.register_buffer('pos_indices', pos_indices) 84 | 85 | def apply_pos_bias(self, fmap): 86 | bias = self.pos_bias(self.pos_indices) 87 | bias = rearrange(bias, 'i j h -> () h i j') 88 | return fmap + (bias / self.scale) 89 | 90 | def forward(self, x): 91 | b, n, *_, h = *x.shape, self.heads 92 | 93 | q = self.to_q(x) 94 | y = q.shape[2] 95 | 96 | qkv = (q, self.to_k(x), self.to_v(x)) 97 | q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv) 98 | 99 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 100 | 101 | dots = self.apply_pos_bias(dots) 102 | 103 | attn = self.attend(dots) 104 | attn = self.dropout(attn) 105 | 106 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 107 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y) 108 | return self.to_out(out) 109 | 110 | class Transformer(nn.Module): 111 | def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False): 112 | super().__init__() 113 | dim_out = default(dim_out, dim) 114 | self.layers = nn.ModuleList([]) 115 | self.attn_residual = (not downsample) and dim == dim_out 116 | 117 | for _ in range(depth): 118 | self.layers.append(nn.ModuleList([ 119 | Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out), 120 | FeedForward(dim_out, mlp_mult, dropout = dropout) 121 | ])) 122 | def forward(self, x): 123 | for attn, ff in self.layers: 124 | attn_res = (x if self.attn_residual else 0) 125 | x = attn(x) + attn_res 126 | x = ff(x) + x 127 | return x 128 | 129 | class LeViT(nn.Module): 130 | def __init__( 131 | self, 132 | *, 133 | image_size, 134 | num_classes, 135 | dim, 136 | depth, 137 | heads, 138 | mlp_mult, 139 | stages = 3, 140 | dim_key = 32, 141 | dim_value = 64, 142 | dropout = 0., 143 | num_distill_classes = None 144 | ): 145 | super().__init__() 146 | 147 | dims = cast_tuple(dim, stages) 148 | depths = cast_tuple(depth, stages) 149 | layer_heads = cast_tuple(heads, stages) 150 | 151 | assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages' 152 | 153 | self.conv_embedding = nn.Sequential( 154 | nn.Conv2d(3, 32, 3, stride = 2, padding = 1), 155 | nn.Conv2d(32, 64, 3, stride = 2, padding = 1), 156 | nn.Conv2d(64, 128, 3, stride = 2, padding = 1), 157 | nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1) 158 | ) 159 | 160 | fmap_size = image_size // (2 ** 4) 161 | layers = [] 162 | 163 | for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads): 164 | is_last = ind == (stages - 1) 165 | layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout)) 166 | 167 | if not is_last: 168 | next_dim = dims[ind + 1] 169 | layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True)) 170 | fmap_size = ceil(fmap_size / 2) 171 | 172 | self.backbone = nn.Sequential(*layers) 173 | 174 | self.pool = nn.Sequential( 175 | nn.AdaptiveAvgPool2d(1), 176 | Rearrange('... () () -> ...') 177 | ) 178 | 179 | self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None) 180 | self.mlp_head = nn.Linear(dim, num_classes) 181 | 182 | def forward(self, img): 183 | x = self.conv_embedding(img) 184 | 185 | x = self.backbone(x) 186 | 187 | x = self.pool(x) 188 | 189 | out = self.mlp_head(x) 190 | distill = self.distill_head(x) 191 | 192 | if exists(distill): 193 | return out, distill 194 | 195 | return out 196 | -------------------------------------------------------------------------------- /vit_pytorch/local_vit.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # classes 10 | 11 | class Residual(nn.Module): 12 | def __init__(self, fn): 13 | super().__init__() 14 | self.fn = fn 15 | 16 | def forward(self, x, **kwargs): 17 | return self.fn(x, **kwargs) + x 18 | 19 | class ExcludeCLS(nn.Module): 20 | def __init__(self, fn): 21 | super().__init__() 22 | self.fn = fn 23 | 24 | def forward(self, x, **kwargs): 25 | cls_token, x = x[:, :1], x[:, 1:] 26 | x = self.fn(x, **kwargs) 27 | return torch.cat((cls_token, x), dim = 1) 28 | 29 | # feed forward related classes 30 | 31 | class DepthWiseConv2d(nn.Module): 32 | def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True): 33 | super().__init__() 34 | self.net = nn.Sequential( 35 | nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), 36 | nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) 37 | ) 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | class FeedForward(nn.Module): 42 | def __init__(self, dim, hidden_dim, dropout = 0.): 43 | super().__init__() 44 | self.net = nn.Sequential( 45 | nn.LayerNorm(dim), 46 | nn.Conv2d(dim, hidden_dim, 1), 47 | nn.Hardswish(), 48 | DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1), 49 | nn.Hardswish(), 50 | nn.Dropout(dropout), 51 | nn.Conv2d(hidden_dim, dim, 1), 52 | nn.Dropout(dropout) 53 | ) 54 | def forward(self, x): 55 | h = w = int(sqrt(x.shape[-2])) 56 | x = rearrange(x, 'b (h w) c -> b c h w', h = h, w = w) 57 | x = self.net(x) 58 | x = rearrange(x, 'b c h w -> b (h w) c') 59 | return x 60 | 61 | # attention 62 | 63 | class Attention(nn.Module): 64 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 65 | super().__init__() 66 | inner_dim = dim_head * heads 67 | 68 | self.heads = heads 69 | self.scale = dim_head ** -0.5 70 | 71 | self.norm = nn.LayerNorm(dim) 72 | self.attend = nn.Softmax(dim = -1) 73 | self.dropout = nn.Dropout(dropout) 74 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 75 | 76 | self.to_out = nn.Sequential( 77 | nn.Linear(inner_dim, dim), 78 | nn.Dropout(dropout) 79 | ) 80 | 81 | def forward(self, x): 82 | b, n, _, h = *x.shape, self.heads 83 | 84 | x = self.norm(x) 85 | qkv = self.to_qkv(x).chunk(3, dim = -1) 86 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 87 | 88 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 89 | 90 | attn = self.attend(dots) 91 | attn = self.dropout(attn) 92 | 93 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 94 | out = rearrange(out, 'b h n d -> b n (h d)') 95 | return self.to_out(out) 96 | 97 | class Transformer(nn.Module): 98 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 99 | super().__init__() 100 | self.layers = nn.ModuleList([]) 101 | for _ in range(depth): 102 | self.layers.append(nn.ModuleList([ 103 | Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 104 | ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout))) 105 | ])) 106 | def forward(self, x): 107 | for attn, ff in self.layers: 108 | x = attn(x) 109 | x = ff(x) 110 | return x 111 | 112 | # main class 113 | 114 | class LocalViT(nn.Module): 115 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 116 | super().__init__() 117 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 118 | num_patches = (image_size // patch_size) ** 2 119 | patch_dim = channels * patch_size ** 2 120 | 121 | self.to_patch_embedding = nn.Sequential( 122 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 123 | nn.LayerNorm(patch_dim), 124 | nn.Linear(patch_dim, dim), 125 | nn.LayerNorm(dim), 126 | ) 127 | 128 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 129 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 130 | self.dropout = nn.Dropout(emb_dropout) 131 | 132 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 133 | 134 | self.mlp_head = nn.Sequential( 135 | nn.LayerNorm(dim), 136 | nn.Linear(dim, num_classes) 137 | ) 138 | 139 | def forward(self, img): 140 | x = self.to_patch_embedding(img) 141 | b, n, _ = x.shape 142 | 143 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 144 | x = torch.cat((cls_tokens, x), dim=1) 145 | x += self.pos_embedding[:, :(n + 1)] 146 | x = self.dropout(x) 147 | 148 | x = self.transformer(x) 149 | 150 | return self.mlp_head(x[:, 0]) 151 | -------------------------------------------------------------------------------- /vit_pytorch/mae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from vit_pytorch.vit import Transformer 7 | 8 | class MAE(nn.Module): 9 | def __init__( 10 | self, 11 | *, 12 | encoder, 13 | decoder_dim, 14 | masking_ratio = 0.75, 15 | decoder_depth = 1, 16 | decoder_heads = 8, 17 | decoder_dim_head = 64 18 | ): 19 | super().__init__() 20 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 21 | self.masking_ratio = masking_ratio 22 | 23 | # extract some hyperparameters and functions from encoder (vision transformer to be trained) 24 | 25 | self.encoder = encoder 26 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 27 | 28 | self.to_patch = encoder.to_patch_embedding[0] 29 | self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:]) 30 | 31 | pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1] 32 | 33 | # decoder parameters 34 | self.decoder_dim = decoder_dim 35 | self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity() 36 | self.mask_token = nn.Parameter(torch.randn(decoder_dim)) 37 | self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4) 38 | self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) 39 | self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch) 40 | 41 | def forward(self, img): 42 | device = img.device 43 | 44 | # get patches 45 | 46 | patches = self.to_patch(img) 47 | batch, num_patches, *_ = patches.shape 48 | 49 | # patch to encoder tokens and add positions 50 | 51 | tokens = self.patch_to_emb(patches) 52 | if self.encoder.pool == "cls": 53 | tokens += self.encoder.pos_embedding[:, 1:(num_patches + 1)] 54 | elif self.encoder.pool == "mean": 55 | tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype) 56 | 57 | # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked 58 | 59 | num_masked = int(self.masking_ratio * num_patches) 60 | rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1) 61 | masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:] 62 | 63 | # get the unmasked tokens to be encoded 64 | 65 | batch_range = torch.arange(batch, device = device)[:, None] 66 | tokens = tokens[batch_range, unmasked_indices] 67 | 68 | # get the patches to be masked for the final reconstruction loss 69 | 70 | masked_patches = patches[batch_range, masked_indices] 71 | 72 | # attend with vision transformer 73 | 74 | encoded_tokens = self.encoder.transformer(tokens) 75 | 76 | # project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder 77 | 78 | decoder_tokens = self.enc_to_dec(encoded_tokens) 79 | 80 | # reapply decoder position embedding to unmasked tokens 81 | 82 | unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) 83 | 84 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above 85 | 86 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) 87 | mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) 88 | 89 | # concat the masked tokens to the decoder tokens and attend with decoder 90 | 91 | decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device) 92 | decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens 93 | decoder_tokens[batch_range, masked_indices] = mask_tokens 94 | decoded_tokens = self.decoder(decoder_tokens) 95 | 96 | # splice out the mask tokens and project to pixel values 97 | 98 | mask_tokens = decoded_tokens[batch_range, masked_indices] 99 | pred_pixel_values = self.to_pixels(mask_tokens) 100 | 101 | # calculate reconstruction loss 102 | 103 | recon_loss = F.mse_loss(pred_pixel_values, masked_patches) 104 | return recon_loss 105 | -------------------------------------------------------------------------------- /vit_pytorch/mp3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val, d): 14 | return val if exists(val) else d 15 | 16 | def pair(t): 17 | return t if isinstance(t, tuple) else (t, t) 18 | 19 | # positional embedding 20 | 21 | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): 22 | _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 23 | 24 | y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') 25 | assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' 26 | omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) 27 | omega = 1. / (temperature ** omega) 28 | 29 | y = y.flatten()[:, None] * omega[None, :] 30 | x = x.flatten()[:, None] * omega[None, :] 31 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) 32 | return pe.type(dtype) 33 | 34 | # feedforward 35 | 36 | class FeedForward(nn.Module): 37 | def __init__(self, dim, hidden_dim, dropout = 0.): 38 | super().__init__() 39 | self.net = nn.Sequential( 40 | nn.LayerNorm(dim), 41 | nn.Linear(dim, hidden_dim), 42 | nn.GELU(), 43 | nn.Dropout(dropout), 44 | nn.Linear(hidden_dim, dim), 45 | nn.Dropout(dropout) 46 | ) 47 | def forward(self, x): 48 | return self.net(x) 49 | 50 | # (cross)attention 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 54 | super().__init__() 55 | inner_dim = dim_head * heads 56 | self.heads = heads 57 | self.scale = dim_head ** -0.5 58 | 59 | self.attend = nn.Softmax(dim = -1) 60 | self.dropout = nn.Dropout(dropout) 61 | 62 | self.norm = nn.LayerNorm(dim) 63 | 64 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 65 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 66 | 67 | self.to_out = nn.Sequential( 68 | nn.Linear(inner_dim, dim), 69 | nn.Dropout(dropout) 70 | ) 71 | 72 | def forward(self, x, context = None): 73 | b, n, _, h = *x.shape, self.heads 74 | 75 | x = self.norm(x) 76 | 77 | context = self.norm(context) if exists(context) else x 78 | 79 | qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 80 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 81 | 82 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 83 | 84 | attn = self.attend(dots) 85 | attn = self.dropout(attn) 86 | 87 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 88 | out = rearrange(out, 'b h n d -> b n (h d)') 89 | return self.to_out(out) 90 | 91 | class Transformer(nn.Module): 92 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 93 | super().__init__() 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append(nn.ModuleList([ 97 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 98 | FeedForward(dim, mlp_dim, dropout = dropout) 99 | ])) 100 | def forward(self, x, context = None): 101 | for attn, ff in self.layers: 102 | x = attn(x, context = context) + x 103 | x = ff(x) + x 104 | return x 105 | 106 | class ViT(nn.Module): 107 | def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.): 108 | super().__init__() 109 | image_height, image_width = pair(image_size) 110 | patch_height, patch_width = pair(patch_size) 111 | 112 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 113 | 114 | num_patches = (image_height // patch_height) * (image_width // patch_width) 115 | patch_dim = channels * patch_height * patch_width 116 | 117 | self.dim = dim 118 | self.num_patches = num_patches 119 | 120 | self.to_patch_embedding = nn.Sequential( 121 | Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width), 122 | nn.LayerNorm(patch_dim), 123 | nn.Linear(patch_dim, dim), 124 | nn.LayerNorm(dim), 125 | ) 126 | 127 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 128 | 129 | self.to_latent = nn.Identity() 130 | self.linear_head = nn.Sequential( 131 | nn.LayerNorm(dim), 132 | nn.Linear(dim, num_classes) 133 | ) 134 | 135 | def forward(self, img): 136 | *_, h, w, dtype = *img.shape, img.dtype 137 | 138 | x = self.to_patch_embedding(img) 139 | pe = posemb_sincos_2d(x) 140 | x = rearrange(x, 'b ... d -> b (...) d') + pe 141 | 142 | x = self.transformer(x) 143 | x = x.mean(dim = 1) 144 | 145 | x = self.to_latent(x) 146 | return self.linear_head(x) 147 | 148 | # Masked Position Prediction Pre-Training 149 | 150 | class MP3(nn.Module): 151 | def __init__(self, vit: ViT, masking_ratio): 152 | super().__init__() 153 | self.vit = vit 154 | 155 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 156 | self.masking_ratio = masking_ratio 157 | 158 | dim = vit.dim 159 | self.mlp_head = nn.Sequential( 160 | nn.LayerNorm(dim), 161 | nn.Linear(dim, vit.num_patches) 162 | ) 163 | 164 | def forward(self, img): 165 | device = img.device 166 | tokens = self.vit.to_patch_embedding(img) 167 | tokens = rearrange(tokens, 'b ... d -> b (...) d') 168 | 169 | batch, num_patches, *_ = tokens.shape 170 | 171 | # Masking 172 | num_masked = int(self.masking_ratio * num_patches) 173 | rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1) 174 | masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:] 175 | 176 | batch_range = torch.arange(batch, device = device)[:, None] 177 | tokens_unmasked = tokens[batch_range, unmasked_indices] 178 | 179 | attended_tokens = self.vit.transformer(tokens, tokens_unmasked) 180 | logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d') 181 | 182 | # Define labels 183 | labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch) 184 | loss = F.cross_entropy(logits, labels) 185 | 186 | return loss 187 | -------------------------------------------------------------------------------- /vit_pytorch/mpp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat, reduce 8 | 9 | # helpers 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def prob_mask_like(t, prob): 15 | batch, seq_length, _ = t.shape 16 | return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob 17 | 18 | def get_mask_subset_with_prob(patched_input, prob): 19 | batch, seq_len, _, device = *patched_input.shape, patched_input.device 20 | max_masked = math.ceil(prob * seq_len) 21 | 22 | rand = torch.rand((batch, seq_len), device=device) 23 | _, sampled_indices = rand.topk(max_masked, dim=-1) 24 | 25 | new_mask = torch.zeros((batch, seq_len), device=device) 26 | new_mask.scatter_(1, sampled_indices, 1) 27 | return new_mask.bool() 28 | 29 | 30 | # mpp loss 31 | 32 | 33 | class MPPLoss(nn.Module): 34 | def __init__( 35 | self, 36 | patch_size, 37 | channels, 38 | output_channel_bits, 39 | max_pixel_val, 40 | mean, 41 | std 42 | ): 43 | super().__init__() 44 | self.patch_size = patch_size 45 | self.channels = channels 46 | self.output_channel_bits = output_channel_bits 47 | self.max_pixel_val = max_pixel_val 48 | 49 | self.mean = torch.tensor(mean).view(-1, 1, 1) if mean else None 50 | self.std = torch.tensor(std).view(-1, 1, 1) if std else None 51 | 52 | def forward(self, predicted_patches, target, mask): 53 | p, c, mpv, bits, device = self.patch_size, self.channels, self.max_pixel_val, self.output_channel_bits, target.device 54 | bin_size = mpv / (2 ** bits) 55 | 56 | # un-normalize input 57 | if exists(self.mean) and exists(self.std): 58 | target = target * self.std + self.mean 59 | 60 | # reshape target to patches 61 | target = target.clamp(max = mpv) # clamp just in case 62 | avg_target = reduce(target, 'b c (h p1) (w p2) -> b (h w) c', 'mean', p1 = p, p2 = p).contiguous() 63 | 64 | channel_bins = torch.arange(bin_size, mpv, bin_size, device = device) 65 | discretized_target = torch.bucketize(avg_target, channel_bins) 66 | 67 | bin_mask = (2 ** bits) ** torch.arange(0, c, device = device).long() 68 | bin_mask = rearrange(bin_mask, 'c -> () () c') 69 | 70 | target_label = torch.sum(bin_mask * discretized_target, dim = -1) 71 | 72 | loss = F.cross_entropy(predicted_patches[mask], target_label[mask]) 73 | return loss 74 | 75 | 76 | # main class 77 | 78 | 79 | class MPP(nn.Module): 80 | def __init__( 81 | self, 82 | transformer, 83 | patch_size, 84 | dim, 85 | output_channel_bits=3, 86 | channels=3, 87 | max_pixel_val=1.0, 88 | mask_prob=0.15, 89 | replace_prob=0.5, 90 | random_patch_prob=0.5, 91 | mean=None, 92 | std=None 93 | ): 94 | super().__init__() 95 | self.transformer = transformer 96 | self.loss = MPPLoss(patch_size, channels, output_channel_bits, 97 | max_pixel_val, mean, std) 98 | 99 | # extract patching function 100 | self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:]) 101 | 102 | # output transformation 103 | self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels)) 104 | 105 | # vit related dimensions 106 | self.patch_size = patch_size 107 | 108 | # mpp related probabilities 109 | self.mask_prob = mask_prob 110 | self.replace_prob = replace_prob 111 | self.random_patch_prob = random_patch_prob 112 | 113 | # token ids 114 | self.mask_token = nn.Parameter(torch.randn(1, 1, channels * patch_size ** 2)) 115 | 116 | def forward(self, input, **kwargs): 117 | transformer = self.transformer 118 | # clone original image for loss 119 | img = input.clone().detach() 120 | 121 | # reshape raw image to patches 122 | p = self.patch_size 123 | input = rearrange(input, 124 | 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 125 | p1=p, 126 | p2=p) 127 | 128 | mask = get_mask_subset_with_prob(input, self.mask_prob) 129 | 130 | # mask input with mask patches with probability of `replace_prob` (keep patches the same with probability 1 - replace_prob) 131 | masked_input = input.clone().detach() 132 | 133 | # if random token probability > 0 for mpp 134 | if self.random_patch_prob > 0: 135 | random_patch_sampling_prob = self.random_patch_prob / ( 136 | 1 - self.replace_prob) 137 | random_patch_prob = prob_mask_like(input, 138 | random_patch_sampling_prob).to(mask.device) 139 | 140 | bool_random_patch_prob = mask * (random_patch_prob == True) 141 | random_patches = torch.randint(0, 142 | input.shape[1], 143 | (input.shape[0], input.shape[1]), 144 | device=input.device) 145 | randomized_input = masked_input[ 146 | torch.arange(masked_input.shape[0]).unsqueeze(-1), 147 | random_patches] 148 | masked_input[bool_random_patch_prob] = randomized_input[ 149 | bool_random_patch_prob] 150 | 151 | # [mask] input 152 | replace_prob = prob_mask_like(input, self.replace_prob).to(mask.device) 153 | bool_mask_replace = (mask * replace_prob) == True 154 | masked_input[bool_mask_replace] = self.mask_token 155 | 156 | # linear embedding of patches 157 | masked_input = self.patch_to_emb(masked_input) 158 | 159 | # add cls token to input sequence 160 | b, n, _ = masked_input.shape 161 | cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) 162 | masked_input = torch.cat((cls_tokens, masked_input), dim=1) 163 | 164 | # add positional embeddings to input 165 | masked_input += transformer.pos_embedding[:, :(n + 1)] 166 | masked_input = transformer.dropout(masked_input) 167 | 168 | # get generator output and get mpp loss 169 | masked_input = transformer.transformer(masked_input, **kwargs) 170 | cls_logits = self.to_bits(masked_input) 171 | logits = cls_logits[:, 1:, :] 172 | 173 | mpp_loss = self.loss(logits, img, mask) 174 | 175 | return mpp_loss 176 | -------------------------------------------------------------------------------- /vit_pytorch/nest.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch import nn, einsum 4 | 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange, Reduce 7 | 8 | # helpers 9 | 10 | def cast_tuple(val, depth): 11 | return val if isinstance(val, tuple) else ((val,) * depth) 12 | 13 | # classes 14 | 15 | class LayerNorm(nn.Module): 16 | def __init__(self, dim, eps = 1e-5): 17 | super().__init__() 18 | self.eps = eps 19 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 20 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 21 | 22 | def forward(self, x): 23 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 24 | mean = torch.mean(x, dim = 1, keepdim = True) 25 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 26 | 27 | class FeedForward(nn.Module): 28 | def __init__(self, dim, mlp_mult = 4, dropout = 0.): 29 | super().__init__() 30 | self.net = nn.Sequential( 31 | LayerNorm(dim), 32 | nn.Conv2d(dim, dim * mlp_mult, 1), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Conv2d(dim * mlp_mult, dim, 1), 36 | nn.Dropout(dropout) 37 | ) 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | class Attention(nn.Module): 42 | def __init__(self, dim, heads = 8, dropout = 0.): 43 | super().__init__() 44 | dim_head = dim // heads 45 | inner_dim = dim_head * heads 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | 49 | self.norm = LayerNorm(dim) 50 | self.attend = nn.Softmax(dim = -1) 51 | self.dropout = nn.Dropout(dropout) 52 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) 53 | 54 | self.to_out = nn.Sequential( 55 | nn.Conv2d(inner_dim, dim, 1), 56 | nn.Dropout(dropout) 57 | ) 58 | 59 | def forward(self, x): 60 | b, c, h, w, heads = *x.shape, self.heads 61 | 62 | x = self.norm(x) 63 | 64 | qkv = self.to_qkv(x).chunk(3, dim = 1) 65 | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv) 66 | 67 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 68 | 69 | attn = self.attend(dots) 70 | attn = self.dropout(attn) 71 | 72 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 73 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 74 | return self.to_out(out) 75 | 76 | def Aggregate(dim, dim_out): 77 | return nn.Sequential( 78 | nn.Conv2d(dim, dim_out, 3, padding = 1), 79 | LayerNorm(dim_out), 80 | nn.MaxPool2d(3, stride = 2, padding = 1) 81 | ) 82 | 83 | class Transformer(nn.Module): 84 | def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.): 85 | super().__init__() 86 | self.layers = nn.ModuleList([]) 87 | self.pos_emb = nn.Parameter(torch.randn(seq_len)) 88 | 89 | for _ in range(depth): 90 | self.layers.append(nn.ModuleList([ 91 | Attention(dim, heads = heads, dropout = dropout), 92 | FeedForward(dim, mlp_mult, dropout = dropout) 93 | ])) 94 | def forward(self, x): 95 | *_, h, w = x.shape 96 | 97 | pos_emb = self.pos_emb[:(h * w)] 98 | pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w) 99 | x = x + pos_emb 100 | 101 | for attn, ff in self.layers: 102 | x = attn(x) + x 103 | x = ff(x) + x 104 | return x 105 | 106 | class NesT(nn.Module): 107 | def __init__( 108 | self, 109 | *, 110 | image_size, 111 | patch_size, 112 | num_classes, 113 | dim, 114 | heads, 115 | num_hierarchies, 116 | block_repeats, 117 | mlp_mult = 4, 118 | channels = 3, 119 | dim_head = 64, 120 | dropout = 0. 121 | ): 122 | super().__init__() 123 | assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.' 124 | num_patches = (image_size // patch_size) ** 2 125 | patch_dim = channels * patch_size ** 2 126 | fmap_size = image_size // patch_size 127 | blocks = 2 ** (num_hierarchies - 1) 128 | 129 | seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across hierarchy 130 | hierarchies = list(reversed(range(num_hierarchies))) 131 | mults = [2 ** i for i in reversed(hierarchies)] 132 | 133 | layer_heads = list(map(lambda t: t * heads, mults)) 134 | layer_dims = list(map(lambda t: t * dim, mults)) 135 | last_dim = layer_dims[-1] 136 | 137 | layer_dims = [*layer_dims, layer_dims[-1]] 138 | dim_pairs = zip(layer_dims[:-1], layer_dims[1:]) 139 | 140 | self.to_patch_embedding = nn.Sequential( 141 | Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size), 142 | LayerNorm(patch_dim), 143 | nn.Conv2d(patch_dim, layer_dims[0], 1), 144 | LayerNorm(layer_dims[0]) 145 | ) 146 | 147 | block_repeats = cast_tuple(block_repeats, num_hierarchies) 148 | 149 | self.layers = nn.ModuleList([]) 150 | 151 | for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats): 152 | is_last = level == 0 153 | depth = block_repeat 154 | 155 | self.layers.append(nn.ModuleList([ 156 | Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout), 157 | Aggregate(dim_in, dim_out) if not is_last else nn.Identity() 158 | ])) 159 | 160 | 161 | self.mlp_head = nn.Sequential( 162 | LayerNorm(last_dim), 163 | Reduce('b c h w -> b c', 'mean'), 164 | nn.Linear(last_dim, num_classes) 165 | ) 166 | 167 | def forward(self, img): 168 | x = self.to_patch_embedding(img) 169 | b, c, h, w = x.shape 170 | 171 | num_hierarchies = len(self.layers) 172 | 173 | for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers): 174 | block_size = 2 ** level 175 | x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size) 176 | x = transformer(x) 177 | x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size) 178 | x = aggregate(x) 179 | 180 | return self.mlp_head(x) 181 | -------------------------------------------------------------------------------- /vit_pytorch/normalized_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | import torch.nn.functional as F 5 | import torch.nn.utils.parametrize as parametrize 6 | 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | # functions 11 | 12 | def exists(v): 13 | return v is not None 14 | 15 | def default(v, d): 16 | return v if exists(v) else d 17 | 18 | def pair(t): 19 | return t if isinstance(t, tuple) else (t, t) 20 | 21 | def divisible_by(numer, denom): 22 | return (numer % denom) == 0 23 | 24 | def l2norm(t, dim = -1): 25 | return F.normalize(t, dim = dim, p = 2) 26 | 27 | # for use with parametrize 28 | 29 | class L2Norm(Module): 30 | def __init__(self, dim = -1): 31 | super().__init__() 32 | self.dim = dim 33 | 34 | def forward(self, t): 35 | return l2norm(t, dim = self.dim) 36 | 37 | class NormLinear(Module): 38 | def __init__( 39 | self, 40 | dim, 41 | dim_out, 42 | norm_dim_in = True 43 | ): 44 | super().__init__() 45 | self.linear = nn.Linear(dim, dim_out, bias = False) 46 | 47 | parametrize.register_parametrization( 48 | self.linear, 49 | 'weight', 50 | L2Norm(dim = -1 if norm_dim_in else 0) 51 | ) 52 | 53 | @property 54 | def weight(self): 55 | return self.linear.weight 56 | 57 | def forward(self, x): 58 | return self.linear(x) 59 | 60 | # attention and feedforward 61 | 62 | class Attention(Module): 63 | def __init__( 64 | self, 65 | dim, 66 | *, 67 | dim_head = 64, 68 | heads = 8, 69 | dropout = 0. 70 | ): 71 | super().__init__() 72 | dim_inner = dim_head * heads 73 | self.to_q = NormLinear(dim, dim_inner) 74 | self.to_k = NormLinear(dim, dim_inner) 75 | self.to_v = NormLinear(dim, dim_inner) 76 | 77 | self.dropout = dropout 78 | 79 | self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25)) 80 | self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25)) 81 | 82 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) 83 | self.merge_heads = Rearrange('b h n d -> b n (h d)') 84 | 85 | self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) 86 | 87 | def forward( 88 | self, 89 | x 90 | ): 91 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 92 | 93 | q, k, v = map(self.split_heads, (q, k, v)) 94 | 95 | # query key rmsnorm 96 | 97 | q, k = map(l2norm, (q, k)) 98 | 99 | q = q * self.q_scale 100 | k = k * self.k_scale 101 | 102 | # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 103 | 104 | out = F.scaled_dot_product_attention( 105 | q, k, v, 106 | dropout_p = self.dropout if self.training else 0., 107 | scale = 1. 108 | ) 109 | 110 | out = self.merge_heads(out) 111 | return self.to_out(out) 112 | 113 | class FeedForward(Module): 114 | def __init__( 115 | self, 116 | dim, 117 | *, 118 | dim_inner, 119 | dropout = 0. 120 | ): 121 | super().__init__() 122 | dim_inner = int(dim_inner * 2 / 3) 123 | 124 | self.dim = dim 125 | self.dropout = nn.Dropout(dropout) 126 | 127 | self.to_hidden = NormLinear(dim, dim_inner) 128 | self.to_gate = NormLinear(dim, dim_inner) 129 | 130 | self.hidden_scale = nn.Parameter(torch.ones(dim_inner)) 131 | self.gate_scale = nn.Parameter(torch.ones(dim_inner)) 132 | 133 | self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) 134 | 135 | def forward(self, x): 136 | hidden, gate = self.to_hidden(x), self.to_gate(x) 137 | 138 | hidden = hidden * self.hidden_scale 139 | gate = gate * self.gate_scale * (self.dim ** 0.5) 140 | 141 | hidden = F.silu(gate) * hidden 142 | 143 | hidden = self.dropout(hidden) 144 | return self.to_out(hidden) 145 | 146 | # classes 147 | 148 | class nViT(Module): 149 | """ https://arxiv.org/abs/2410.01131 """ 150 | 151 | def __init__( 152 | self, 153 | *, 154 | image_size, 155 | patch_size, 156 | num_classes, 157 | dim, 158 | depth, 159 | heads, 160 | mlp_dim, 161 | dropout = 0., 162 | channels = 3, 163 | dim_head = 64, 164 | residual_lerp_scale_init = None 165 | ): 166 | super().__init__() 167 | image_height, image_width = pair(image_size) 168 | 169 | # calculate patching related stuff 170 | 171 | assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' 172 | 173 | patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size) 174 | patch_dim = channels * (patch_size ** 2) 175 | num_patches = patch_height_dim * patch_width_dim 176 | 177 | self.channels = channels 178 | self.patch_size = patch_size 179 | 180 | self.to_patch_embedding = nn.Sequential( 181 | Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size), 182 | NormLinear(patch_dim, dim, norm_dim_in = False), 183 | ) 184 | 185 | self.abs_pos_emb = NormLinear(dim, num_patches) 186 | 187 | residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) 188 | 189 | # layers 190 | 191 | self.dim = dim 192 | self.scale = dim ** 0.5 193 | 194 | self.layers = ModuleList([]) 195 | self.residual_lerp_scales = nn.ParameterList([]) 196 | 197 | for _ in range(depth): 198 | self.layers.append(ModuleList([ 199 | Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout), 200 | FeedForward(dim, dim_inner = mlp_dim, dropout = dropout), 201 | ])) 202 | 203 | self.residual_lerp_scales.append(nn.ParameterList([ 204 | nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale), 205 | nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale), 206 | ])) 207 | 208 | self.logit_scale = nn.Parameter(torch.ones(num_classes)) 209 | 210 | self.to_pred = NormLinear(dim, num_classes) 211 | 212 | @torch.no_grad() 213 | def norm_weights_(self): 214 | for module in self.modules(): 215 | if not isinstance(module, NormLinear): 216 | continue 217 | 218 | normed = module.weight 219 | original = module.linear.parametrizations.weight.original 220 | 221 | original.copy_(normed) 222 | 223 | def forward(self, images): 224 | device = images.device 225 | 226 | tokens = self.to_patch_embedding(images) 227 | 228 | seq_len = tokens.shape[-2] 229 | pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)] 230 | 231 | tokens = l2norm(tokens + pos_emb) 232 | 233 | for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales): 234 | 235 | attn_out = l2norm(attn(tokens)) 236 | tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale)) 237 | 238 | ff_out = l2norm(ff(tokens)) 239 | tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale)) 240 | 241 | pooled = reduce(tokens, 'b n d -> b d', 'mean') 242 | 243 | logits = self.to_pred(pooled) 244 | logits = logits * self.logit_scale * self.scale 245 | 246 | return logits 247 | 248 | # quick test 249 | 250 | if __name__ == '__main__': 251 | 252 | v = nViT( 253 | image_size = 256, 254 | patch_size = 16, 255 | num_classes = 1000, 256 | dim = 1024, 257 | depth = 6, 258 | heads = 8, 259 | mlp_dim = 2048, 260 | ) 261 | 262 | img = torch.randn(4, 3, 256, 256) 263 | logits = v(img) # (4, 1000) 264 | assert logits.shape == (4, 1000) 265 | -------------------------------------------------------------------------------- /vit_pytorch/parallel_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | # classes 13 | 14 | class Parallel(nn.Module): 15 | def __init__(self, *fns): 16 | super().__init__() 17 | self.fns = nn.ModuleList(fns) 18 | 19 | def forward(self, x): 20 | return sum([fn(x) for fn in self.fns]) 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout = 0.): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.LayerNorm(dim), 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.norm = nn.LayerNorm(dim) 46 | self.attend = nn.Softmax(dim = -1) 47 | self.dropout = nn.Dropout(dropout) 48 | 49 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 50 | 51 | self.to_out = nn.Sequential( 52 | nn.Linear(inner_dim, dim), 53 | nn.Dropout(dropout) 54 | ) if project_out else nn.Identity() 55 | 56 | def forward(self, x): 57 | x = self.norm(x) 58 | qkv = self.to_qkv(x).chunk(3, dim = -1) 59 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 60 | 61 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 62 | 63 | attn = self.attend(dots) 64 | attn = self.dropout(attn) 65 | 66 | out = torch.matmul(attn, v) 67 | out = rearrange(out, 'b h n d -> b n (h d)') 68 | return self.to_out(out) 69 | 70 | class Transformer(nn.Module): 71 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.): 72 | super().__init__() 73 | self.layers = nn.ModuleList([]) 74 | 75 | attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout) 76 | ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout) 77 | 78 | for _ in range(depth): 79 | self.layers.append(nn.ModuleList([ 80 | Parallel(*[attn_block() for _ in range(num_parallel_branches)]), 81 | Parallel(*[ff_block() for _ in range(num_parallel_branches)]), 82 | ])) 83 | 84 | def forward(self, x): 85 | for attns, ffs in self.layers: 86 | x = attns(x) + x 87 | x = ffs(x) + x 88 | return x 89 | 90 | class ViT(nn.Module): 91 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 92 | super().__init__() 93 | image_height, image_width = pair(image_size) 94 | patch_height, patch_width = pair(patch_size) 95 | 96 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 97 | 98 | num_patches = (image_height // patch_height) * (image_width // patch_width) 99 | patch_dim = channels * patch_height * patch_width 100 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 101 | 102 | self.to_patch_embedding = nn.Sequential( 103 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 104 | nn.Linear(patch_dim, dim), 105 | ) 106 | 107 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 108 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 109 | self.dropout = nn.Dropout(emb_dropout) 110 | 111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout) 112 | 113 | self.pool = pool 114 | self.to_latent = nn.Identity() 115 | 116 | self.mlp_head = nn.Sequential( 117 | nn.LayerNorm(dim), 118 | nn.Linear(dim, num_classes) 119 | ) 120 | 121 | def forward(self, img): 122 | x = self.to_patch_embedding(img) 123 | b, n, _ = x.shape 124 | 125 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 126 | x = torch.cat((cls_tokens, x), dim=1) 127 | x += self.pos_embedding[:, :(n + 1)] 128 | x = self.dropout(x) 129 | 130 | x = self.transformer(x) 131 | 132 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 133 | 134 | x = self.to_latent(x) 135 | return self.mlp_head(x) 136 | -------------------------------------------------------------------------------- /vit_pytorch/pit.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def cast_tuple(val, num): 13 | return val if isinstance(val, tuple) else (val,) * num 14 | 15 | def conv_output_size(image_size, kernel_size, stride, padding = 0): 16 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 17 | 18 | # classes 19 | 20 | class FeedForward(nn.Module): 21 | def __init__(self, dim, hidden_dim, dropout = 0.): 22 | super().__init__() 23 | self.net = nn.Sequential( 24 | nn.LayerNorm(dim), 25 | nn.Linear(dim, hidden_dim), 26 | nn.GELU(), 27 | nn.Dropout(dropout), 28 | nn.Linear(hidden_dim, dim), 29 | nn.Dropout(dropout) 30 | ) 31 | def forward(self, x): 32 | return self.net(x) 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 36 | super().__init__() 37 | inner_dim = dim_head * heads 38 | project_out = not (heads == 1 and dim_head == dim) 39 | 40 | self.heads = heads 41 | self.scale = dim_head ** -0.5 42 | 43 | self.norm = nn.LayerNorm(dim) 44 | self.attend = nn.Softmax(dim = -1) 45 | self.dropout = nn.Dropout(dropout) 46 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 47 | 48 | self.to_out = nn.Sequential( 49 | nn.Linear(inner_dim, dim), 50 | nn.Dropout(dropout) 51 | ) if project_out else nn.Identity() 52 | 53 | def forward(self, x): 54 | b, n, _, h = *x.shape, self.heads 55 | 56 | x = self.norm(x) 57 | qkv = self.to_qkv(x).chunk(3, dim = -1) 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 59 | 60 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 61 | 62 | attn = self.attend(dots) 63 | attn = self.dropout(attn) 64 | 65 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 66 | out = rearrange(out, 'b h n d -> b n (h d)') 67 | return self.to_out(out) 68 | 69 | class Transformer(nn.Module): 70 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 71 | super().__init__() 72 | self.layers = nn.ModuleList([]) 73 | for _ in range(depth): 74 | self.layers.append(nn.ModuleList([ 75 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 76 | FeedForward(dim, mlp_dim, dropout = dropout) 77 | ])) 78 | def forward(self, x): 79 | for attn, ff in self.layers: 80 | x = attn(x) + x 81 | x = ff(x) + x 82 | return x 83 | 84 | # depthwise convolution, for pooling 85 | 86 | class DepthWiseConv2d(nn.Module): 87 | def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True): 88 | super().__init__() 89 | self.net = nn.Sequential( 90 | nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), 91 | nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias) 92 | ) 93 | def forward(self, x): 94 | return self.net(x) 95 | 96 | # pooling layer 97 | 98 | class Pool(nn.Module): 99 | def __init__(self, dim): 100 | super().__init__() 101 | self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1) 102 | self.cls_ff = nn.Linear(dim, dim * 2) 103 | 104 | def forward(self, x): 105 | cls_token, tokens = x[:, :1], x[:, 1:] 106 | 107 | cls_token = self.cls_ff(cls_token) 108 | 109 | tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1]))) 110 | tokens = self.downsample(tokens) 111 | tokens = rearrange(tokens, 'b c h w -> b (h w) c') 112 | 113 | return torch.cat((cls_token, tokens), dim = 1) 114 | 115 | # main class 116 | 117 | class PiT(nn.Module): 118 | def __init__( 119 | self, 120 | *, 121 | image_size, 122 | patch_size, 123 | num_classes, 124 | dim, 125 | depth, 126 | heads, 127 | mlp_dim, 128 | dim_head = 64, 129 | dropout = 0., 130 | emb_dropout = 0., 131 | channels = 3 132 | ): 133 | super().__init__() 134 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 135 | assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing' 136 | heads = cast_tuple(heads, len(depth)) 137 | 138 | patch_dim = channels * patch_size ** 2 139 | 140 | self.to_patch_embedding = nn.Sequential( 141 | nn.Unfold(kernel_size = patch_size, stride = patch_size // 2), 142 | Rearrange('b c n -> b n c'), 143 | nn.Linear(patch_dim, dim) 144 | ) 145 | 146 | output_size = conv_output_size(image_size, patch_size, patch_size // 2) 147 | num_patches = output_size ** 2 148 | 149 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 150 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 151 | self.dropout = nn.Dropout(emb_dropout) 152 | 153 | layers = [] 154 | 155 | for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)): 156 | not_last = ind < (len(depth) - 1) 157 | 158 | layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout)) 159 | 160 | if not_last: 161 | layers.append(Pool(dim)) 162 | dim *= 2 163 | 164 | self.layers = nn.Sequential(*layers) 165 | 166 | self.mlp_head = nn.Sequential( 167 | nn.LayerNorm(dim), 168 | nn.Linear(dim, num_classes) 169 | ) 170 | 171 | def forward(self, img): 172 | x = self.to_patch_embedding(img) 173 | b, n, _ = x.shape 174 | 175 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 176 | x = torch.cat((cls_tokens, x), dim=1) 177 | x += self.pos_embedding[:, :n+1] 178 | x = self.dropout(x) 179 | 180 | x = self.layers(x) 181 | 182 | return self.mlp_head(x[:, 0]) 183 | -------------------------------------------------------------------------------- /vit_pytorch/recorder.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import torch 3 | from torch import nn 4 | 5 | from vit_pytorch.vit import Attention 6 | 7 | def find_modules(nn_module, type): 8 | return [module for module in nn_module.modules() if isinstance(module, type)] 9 | 10 | class Recorder(nn.Module): 11 | def __init__(self, vit, device = None): 12 | super().__init__() 13 | self.vit = vit 14 | 15 | self.data = None 16 | self.recordings = [] 17 | self.hooks = [] 18 | self.hook_registered = False 19 | self.ejected = False 20 | self.device = device 21 | 22 | def _hook(self, _, input, output): 23 | self.recordings.append(output.clone().detach()) 24 | 25 | def _register_hook(self): 26 | modules = find_modules(self.vit.transformer, Attention) 27 | for module in modules: 28 | handle = module.attend.register_forward_hook(self._hook) 29 | self.hooks.append(handle) 30 | self.hook_registered = True 31 | 32 | def eject(self): 33 | self.ejected = True 34 | for hook in self.hooks: 35 | hook.remove() 36 | self.hooks.clear() 37 | return self.vit 38 | 39 | def clear(self): 40 | self.recordings.clear() 41 | 42 | def record(self, attn): 43 | recording = attn.clone().detach() 44 | self.recordings.append(recording) 45 | 46 | def forward(self, img): 47 | assert not self.ejected, 'recorder has been ejected, cannot be used anymore' 48 | self.clear() 49 | if not self.hook_registered: 50 | self._register_hook() 51 | 52 | pred = self.vit(img) 53 | 54 | # move all recordings to one device before stacking 55 | target_device = self.device if self.device is not None else img.device 56 | recordings = tuple(map(lambda t: t.to(target_device), self.recordings)) 57 | 58 | attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None 59 | return pred, attns 60 | -------------------------------------------------------------------------------- /vit_pytorch/simmim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | class SimMIM(nn.Module): 7 | def __init__( 8 | self, 9 | *, 10 | encoder, 11 | masking_ratio = 0.5 12 | ): 13 | super().__init__() 14 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 15 | self.masking_ratio = masking_ratio 16 | 17 | # extract some hyperparameters and functions from encoder (vision transformer to be trained) 18 | 19 | self.encoder = encoder 20 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 21 | 22 | self.to_patch = encoder.to_patch_embedding[0] 23 | self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:]) 24 | 25 | pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1] 26 | 27 | # simple linear head 28 | 29 | self.mask_token = nn.Parameter(torch.randn(encoder_dim)) 30 | self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch) 31 | 32 | def forward(self, img): 33 | device = img.device 34 | 35 | # get patches 36 | 37 | patches = self.to_patch(img) 38 | batch, num_patches, *_ = patches.shape 39 | 40 | # for indexing purposes 41 | 42 | batch_range = torch.arange(batch, device = device)[:, None] 43 | 44 | # get positions 45 | 46 | pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)] 47 | 48 | # patch to encoder tokens and add positions 49 | 50 | tokens = self.patch_to_emb(patches) 51 | tokens = tokens + pos_emb 52 | 53 | # prepare mask tokens 54 | 55 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches) 56 | mask_tokens = mask_tokens + pos_emb 57 | 58 | # calculate of patches needed to be masked, and get positions (indices) to be masked 59 | 60 | num_masked = int(self.masking_ratio * num_patches) 61 | masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices 62 | masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool() 63 | 64 | # mask tokens 65 | 66 | tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens) 67 | 68 | # attend with vision transformer 69 | 70 | encoded = self.encoder.transformer(tokens) 71 | 72 | # get the masked tokens 73 | 74 | encoded_mask_tokens = encoded[batch_range, masked_indices] 75 | 76 | # small linear projection for predicted pixel values 77 | 78 | pred_pixel_values = self.to_pixels(encoded_mask_tokens) 79 | 80 | # get the masked patches for the final reconstruction loss 81 | 82 | masked_patches = patches[batch_range, masked_indices] 83 | 84 | # calculate reconstruction loss 85 | 86 | recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked 87 | return recon_loss 88 | -------------------------------------------------------------------------------- /vit_pytorch/simple_flash_attn_vit.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from packaging import version 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | # constants 12 | 13 | Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def pair(t): 18 | return t if isinstance(t, tuple) else (t, t) 19 | 20 | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): 21 | _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 22 | 23 | y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') 24 | assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' 25 | omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) 26 | omega = 1. / (temperature ** omega) 27 | 28 | y = y.flatten()[:, None] * omega[None, :] 29 | x = x.flatten()[:, None] * omega[None, :] 30 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) 31 | return pe.type(dtype) 32 | 33 | # main class 34 | 35 | class Attend(nn.Module): 36 | def __init__(self, use_flash = False): 37 | super().__init__() 38 | self.use_flash = use_flash 39 | assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 40 | 41 | # determine efficient attention configs for cuda and cpu 42 | 43 | self.cpu_config = Config(True, True, True) 44 | self.cuda_config = None 45 | 46 | if not torch.cuda.is_available() or not use_flash: 47 | return 48 | 49 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 50 | 51 | if device_properties.major == 8 and device_properties.minor == 0: 52 | self.cuda_config = Config(True, False, False) 53 | else: 54 | self.cuda_config = Config(False, True, True) 55 | 56 | def flash_attn(self, q, k, v): 57 | config = self.cuda_config if q.is_cuda else self.cpu_config 58 | 59 | # flash attention - https://arxiv.org/abs/2205.14135 60 | 61 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 62 | out = F.scaled_dot_product_attention(q, k, v) 63 | 64 | return out 65 | 66 | def forward(self, q, k, v): 67 | n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5 68 | 69 | if self.use_flash: 70 | return self.flash_attn(q, k, v) 71 | 72 | # similarity 73 | 74 | sim = einsum("b h i d, b j d -> b h i j", q, k) * scale 75 | 76 | # attention 77 | 78 | attn = sim.softmax(dim=-1) 79 | 80 | # aggregate values 81 | 82 | out = einsum("b h i j, b j d -> b h i d", attn, v) 83 | 84 | return out 85 | 86 | # classes 87 | 88 | class FeedForward(nn.Module): 89 | def __init__(self, dim, hidden_dim): 90 | super().__init__() 91 | self.net = nn.Sequential( 92 | nn.LayerNorm(dim), 93 | nn.Linear(dim, hidden_dim), 94 | nn.GELU(), 95 | nn.Linear(hidden_dim, dim), 96 | ) 97 | def forward(self, x): 98 | return self.net(x) 99 | 100 | class Attention(nn.Module): 101 | def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True): 102 | super().__init__() 103 | inner_dim = dim_head * heads 104 | self.heads = heads 105 | self.scale = dim_head ** -0.5 106 | self.norm = nn.LayerNorm(dim) 107 | 108 | self.attend = Attend(use_flash = use_flash) 109 | 110 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 111 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 112 | 113 | def forward(self, x): 114 | x = self.norm(x) 115 | 116 | qkv = self.to_qkv(x).chunk(3, dim = -1) 117 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 118 | 119 | out = self.attend(q, k, v) 120 | 121 | out = rearrange(out, 'b h n d -> b n (h d)') 122 | return self.to_out(out) 123 | 124 | class Transformer(nn.Module): 125 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash): 126 | super().__init__() 127 | self.layers = nn.ModuleList([]) 128 | for _ in range(depth): 129 | self.layers.append(nn.ModuleList([ 130 | Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash), 131 | FeedForward(dim, mlp_dim) 132 | ])) 133 | def forward(self, x): 134 | for attn, ff in self.layers: 135 | x = attn(x) + x 136 | x = ff(x) + x 137 | return x 138 | 139 | class SimpleViT(nn.Module): 140 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True): 141 | super().__init__() 142 | image_height, image_width = pair(image_size) 143 | patch_height, patch_width = pair(patch_size) 144 | 145 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 146 | 147 | num_patches = (image_height // patch_height) * (image_width // patch_width) 148 | patch_dim = channels * patch_height * patch_width 149 | 150 | self.to_patch_embedding = nn.Sequential( 151 | Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width), 152 | nn.LayerNorm(patch_dim), 153 | nn.Linear(patch_dim, dim), 154 | nn.LayerNorm(dim), 155 | ) 156 | 157 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash) 158 | 159 | self.to_latent = nn.Identity() 160 | self.linear_head = nn.Sequential( 161 | nn.LayerNorm(dim), 162 | nn.Linear(dim, num_classes) 163 | ) 164 | 165 | def forward(self, img): 166 | *_, h, w, dtype = *img.shape, img.dtype 167 | 168 | x = self.to_patch_embedding(img) 169 | pe = posemb_sincos_2d(x) 170 | x = rearrange(x, 'b ... d -> b (...) d') + pe 171 | 172 | x = self.transformer(x) 173 | x = x.mean(dim = 1) 174 | 175 | x = self.to_latent(x) 176 | return self.linear_head(x) 177 | -------------------------------------------------------------------------------- /vit_pytorch/simple_flash_attn_vit_3d.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | from collections import namedtuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn import Module, ModuleList 8 | 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | 12 | # constants 13 | 14 | Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 15 | 16 | # helpers 17 | 18 | def pair(t): 19 | return t if isinstance(t, tuple) else (t, t) 20 | 21 | def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32): 22 | _, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 23 | 24 | z, y, x = torch.meshgrid( 25 | torch.arange(f, device = device), 26 | torch.arange(h, device = device), 27 | torch.arange(w, device = device), 28 | indexing = 'ij') 29 | 30 | fourier_dim = dim // 6 31 | 32 | omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1) 33 | omega = 1. / (temperature ** omega) 34 | 35 | z = z.flatten()[:, None] * omega[None, :] 36 | y = y.flatten()[:, None] * omega[None, :] 37 | x = x.flatten()[:, None] * omega[None, :] 38 | 39 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1) 40 | 41 | pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6 42 | return pe.type(dtype) 43 | 44 | # main class 45 | 46 | class Attend(Module): 47 | def __init__(self, use_flash = False, config: Config = Config(True, True, True)): 48 | super().__init__() 49 | self.config = config 50 | self.use_flash = use_flash 51 | assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 52 | 53 | def flash_attn(self, q, k, v): 54 | # flash attention - https://arxiv.org/abs/2205.14135 55 | 56 | with torch.backends.cuda.sdp_kernel(**self.config._asdict()): 57 | out = F.scaled_dot_product_attention(q, k, v) 58 | 59 | return out 60 | 61 | def forward(self, q, k, v): 62 | n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5 63 | 64 | if self.use_flash: 65 | return self.flash_attn(q, k, v) 66 | 67 | # similarity 68 | 69 | sim = einsum("b h i d, b j d -> b h i j", q, k) * scale 70 | 71 | # attention 72 | 73 | attn = sim.softmax(dim=-1) 74 | 75 | # aggregate values 76 | 77 | out = einsum("b h i j, b j d -> b h i d", attn, v) 78 | 79 | return out 80 | 81 | # classes 82 | 83 | class FeedForward(Module): 84 | def __init__(self, dim, hidden_dim): 85 | super().__init__() 86 | self.net = nn.Sequential( 87 | nn.LayerNorm(dim), 88 | nn.Linear(dim, hidden_dim), 89 | nn.GELU(), 90 | nn.Linear(hidden_dim, dim), 91 | ) 92 | def forward(self, x): 93 | return self.net(x) 94 | 95 | class Attention(Module): 96 | def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True): 97 | super().__init__() 98 | inner_dim = dim_head * heads 99 | self.heads = heads 100 | self.scale = dim_head ** -0.5 101 | self.norm = nn.LayerNorm(dim) 102 | 103 | self.attend = Attend(use_flash = use_flash) 104 | 105 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 106 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 107 | 108 | def forward(self, x): 109 | x = self.norm(x) 110 | 111 | qkv = self.to_qkv(x).chunk(3, dim = -1) 112 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 113 | 114 | out = self.attend(q, k, v) 115 | 116 | out = rearrange(out, 'b h n d -> b n (h d)') 117 | return self.to_out(out) 118 | 119 | class Transformer(Module): 120 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash): 121 | super().__init__() 122 | self.layers = ModuleList([]) 123 | for _ in range(depth): 124 | self.layers.append(ModuleList([ 125 | Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash), 126 | FeedForward(dim, mlp_dim) 127 | ])) 128 | 129 | def forward(self, x): 130 | for attn, ff in self.layers: 131 | x = attn(x) + x 132 | x = ff(x) + x 133 | 134 | return x 135 | 136 | class SimpleViT(Module): 137 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True): 138 | super().__init__() 139 | image_height, image_width = pair(image_size) 140 | patch_height, patch_width = pair(image_patch_size) 141 | 142 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 143 | assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size' 144 | 145 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 146 | patch_dim = channels * patch_height * patch_width * frame_patch_size 147 | 148 | self.to_patch_embedding = nn.Sequential( 149 | Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 150 | nn.LayerNorm(patch_dim), 151 | nn.Linear(patch_dim, dim), 152 | nn.LayerNorm(dim), 153 | ) 154 | 155 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn) 156 | 157 | self.to_latent = nn.Identity() 158 | self.linear_head = nn.Linear(dim, num_classes) 159 | 160 | def forward(self, video): 161 | *_, h, w, dtype = *video.shape, video.dtype 162 | 163 | x = self.to_patch_embedding(video) 164 | pe = posemb_sincos_3d(x) 165 | x = rearrange(x, 'b ... d -> b (...) d') + pe 166 | 167 | x = self.transformer(x) 168 | x = x.mean(dim = 1) 169 | 170 | x = self.to_latent(x) 171 | return self.linear_head(x) 172 | -------------------------------------------------------------------------------- /vit_pytorch/simple_uvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | 5 | from einops import rearrange, repeat, pack, unpack 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def exists(v): 14 | return v is not None 15 | 16 | def divisible_by(num, den): 17 | return (num % den) == 0 18 | 19 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 20 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 21 | assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb" 22 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 23 | omega = temperature ** -omega 24 | 25 | y = y.flatten()[:, None] * omega[None, :] 26 | x = x.flatten()[:, None] * omega[None, :] 27 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 28 | return pe.type(dtype) 29 | 30 | # classes 31 | 32 | def FeedForward(dim, hidden_dim): 33 | return nn.Sequential( 34 | nn.LayerNorm(dim), 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Linear(hidden_dim, dim), 38 | ) 39 | 40 | class Attention(Module): 41 | def __init__(self, dim, heads = 8, dim_head = 64): 42 | super().__init__() 43 | inner_dim = dim_head * heads 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | self.norm = nn.LayerNorm(dim) 47 | 48 | self.attend = nn.Softmax(dim = -1) 49 | 50 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 51 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 52 | 53 | def forward(self, x): 54 | x = self.norm(x) 55 | 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 69 | super().__init__() 70 | self.depth = depth 71 | self.norm = nn.LayerNorm(dim) 72 | self.layers = ModuleList([]) 73 | 74 | for layer in range(1, depth + 1): 75 | latter_half = layer >= (depth / 2 + 1) 76 | 77 | self.layers.append(nn.ModuleList([ 78 | nn.Linear(dim * 2, dim) if latter_half else None, 79 | Attention(dim, heads = heads, dim_head = dim_head), 80 | FeedForward(dim, mlp_dim) 81 | ])) 82 | 83 | def forward(self, x): 84 | 85 | skips = [] 86 | 87 | for ind, (combine_skip, attn, ff) in enumerate(self.layers): 88 | layer = ind + 1 89 | first_half = layer <= (self.depth / 2) 90 | 91 | if first_half: 92 | skips.append(x) 93 | 94 | if exists(combine_skip): 95 | skip = skips.pop() 96 | skip_and_x = torch.cat((skip, x), dim = -1) 97 | x = combine_skip(skip_and_x) 98 | 99 | x = attn(x) + x 100 | x = ff(x) + x 101 | 102 | assert len(skips) == 0 103 | 104 | return self.norm(x) 105 | 106 | class SimpleUViT(Module): 107 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64): 108 | super().__init__() 109 | image_height, image_width = pair(image_size) 110 | patch_height, patch_width = pair(patch_size) 111 | 112 | assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.' 113 | 114 | patch_dim = channels * patch_height * patch_width 115 | 116 | self.to_patch_embedding = nn.Sequential( 117 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 118 | nn.LayerNorm(patch_dim), 119 | nn.Linear(patch_dim, dim), 120 | nn.LayerNorm(dim), 121 | ) 122 | 123 | pos_embedding = posemb_sincos_2d( 124 | h = image_height // patch_height, 125 | w = image_width // patch_width, 126 | dim = dim 127 | ) 128 | 129 | self.register_buffer('pos_embedding', pos_embedding, persistent = False) 130 | 131 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) 132 | 133 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 134 | 135 | self.pool = "mean" 136 | self.to_latent = nn.Identity() 137 | 138 | self.linear_head = nn.Linear(dim, num_classes) 139 | 140 | def forward(self, img): 141 | batch, device = img.shape[0], img.device 142 | 143 | x = self.to_patch_embedding(img) 144 | x = x + self.pos_embedding.type(x.dtype) 145 | 146 | r = repeat(self.register_tokens, 'n d -> b n d', b = batch) 147 | 148 | x, ps = pack([x, r], 'b * d') 149 | 150 | x = self.transformer(x) 151 | 152 | x, _ = unpack(x, ps, 'b * d') 153 | 154 | x = x.mean(dim = 1) 155 | 156 | x = self.to_latent(x) 157 | return self.linear_head(x) 158 | 159 | # quick test on odd number of layers 160 | 161 | if __name__ == '__main__': 162 | 163 | v = SimpleUViT( 164 | image_size = 256, 165 | patch_size = 32, 166 | num_classes = 1000, 167 | dim = 1024, 168 | depth = 7, 169 | heads = 16, 170 | mlp_dim = 2048 171 | ).cuda() 172 | 173 | img = torch.randn(2, 3, 256, 256).cuda() 174 | 175 | preds = v(img) 176 | assert preds.shape == (2, 1000) 177 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 13 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 14 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 15 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 16 | omega = 1.0 / (temperature ** omega) 17 | 18 | y = y.flatten()[:, None] * omega[None, :] 19 | x = x.flatten()[:, None] * omega[None, :] 20 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 21 | return pe.type(dtype) 22 | 23 | # classes 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.LayerNorm(dim), 30 | nn.Linear(dim, hidden_dim), 31 | nn.GELU(), 32 | nn.Linear(hidden_dim, dim), 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | self.heads = heads 42 | self.scale = dim_head ** -0.5 43 | self.norm = nn.LayerNorm(dim) 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 49 | 50 | def forward(self, x): 51 | x = self.norm(x) 52 | 53 | qkv = self.to_qkv(x).chunk(3, dim = -1) 54 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 55 | 56 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 57 | 58 | attn = self.attend(dots) 59 | 60 | out = torch.matmul(attn, v) 61 | out = rearrange(out, 'b h n d -> b n (h d)') 62 | return self.to_out(out) 63 | 64 | class Transformer(nn.Module): 65 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 66 | super().__init__() 67 | self.norm = nn.LayerNorm(dim) 68 | self.layers = nn.ModuleList([]) 69 | for _ in range(depth): 70 | self.layers.append(nn.ModuleList([ 71 | Attention(dim, heads = heads, dim_head = dim_head), 72 | FeedForward(dim, mlp_dim) 73 | ])) 74 | def forward(self, x): 75 | for attn, ff in self.layers: 76 | x = attn(x) + x 77 | x = ff(x) + x 78 | return self.norm(x) 79 | 80 | class SimpleViT(nn.Module): 81 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 82 | super().__init__() 83 | image_height, image_width = pair(image_size) 84 | patch_height, patch_width = pair(patch_size) 85 | 86 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 87 | 88 | patch_dim = channels * patch_height * patch_width 89 | 90 | self.to_patch_embedding = nn.Sequential( 91 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 92 | nn.LayerNorm(patch_dim), 93 | nn.Linear(patch_dim, dim), 94 | nn.LayerNorm(dim), 95 | ) 96 | 97 | self.pos_embedding = posemb_sincos_2d( 98 | h = image_height // patch_height, 99 | w = image_width // patch_width, 100 | dim = dim, 101 | ) 102 | 103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 104 | 105 | self.pool = "mean" 106 | self.to_latent = nn.Identity() 107 | 108 | self.linear_head = nn.Linear(dim, num_classes) 109 | 110 | def forward(self, img): 111 | device = img.device 112 | 113 | x = self.to_patch_embedding(img) 114 | x += self.pos_embedding.to(device, dtype=x.dtype) 115 | 116 | x = self.transformer(x) 117 | x = x.mean(dim = 1) 118 | 119 | x = self.to_latent(x) 120 | return self.linear_head(x) 121 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32): 10 | _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype 11 | 12 | n = torch.arange(n, device = device) 13 | assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb' 14 | omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1) 15 | omega = 1. / (temperature ** omega) 16 | 17 | n = n.flatten()[:, None] * omega[None, :] 18 | pe = torch.cat((n.sin(), n.cos()), dim = 1) 19 | return pe.type(dtype) 20 | 21 | # classes 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, hidden_dim), 29 | nn.GELU(), 30 | nn.Linear(hidden_dim, dim), 31 | ) 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads = 8, dim_head = 64): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | self.heads = heads 40 | self.scale = dim_head ** -0.5 41 | self.norm = nn.LayerNorm(dim) 42 | 43 | self.attend = nn.Softmax(dim = -1) 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 47 | 48 | def forward(self, x): 49 | x = self.norm(x) 50 | 51 | qkv = self.to_qkv(x).chunk(3, dim = -1) 52 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 53 | 54 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 55 | 56 | attn = self.attend(dots) 57 | 58 | out = torch.matmul(attn, v) 59 | out = rearrange(out, 'b h n d -> b n (h d)') 60 | return self.to_out(out) 61 | 62 | class Transformer(nn.Module): 63 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 64 | super().__init__() 65 | self.norm = nn.LayerNorm(dim) 66 | self.layers = nn.ModuleList([]) 67 | for _ in range(depth): 68 | self.layers.append(nn.ModuleList([ 69 | Attention(dim, heads = heads, dim_head = dim_head), 70 | FeedForward(dim, mlp_dim) 71 | ])) 72 | def forward(self, x): 73 | for attn, ff in self.layers: 74 | x = attn(x) + x 75 | x = ff(x) + x 76 | return self.norm(x) 77 | 78 | class SimpleViT(nn.Module): 79 | def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 80 | super().__init__() 81 | 82 | assert seq_len % patch_size == 0 83 | 84 | num_patches = seq_len // patch_size 85 | patch_dim = channels * patch_size 86 | 87 | self.to_patch_embedding = nn.Sequential( 88 | Rearrange('b c (n p) -> b n (p c)', p = patch_size), 89 | nn.LayerNorm(patch_dim), 90 | nn.Linear(patch_dim, dim), 91 | nn.LayerNorm(dim), 92 | ) 93 | 94 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 95 | 96 | self.to_latent = nn.Identity() 97 | self.linear_head = nn.Linear(dim, num_classes) 98 | 99 | def forward(self, series): 100 | *_, n, dtype = *series.shape, series.dtype 101 | 102 | x = self.to_patch_embedding(series) 103 | pe = posemb_sincos_1d(x) 104 | x = rearrange(x, 'b ... d -> b (...) d') + pe 105 | 106 | x = self.transformer(x) 107 | x = x.mean(dim = 1) 108 | 109 | x = self.to_latent(x) 110 | return self.linear_head(x) 111 | 112 | if __name__ == '__main__': 113 | 114 | v = SimpleViT( 115 | seq_len = 256, 116 | patch_size = 16, 117 | num_classes = 1000, 118 | dim = 1024, 119 | depth = 6, 120 | heads = 8, 121 | mlp_dim = 2048 122 | ) 123 | 124 | time_series = torch.randn(4, 3, 256) 125 | logits = v(time_series) # (4, 1000) 126 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32): 14 | _, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 15 | 16 | z, y, x = torch.meshgrid( 17 | torch.arange(f, device = device), 18 | torch.arange(h, device = device), 19 | torch.arange(w, device = device), 20 | indexing = 'ij') 21 | 22 | fourier_dim = dim // 6 23 | 24 | omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1) 25 | omega = 1. / (temperature ** omega) 26 | 27 | z = z.flatten()[:, None] * omega[None, :] 28 | y = y.flatten()[:, None] * omega[None, :] 29 | x = x.flatten()[:, None] * omega[None, :] 30 | 31 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1) 32 | 33 | pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6 34 | return pe.type(dtype) 35 | 36 | # classes 37 | 38 | class FeedForward(nn.Module): 39 | def __init__(self, dim, hidden_dim): 40 | super().__init__() 41 | self.net = nn.Sequential( 42 | nn.LayerNorm(dim), 43 | nn.Linear(dim, hidden_dim), 44 | nn.GELU(), 45 | nn.Linear(hidden_dim, dim), 46 | ) 47 | def forward(self, x): 48 | return self.net(x) 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, dim, heads = 8, dim_head = 64): 52 | super().__init__() 53 | inner_dim = dim_head * heads 54 | self.heads = heads 55 | self.scale = dim_head ** -0.5 56 | self.norm = nn.LayerNorm(dim) 57 | 58 | self.attend = nn.Softmax(dim = -1) 59 | 60 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 62 | 63 | def forward(self, x): 64 | x = self.norm(x) 65 | 66 | qkv = self.to_qkv(x).chunk(3, dim = -1) 67 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 68 | 69 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 70 | 71 | attn = self.attend(dots) 72 | 73 | out = torch.matmul(attn, v) 74 | out = rearrange(out, 'b h n d -> b n (h d)') 75 | return self.to_out(out) 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 79 | super().__init__() 80 | self.norm = nn.LayerNorm(dim) 81 | self.layers = nn.ModuleList([]) 82 | for _ in range(depth): 83 | self.layers.append(nn.ModuleList([ 84 | Attention(dim, heads = heads, dim_head = dim_head), 85 | FeedForward(dim, mlp_dim) 86 | ])) 87 | def forward(self, x): 88 | for attn, ff in self.layers: 89 | x = attn(x) + x 90 | x = ff(x) + x 91 | return self.norm(x) 92 | 93 | class SimpleViT(nn.Module): 94 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 95 | super().__init__() 96 | image_height, image_width = pair(image_size) 97 | patch_height, patch_width = pair(image_patch_size) 98 | 99 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 100 | assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size' 101 | 102 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 103 | patch_dim = channels * patch_height * patch_width * frame_patch_size 104 | 105 | self.to_patch_embedding = nn.Sequential( 106 | Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 107 | nn.LayerNorm(patch_dim), 108 | nn.Linear(patch_dim, dim), 109 | nn.LayerNorm(dim), 110 | ) 111 | 112 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 113 | 114 | self.to_latent = nn.Identity() 115 | self.linear_head = nn.Linear(dim, num_classes) 116 | 117 | def forward(self, video): 118 | *_, h, w, dtype = *video.shape, video.dtype 119 | 120 | x = self.to_patch_embedding(video) 121 | pe = posemb_sincos_3d(x) 122 | x = rearrange(x, 'b ... d -> b (...) d') + pe 123 | 124 | x = self.transformer(x) 125 | x = x.mean(dim = 1) 126 | 127 | x = self.to_latent(x) 128 | return self.linear_head(x) 129 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_fft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fft2 3 | from torch import nn 4 | 5 | from einops import rearrange, reduce, pack, unpack 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 14 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 15 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 16 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 17 | omega = 1.0 / (temperature ** omega) 18 | 19 | y = y.flatten()[:, None] * omega[None, :] 20 | x = x.flatten()[:, None] * omega[None, :] 21 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 22 | return pe.type(dtype) 23 | 24 | # classes 25 | 26 | class FeedForward(nn.Module): 27 | def __init__(self, dim, hidden_dim): 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, hidden_dim), 32 | nn.GELU(), 33 | nn.Linear(hidden_dim, dim), 34 | ) 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads = 8, dim_head = 64): 40 | super().__init__() 41 | inner_dim = dim_head * heads 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | self.norm = nn.LayerNorm(dim) 45 | 46 | self.attend = nn.Softmax(dim = -1) 47 | 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 50 | 51 | def forward(self, x): 52 | x = self.norm(x) 53 | 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 56 | 57 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 58 | 59 | attn = self.attend(dots) 60 | 61 | out = torch.matmul(attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | return self.to_out(out) 64 | 65 | class Transformer(nn.Module): 66 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 67 | super().__init__() 68 | self.norm = nn.LayerNorm(dim) 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | Attention(dim, heads = heads, dim_head = dim_head), 73 | FeedForward(dim, mlp_dim) 74 | ])) 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) + x 78 | x = ff(x) + x 79 | return self.norm(x) 80 | 81 | class SimpleViT(nn.Module): 82 | def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 83 | super().__init__() 84 | image_height, image_width = pair(image_size) 85 | patch_height, patch_width = pair(patch_size) 86 | freq_patch_height, freq_patch_width = pair(freq_patch_size) 87 | 88 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 89 | assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.' 90 | 91 | patch_dim = channels * patch_height * patch_width 92 | freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width 93 | 94 | self.to_patch_embedding = nn.Sequential( 95 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 96 | nn.LayerNorm(patch_dim), 97 | nn.Linear(patch_dim, dim), 98 | nn.LayerNorm(dim), 99 | ) 100 | 101 | self.to_freq_embedding = nn.Sequential( 102 | Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width), 103 | nn.LayerNorm(freq_patch_dim), 104 | nn.Linear(freq_patch_dim, dim), 105 | nn.LayerNorm(dim) 106 | ) 107 | 108 | self.pos_embedding = posemb_sincos_2d( 109 | h = image_height // patch_height, 110 | w = image_width // patch_width, 111 | dim = dim, 112 | ) 113 | 114 | self.freq_pos_embedding = posemb_sincos_2d( 115 | h = image_height // freq_patch_height, 116 | w = image_width // freq_patch_width, 117 | dim = dim 118 | ) 119 | 120 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 121 | 122 | self.pool = "mean" 123 | self.to_latent = nn.Identity() 124 | 125 | self.linear_head = nn.Linear(dim, num_classes) 126 | 127 | def forward(self, img): 128 | device, dtype = img.device, img.dtype 129 | 130 | x = self.to_patch_embedding(img) 131 | freqs = torch.view_as_real(fft2(img)) 132 | 133 | f = self.to_freq_embedding(freqs) 134 | 135 | x += self.pos_embedding.to(device, dtype = dtype) 136 | f += self.freq_pos_embedding.to(device, dtype = dtype) 137 | 138 | x, ps = pack((f, x), 'b * d') 139 | 140 | x = self.transformer(x) 141 | 142 | _, x = unpack(x, ps, 'b * d') 143 | x = reduce(x, 'b n d -> b d', 'mean') 144 | 145 | x = self.to_latent(x) 146 | return self.linear_head(x) 147 | 148 | if __name__ == '__main__': 149 | vit = SimpleViT( 150 | num_classes = 1000, 151 | image_size = 256, 152 | patch_size = 8, 153 | freq_patch_size = 8, 154 | dim = 1024, 155 | depth = 1, 156 | heads = 8, 157 | mlp_dim = 2048, 158 | ) 159 | 160 | images = torch.randn(8, 3, 256, 256) 161 | 162 | logits = vit(images) 163 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_hyper_connections.py: -------------------------------------------------------------------------------- 1 | """ 2 | ViT + Hyper-Connections + Register Tokens 3 | https://arxiv.org/abs/2409.19606 4 | """ 5 | 6 | import torch 7 | from torch import nn, tensor 8 | from torch.nn import Module, ModuleList 9 | 10 | from einops import rearrange, repeat, reduce, einsum, pack, unpack 11 | from einops.layers.torch import Rearrange 12 | 13 | # b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension 14 | 15 | # helpers 16 | 17 | def pair(t): 18 | return t if isinstance(t, tuple) else (t, t) 19 | 20 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 21 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 22 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 23 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 24 | omega = 1.0 / (temperature ** omega) 25 | 26 | y = y.flatten()[:, None] * omega[None, :] 27 | x = x.flatten()[:, None] * omega[None, :] 28 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 29 | return pe.type(dtype) 30 | 31 | # hyper connections 32 | 33 | class HyperConnection(Module): 34 | def __init__( 35 | self, 36 | dim, 37 | num_residual_streams, 38 | layer_index 39 | ): 40 | """ Appendix J - Algorithm 2, Dynamic only """ 41 | super().__init__() 42 | 43 | self.norm = nn.LayerNorm(dim, bias = False) 44 | 45 | self.num_residual_streams = num_residual_streams 46 | self.layer_index = layer_index 47 | 48 | self.static_beta = nn.Parameter(torch.ones(num_residual_streams)) 49 | 50 | init_alpha0 = torch.zeros((num_residual_streams, 1)) 51 | init_alpha0[layer_index % num_residual_streams, 0] = 1. 52 | 53 | self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1)) 54 | 55 | self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1)) 56 | self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2)) 57 | self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim)) 58 | self.dynamic_beta_scale = nn.Parameter(tensor(1e-2)) 59 | 60 | def width_connection(self, residuals): 61 | normed = self.norm(residuals) 62 | 63 | wc_weight = (normed @ self.dynamic_alpha_fn).tanh() 64 | dynamic_alpha = wc_weight * self.dynamic_alpha_scale 65 | alpha = dynamic_alpha + self.static_alpha 66 | 67 | dc_weight = (normed @ self.dynamic_beta_fn).tanh() 68 | dynamic_beta = dc_weight * self.dynamic_beta_scale 69 | beta = dynamic_beta + self.static_beta 70 | 71 | # width connection 72 | mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d') 73 | 74 | branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :] 75 | 76 | return branch_input, residuals, beta 77 | 78 | def depth_connection( 79 | self, 80 | branch_output, 81 | residuals, 82 | beta 83 | ): 84 | return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals 85 | 86 | # classes 87 | 88 | class FeedForward(Module): 89 | def __init__(self, dim, hidden_dim): 90 | super().__init__() 91 | self.net = nn.Sequential( 92 | nn.LayerNorm(dim), 93 | nn.Linear(dim, hidden_dim), 94 | nn.GELU(), 95 | nn.Linear(hidden_dim, dim), 96 | ) 97 | def forward(self, x): 98 | return self.net(x) 99 | 100 | class Attention(Module): 101 | def __init__(self, dim, heads = 8, dim_head = 64): 102 | super().__init__() 103 | inner_dim = dim_head * heads 104 | self.heads = heads 105 | self.scale = dim_head ** -0.5 106 | self.norm = nn.LayerNorm(dim) 107 | 108 | self.attend = nn.Softmax(dim = -1) 109 | 110 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 111 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 112 | 113 | def forward(self, x): 114 | x = self.norm(x) 115 | 116 | qkv = self.to_qkv(x).chunk(3, dim = -1) 117 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 118 | 119 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 120 | 121 | attn = self.attend(dots) 122 | 123 | out = torch.matmul(attn, v) 124 | out = rearrange(out, 'b h n d -> b n (h d)') 125 | return self.to_out(out) 126 | 127 | class Transformer(Module): 128 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams): 129 | super().__init__() 130 | 131 | self.num_residual_streams = num_residual_streams 132 | 133 | self.norm = nn.LayerNorm(dim) 134 | self.layers = ModuleList([]) 135 | 136 | for layer_index in range(depth): 137 | self.layers.append(nn.ModuleList([ 138 | HyperConnection(dim, num_residual_streams, layer_index), 139 | Attention(dim, heads = heads, dim_head = dim_head), 140 | HyperConnection(dim, num_residual_streams, layer_index), 141 | FeedForward(dim, mlp_dim) 142 | ])) 143 | 144 | def forward(self, x): 145 | 146 | x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams) 147 | 148 | for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers: 149 | 150 | x, attn_res, beta = attn_hyper_conn.width_connection(x) 151 | 152 | x = attn(x) 153 | 154 | x = attn_hyper_conn.depth_connection(x, attn_res, beta) 155 | 156 | x, ff_res, beta = ff_hyper_conn.width_connection(x) 157 | 158 | x = ff(x) 159 | 160 | x = ff_hyper_conn.depth_connection(x, ff_res, beta) 161 | 162 | x = reduce(x, 'b n e d -> b n d', 'sum') 163 | 164 | return self.norm(x) 165 | 166 | class SimpleViT(nn.Module): 167 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64): 168 | super().__init__() 169 | image_height, image_width = pair(image_size) 170 | patch_height, patch_width = pair(patch_size) 171 | 172 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 173 | 174 | patch_dim = channels * patch_height * patch_width 175 | 176 | self.to_patch_embedding = nn.Sequential( 177 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 178 | nn.LayerNorm(patch_dim), 179 | nn.Linear(patch_dim, dim), 180 | nn.LayerNorm(dim), 181 | ) 182 | 183 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) 184 | 185 | self.pos_embedding = posemb_sincos_2d( 186 | h = image_height // patch_height, 187 | w = image_width // patch_width, 188 | dim = dim, 189 | ) 190 | 191 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams) 192 | 193 | self.pool = "mean" 194 | self.to_latent = nn.Identity() 195 | 196 | self.linear_head = nn.Linear(dim, num_classes) 197 | 198 | def forward(self, img): 199 | batch, device = img.shape[0], img.device 200 | 201 | x = self.to_patch_embedding(img) 202 | x += self.pos_embedding.to(x) 203 | 204 | r = repeat(self.register_tokens, 'n d -> b n d', b = batch) 205 | 206 | x, ps = pack([x, r], 'b * d') 207 | 208 | x = self.transformer(x) 209 | 210 | x, _ = unpack(x, ps, 'b * d') 211 | 212 | x = x.mean(dim = 1) 213 | 214 | x = self.to_latent(x) 215 | return self.linear_head(x) 216 | 217 | # main 218 | 219 | if __name__ == '__main__': 220 | vit = SimpleViT( 221 | num_classes = 1000, 222 | image_size = 256, 223 | patch_size = 8, 224 | dim = 1024, 225 | depth = 12, 226 | heads = 8, 227 | mlp_dim = 2048, 228 | num_residual_streams = 8 229 | ) 230 | 231 | images = torch.randn(3, 3, 256, 256) 232 | 233 | logits = vit(images) 234 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_patch_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): 13 | _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 14 | 15 | y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') 16 | assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' 17 | omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) 18 | omega = 1. / (temperature ** omega) 19 | 20 | y = y.flatten()[:, None] * omega[None, :] 21 | x = x.flatten()[:, None] * omega[None, :] 22 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) 23 | return pe.type(dtype) 24 | 25 | # patch dropout 26 | 27 | class PatchDropout(nn.Module): 28 | def __init__(self, prob): 29 | super().__init__() 30 | assert 0 <= prob < 1. 31 | self.prob = prob 32 | 33 | def forward(self, x): 34 | if not self.training or self.prob == 0.: 35 | return x 36 | 37 | b, n, _, device = *x.shape, x.device 38 | 39 | batch_indices = torch.arange(b, device = device) 40 | batch_indices = rearrange(batch_indices, '... -> ... 1') 41 | num_patches_keep = max(1, int(n * (1 - self.prob))) 42 | patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices 43 | 44 | return x[batch_indices, patch_indices_keep] 45 | 46 | # classes 47 | 48 | class FeedForward(nn.Module): 49 | def __init__(self, dim, hidden_dim): 50 | super().__init__() 51 | self.net = nn.Sequential( 52 | nn.LayerNorm(dim), 53 | nn.Linear(dim, hidden_dim), 54 | nn.GELU(), 55 | nn.Linear(hidden_dim, dim), 56 | ) 57 | def forward(self, x): 58 | return self.net(x) 59 | 60 | class Attention(nn.Module): 61 | def __init__(self, dim, heads = 8, dim_head = 64): 62 | super().__init__() 63 | inner_dim = dim_head * heads 64 | self.heads = heads 65 | self.scale = dim_head ** -0.5 66 | self.norm = nn.LayerNorm(dim) 67 | 68 | self.attend = nn.Softmax(dim = -1) 69 | 70 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 71 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 72 | 73 | def forward(self, x): 74 | x = self.norm(x) 75 | 76 | qkv = self.to_qkv(x).chunk(3, dim = -1) 77 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 78 | 79 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 80 | 81 | attn = self.attend(dots) 82 | 83 | out = torch.matmul(attn, v) 84 | out = rearrange(out, 'b h n d -> b n (h d)') 85 | return self.to_out(out) 86 | 87 | class Transformer(nn.Module): 88 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 89 | super().__init__() 90 | self.norm = nn.LayerNorm(dim) 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append(nn.ModuleList([ 94 | Attention(dim, heads = heads, dim_head = dim_head), 95 | FeedForward(dim, mlp_dim) 96 | ])) 97 | def forward(self, x): 98 | for attn, ff in self.layers: 99 | x = attn(x) + x 100 | x = ff(x) + x 101 | return self.norm(x) 102 | 103 | class SimpleViT(nn.Module): 104 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5): 105 | super().__init__() 106 | image_height, image_width = pair(image_size) 107 | patch_height, patch_width = pair(patch_size) 108 | 109 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 110 | 111 | num_patches = (image_height // patch_height) * (image_width // patch_width) 112 | patch_dim = channels * patch_height * patch_width 113 | 114 | self.to_patch_embedding = nn.Sequential( 115 | Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width), 116 | nn.LayerNorm(patch_dim), 117 | nn.Linear(patch_dim, dim), 118 | nn.LayerNorm(dim) 119 | ) 120 | 121 | self.patch_dropout = PatchDropout(patch_dropout) 122 | 123 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 124 | 125 | self.to_latent = nn.Identity() 126 | self.linear_head = nn.Linear(dim, num_classes) 127 | 128 | def forward(self, img): 129 | *_, h, w, dtype = *img.shape, img.dtype 130 | 131 | x = self.to_patch_embedding(img) 132 | pe = posemb_sincos_2d(x) 133 | x = rearrange(x, 'b ... d -> b (...) d') + pe 134 | 135 | x = self.patch_dropout(x) 136 | 137 | x = self.transformer(x) 138 | x = x.mean(dim = 1) 139 | 140 | x = self.to_latent(x) 141 | return self.linear_head(x) 142 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_qk_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 14 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 15 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 16 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 17 | omega = 1.0 / (temperature ** omega) 18 | 19 | y = y.flatten()[:, None] * omega[None, :] 20 | x = x.flatten()[:, None] * omega[None, :] 21 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 22 | return pe.type(dtype) 23 | 24 | # they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper 25 | 26 | # in latest tweet, seem to claim more stable training at higher learning rates 27 | # unsure if this has taken off within Brain, or it has some hidden drawback 28 | 29 | class RMSNorm(nn.Module): 30 | def __init__(self, heads, dim): 31 | super().__init__() 32 | self.scale = dim ** 0.5 33 | self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale) 34 | 35 | def forward(self, x): 36 | normed = F.normalize(x, dim = -1) 37 | return normed * self.scale * self.gamma 38 | 39 | # classes 40 | 41 | class FeedForward(nn.Module): 42 | def __init__(self, dim, hidden_dim): 43 | super().__init__() 44 | self.net = nn.Sequential( 45 | nn.LayerNorm(dim), 46 | nn.Linear(dim, hidden_dim), 47 | nn.GELU(), 48 | nn.Linear(hidden_dim, dim), 49 | ) 50 | def forward(self, x): 51 | return self.net(x) 52 | 53 | class Attention(nn.Module): 54 | def __init__(self, dim, heads = 8, dim_head = 64): 55 | super().__init__() 56 | inner_dim = dim_head * heads 57 | self.heads = heads 58 | self.norm = nn.LayerNorm(dim) 59 | 60 | self.attend = nn.Softmax(dim = -1) 61 | 62 | self.q_norm = RMSNorm(heads, dim_head) 63 | self.k_norm = RMSNorm(heads, dim_head) 64 | 65 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 66 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 67 | 68 | def forward(self, x): 69 | x = self.norm(x) 70 | 71 | qkv = self.to_qkv(x).chunk(3, dim = -1) 72 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 73 | 74 | q = self.q_norm(q) 75 | k = self.k_norm(k) 76 | 77 | dots = torch.matmul(q, k.transpose(-1, -2)) 78 | 79 | attn = self.attend(dots) 80 | 81 | out = torch.matmul(attn, v) 82 | out = rearrange(out, 'b h n d -> b n (h d)') 83 | return self.to_out(out) 84 | 85 | class Transformer(nn.Module): 86 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 87 | super().__init__() 88 | self.norm = nn.LayerNorm(dim) 89 | self.layers = nn.ModuleList([]) 90 | for _ in range(depth): 91 | self.layers.append(nn.ModuleList([ 92 | Attention(dim, heads = heads, dim_head = dim_head), 93 | FeedForward(dim, mlp_dim) 94 | ])) 95 | def forward(self, x): 96 | for attn, ff in self.layers: 97 | x = attn(x) + x 98 | x = ff(x) + x 99 | return self.norm(x) 100 | 101 | class SimpleViT(nn.Module): 102 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 103 | super().__init__() 104 | image_height, image_width = pair(image_size) 105 | patch_height, patch_width = pair(patch_size) 106 | 107 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 108 | 109 | patch_dim = channels * patch_height * patch_width 110 | 111 | self.to_patch_embedding = nn.Sequential( 112 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 113 | nn.LayerNorm(patch_dim), 114 | nn.Linear(patch_dim, dim), 115 | nn.LayerNorm(dim), 116 | ) 117 | 118 | self.pos_embedding = posemb_sincos_2d( 119 | h = image_height // patch_height, 120 | w = image_width // patch_width, 121 | dim = dim, 122 | ) 123 | 124 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 125 | 126 | self.pool = "mean" 127 | self.to_latent = nn.Identity() 128 | 129 | self.linear_head = nn.LayerNorm(dim) 130 | 131 | def forward(self, img): 132 | device = img.device 133 | 134 | x = self.to_patch_embedding(img) 135 | x += self.pos_embedding.to(device, dtype=x.dtype) 136 | 137 | x = self.transformer(x) 138 | x = x.mean(dim = 1) 139 | 140 | x = self.to_latent(x) 141 | return self.linear_head(x) 142 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_register_tokens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vision Transformers Need Registers 3 | https://arxiv.org/abs/2309.16588 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from einops import rearrange, repeat, pack, unpack 10 | from einops.layers.torch import Rearrange 11 | 12 | # helpers 13 | 14 | def pair(t): 15 | return t if isinstance(t, tuple) else (t, t) 16 | 17 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 18 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 19 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 20 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 21 | omega = 1.0 / (temperature ** omega) 22 | 23 | y = y.flatten()[:, None] * omega[None, :] 24 | x = x.flatten()[:, None] * omega[None, :] 25 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 26 | return pe.type(dtype) 27 | 28 | # classes 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, dim, hidden_dim): 32 | super().__init__() 33 | self.net = nn.Sequential( 34 | nn.LayerNorm(dim), 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Linear(hidden_dim, dim), 38 | ) 39 | def forward(self, x): 40 | return self.net(x) 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, dim, heads = 8, dim_head = 64): 44 | super().__init__() 45 | inner_dim = dim_head * heads 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | self.norm = nn.LayerNorm(dim) 49 | 50 | self.attend = nn.Softmax(dim = -1) 51 | 52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 53 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 54 | 55 | def forward(self, x): 56 | x = self.norm(x) 57 | 58 | qkv = self.to_qkv(x).chunk(3, dim = -1) 59 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 60 | 61 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 62 | 63 | attn = self.attend(dots) 64 | 65 | out = torch.matmul(attn, v) 66 | out = rearrange(out, 'b h n d -> b n (h d)') 67 | return self.to_out(out) 68 | 69 | class Transformer(nn.Module): 70 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 71 | super().__init__() 72 | self.norm = nn.LayerNorm(dim) 73 | self.layers = nn.ModuleList([]) 74 | for _ in range(depth): 75 | self.layers.append(nn.ModuleList([ 76 | Attention(dim, heads = heads, dim_head = dim_head), 77 | FeedForward(dim, mlp_dim) 78 | ])) 79 | def forward(self, x): 80 | for attn, ff in self.layers: 81 | x = attn(x) + x 82 | x = ff(x) + x 83 | return self.norm(x) 84 | 85 | class SimpleViT(nn.Module): 86 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64): 87 | super().__init__() 88 | image_height, image_width = pair(image_size) 89 | patch_height, patch_width = pair(patch_size) 90 | 91 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 92 | 93 | patch_dim = channels * patch_height * patch_width 94 | 95 | self.to_patch_embedding = nn.Sequential( 96 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 97 | nn.LayerNorm(patch_dim), 98 | nn.Linear(patch_dim, dim), 99 | nn.LayerNorm(dim), 100 | ) 101 | 102 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) 103 | 104 | self.pos_embedding = posemb_sincos_2d( 105 | h = image_height // patch_height, 106 | w = image_width // patch_width, 107 | dim = dim, 108 | ) 109 | 110 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 111 | 112 | self.pool = "mean" 113 | self.to_latent = nn.Identity() 114 | 115 | self.linear_head = nn.Linear(dim, num_classes) 116 | 117 | def forward(self, img): 118 | batch, device = img.shape[0], img.device 119 | 120 | x = self.to_patch_embedding(img) 121 | x += self.pos_embedding.to(device, dtype=x.dtype) 122 | 123 | r = repeat(self.register_tokens, 'n d -> b n d', b = batch) 124 | 125 | x, ps = pack([x, r], 'b * d') 126 | 127 | x = self.transformer(x) 128 | 129 | x, _ = unpack(x, ps, 'b * d') 130 | 131 | x = x.mean(dim = 1) 132 | 133 | x = self.to_latent(x) 134 | return self.linear_head(x) 135 | -------------------------------------------------------------------------------- /vit_pytorch/simple_vit_with_value_residual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module, ModuleList 4 | 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | 8 | # helpers 9 | 10 | def exists(v): 11 | return v is not None 12 | 13 | def default(v, d): 14 | return v if exists(v) else d 15 | 16 | def pair(t): 17 | return t if isinstance(t, tuple) else (t, t) 18 | 19 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 20 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 21 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 22 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 23 | omega = 1.0 / (temperature ** omega) 24 | 25 | y = y.flatten()[:, None] * omega[None, :] 26 | x = x.flatten()[:, None] * omega[None, :] 27 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 28 | return pe.type(dtype) 29 | 30 | # classes 31 | 32 | def FeedForward(dim, hidden_dim): 33 | return nn.Sequential( 34 | nn.LayerNorm(dim), 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Linear(hidden_dim, dim), 38 | ) 39 | 40 | class Attention(Module): 41 | def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False): 42 | super().__init__() 43 | inner_dim = dim_head * heads 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | self.norm = nn.LayerNorm(dim) 47 | 48 | self.attend = nn.Softmax(dim = -1) 49 | 50 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 51 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 52 | 53 | self.to_residual_mix = nn.Sequential( 54 | nn.Linear(dim, heads), 55 | nn.Sigmoid(), 56 | Rearrange('b n h -> b h n 1') 57 | ) if learned_value_residual_mix else (lambda _: 0.5) 58 | 59 | def forward(self, x, value_residual = None): 60 | x = self.norm(x) 61 | 62 | qkv = self.to_qkv(x).chunk(3, dim = -1) 63 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 64 | 65 | if exists(value_residual): 66 | mix = self.to_residual_mix(x) 67 | v = v * mix + value_residual * (1. - mix) 68 | 69 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 70 | 71 | attn = self.attend(dots) 72 | 73 | out = torch.matmul(attn, v) 74 | out = rearrange(out, 'b h n d -> b n (h d)') 75 | 76 | return self.to_out(out), v 77 | 78 | class Transformer(Module): 79 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 80 | super().__init__() 81 | self.norm = nn.LayerNorm(dim) 82 | self.layers = ModuleList([]) 83 | for i in range(depth): 84 | is_first = i == 0 85 | self.layers.append(ModuleList([ 86 | Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first), 87 | FeedForward(dim, mlp_dim) 88 | ])) 89 | def forward(self, x): 90 | value_residual = None 91 | 92 | for attn, ff in self.layers: 93 | 94 | attn_out, values = attn(x, value_residual = value_residual) 95 | value_residual = default(value_residual, values) 96 | 97 | x = attn_out + x 98 | x = ff(x) + x 99 | 100 | return self.norm(x) 101 | 102 | class SimpleViT(Module): 103 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64): 104 | super().__init__() 105 | image_height, image_width = pair(image_size) 106 | patch_height, patch_width = pair(patch_size) 107 | 108 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 109 | 110 | patch_dim = channels * patch_height * patch_width 111 | 112 | self.to_patch_embedding = nn.Sequential( 113 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 114 | nn.LayerNorm(patch_dim), 115 | nn.Linear(patch_dim, dim), 116 | nn.LayerNorm(dim), 117 | ) 118 | 119 | self.pos_embedding = posemb_sincos_2d( 120 | h = image_height // patch_height, 121 | w = image_width // patch_width, 122 | dim = dim, 123 | ) 124 | 125 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 126 | 127 | self.pool = "mean" 128 | self.to_latent = nn.Identity() 129 | 130 | self.linear_head = nn.Linear(dim, num_classes) 131 | 132 | def forward(self, img): 133 | device = img.device 134 | 135 | x = self.to_patch_embedding(img) 136 | x += self.pos_embedding.to(device, dtype=x.dtype) 137 | 138 | x = self.transformer(x) 139 | x = x.mean(dim = 1) 140 | 141 | x = self.to_latent(x) 142 | return self.linear_head(x) 143 | 144 | # quick test 145 | 146 | if __name__ == '__main__': 147 | v = SimpleViT( 148 | num_classes = 1000, 149 | image_size = 256, 150 | patch_size = 8, 151 | dim = 1024, 152 | depth = 6, 153 | heads = 8, 154 | mlp_dim = 2048, 155 | ) 156 | 157 | images = torch.randn(2, 3, 256, 256) 158 | 159 | logits = v(images) 160 | -------------------------------------------------------------------------------- /vit_pytorch/t2t.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | from vit_pytorch.vit import Transformer 6 | 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def conv_output_size(image_size, kernel_size, stride, padding): 16 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 17 | 18 | # classes 19 | 20 | class RearrangeImage(nn.Module): 21 | def forward(self, x): 22 | return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1]))) 23 | 24 | # main class 25 | 26 | class T2TViT(nn.Module): 27 | def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))): 28 | super().__init__() 29 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 30 | 31 | layers = [] 32 | layer_dim = channels 33 | output_image_size = image_size 34 | 35 | for i, (kernel_size, stride) in enumerate(t2t_layers): 36 | layer_dim *= kernel_size ** 2 37 | is_first = i == 0 38 | is_last = i == (len(t2t_layers) - 1) 39 | output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) 40 | 41 | layers.extend([ 42 | RearrangeImage() if not is_first else nn.Identity(), 43 | nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), 44 | Rearrange('b c n -> b n c'), 45 | Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(), 46 | ]) 47 | 48 | layers.append(nn.Linear(layer_dim, dim)) 49 | self.to_patch_embedding = nn.Sequential(*layers) 50 | 51 | self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim)) 52 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 53 | self.dropout = nn.Dropout(emb_dropout) 54 | 55 | if not exists(transformer): 56 | assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied' 57 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 58 | else: 59 | self.transformer = transformer 60 | 61 | self.pool = pool 62 | self.to_latent = nn.Identity() 63 | 64 | self.mlp_head = nn.Linear(dim, num_classes) 65 | 66 | def forward(self, img): 67 | x = self.to_patch_embedding(img) 68 | b, n, _ = x.shape 69 | 70 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 71 | x = torch.cat((cls_tokens, x), dim=1) 72 | x += self.pos_embedding[:, :n+1] 73 | x = self.dropout(x) 74 | 75 | x = self.transformer(x) 76 | 77 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 78 | 79 | x = self.to_latent(x) 80 | return self.mlp_head(x) 81 | -------------------------------------------------------------------------------- /vit_pytorch/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | # classes 13 | 14 | class FeedForward(nn.Module): 15 | def __init__(self, dim, hidden_dim, dropout = 0.): 16 | super().__init__() 17 | self.net = nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, hidden_dim), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | nn.Linear(hidden_dim, dim), 23 | nn.Dropout(dropout) 24 | ) 25 | 26 | def forward(self, x): 27 | return self.net(x) 28 | 29 | class Attention(nn.Module): 30 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 31 | super().__init__() 32 | inner_dim = dim_head * heads 33 | project_out = not (heads == 1 and dim_head == dim) 34 | 35 | self.heads = heads 36 | self.scale = dim_head ** -0.5 37 | 38 | self.norm = nn.LayerNorm(dim) 39 | 40 | self.attend = nn.Softmax(dim = -1) 41 | self.dropout = nn.Dropout(dropout) 42 | 43 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 44 | 45 | self.to_out = nn.Sequential( 46 | nn.Linear(inner_dim, dim), 47 | nn.Dropout(dropout) 48 | ) if project_out else nn.Identity() 49 | 50 | def forward(self, x): 51 | x = self.norm(x) 52 | 53 | qkv = self.to_qkv(x).chunk(3, dim = -1) 54 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 55 | 56 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 57 | 58 | attn = self.attend(dots) 59 | attn = self.dropout(attn) 60 | 61 | out = torch.matmul(attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | return self.to_out(out) 64 | 65 | class Transformer(nn.Module): 66 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 67 | super().__init__() 68 | self.norm = nn.LayerNorm(dim) 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 73 | FeedForward(dim, mlp_dim, dropout = dropout) 74 | ])) 75 | 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | 81 | return self.norm(x) 82 | 83 | class ViT(nn.Module): 84 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 85 | super().__init__() 86 | image_height, image_width = pair(image_size) 87 | patch_height, patch_width = pair(patch_size) 88 | 89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 90 | 91 | num_patches = (image_height // patch_height) * (image_width // patch_width) 92 | patch_dim = channels * patch_height * patch_width 93 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 94 | 95 | self.to_patch_embedding = nn.Sequential( 96 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 97 | nn.LayerNorm(patch_dim), 98 | nn.Linear(patch_dim, dim), 99 | nn.LayerNorm(dim), 100 | ) 101 | 102 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 103 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 104 | self.dropout = nn.Dropout(emb_dropout) 105 | 106 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 107 | 108 | self.pool = pool 109 | self.to_latent = nn.Identity() 110 | 111 | self.mlp_head = nn.Linear(dim, num_classes) 112 | 113 | def forward(self, img): 114 | x = self.to_patch_embedding(img) 115 | b, n, _ = x.shape 116 | 117 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) 118 | x = torch.cat((cls_tokens, x), dim=1) 119 | x += self.pos_embedding[:, :(n + 1)] 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | 124 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 125 | 126 | x = self.to_latent(x) 127 | return self.mlp_head(x) 128 | -------------------------------------------------------------------------------- /vit_pytorch/vit_1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat, pack, unpack 5 | from einops.layers.torch import Rearrange 6 | 7 | # classes 8 | 9 | class FeedForward(nn.Module): 10 | def __init__(self, dim, hidden_dim, dropout = 0.): 11 | super().__init__() 12 | self.net = nn.Sequential( 13 | nn.LayerNorm(dim), 14 | nn.Linear(dim, hidden_dim), 15 | nn.GELU(), 16 | nn.Dropout(dropout), 17 | nn.Linear(hidden_dim, dim), 18 | nn.Dropout(dropout) 19 | ) 20 | def forward(self, x): 21 | return self.net(x) 22 | 23 | class Attention(nn.Module): 24 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 25 | super().__init__() 26 | inner_dim = dim_head * heads 27 | project_out = not (heads == 1 and dim_head == dim) 28 | 29 | self.heads = heads 30 | self.scale = dim_head ** -0.5 31 | 32 | self.norm = nn.LayerNorm(dim) 33 | self.attend = nn.Softmax(dim = -1) 34 | self.dropout = nn.Dropout(dropout) 35 | 36 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 37 | 38 | self.to_out = nn.Sequential( 39 | nn.Linear(inner_dim, dim), 40 | nn.Dropout(dropout) 41 | ) if project_out else nn.Identity() 42 | 43 | def forward(self, x): 44 | x = self.norm(x) 45 | qkv = self.to_qkv(x).chunk(3, dim = -1) 46 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 47 | 48 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 49 | 50 | attn = self.attend(dots) 51 | attn = self.dropout(attn) 52 | 53 | out = torch.matmul(attn, v) 54 | out = rearrange(out, 'b h n d -> b n (h d)') 55 | return self.to_out(out) 56 | 57 | class Transformer(nn.Module): 58 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 59 | super().__init__() 60 | self.layers = nn.ModuleList([]) 61 | for _ in range(depth): 62 | self.layers.append(nn.ModuleList([ 63 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 64 | FeedForward(dim, mlp_dim, dropout = dropout) 65 | ])) 66 | def forward(self, x): 67 | for attn, ff in self.layers: 68 | x = attn(x) + x 69 | x = ff(x) + x 70 | return x 71 | 72 | class ViT(nn.Module): 73 | def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 74 | super().__init__() 75 | assert (seq_len % patch_size) == 0 76 | 77 | num_patches = seq_len // patch_size 78 | patch_dim = channels * patch_size 79 | 80 | self.to_patch_embedding = nn.Sequential( 81 | Rearrange('b c (n p) -> b n (p c)', p = patch_size), 82 | nn.LayerNorm(patch_dim), 83 | nn.Linear(patch_dim, dim), 84 | nn.LayerNorm(dim), 85 | ) 86 | 87 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 88 | self.cls_token = nn.Parameter(torch.randn(dim)) 89 | self.dropout = nn.Dropout(emb_dropout) 90 | 91 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 92 | 93 | self.mlp_head = nn.Sequential( 94 | nn.LayerNorm(dim), 95 | nn.Linear(dim, num_classes) 96 | ) 97 | 98 | def forward(self, series): 99 | x = self.to_patch_embedding(series) 100 | b, n, _ = x.shape 101 | 102 | cls_tokens = repeat(self.cls_token, 'd -> b d', b = b) 103 | 104 | x, ps = pack([cls_tokens, x], 'b * d') 105 | 106 | x += self.pos_embedding[:, :(n + 1)] 107 | x = self.dropout(x) 108 | 109 | x = self.transformer(x) 110 | 111 | cls_tokens, _ = unpack(x, ps, 'b * d') 112 | 113 | return self.mlp_head(cls_tokens) 114 | 115 | if __name__ == '__main__': 116 | 117 | v = ViT( 118 | seq_len = 256, 119 | patch_size = 16, 120 | num_classes = 1000, 121 | dim = 1024, 122 | depth = 6, 123 | heads = 8, 124 | mlp_dim = 2048, 125 | dropout = 0.1, 126 | emb_dropout = 0.1 127 | ) 128 | 129 | time_series = torch.randn(4, 3, 256) 130 | logits = v(time_series) # (4, 1000) 131 | -------------------------------------------------------------------------------- /vit_pytorch/vit_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | # classes 13 | 14 | class FeedForward(nn.Module): 15 | def __init__(self, dim, hidden_dim, dropout = 0.): 16 | super().__init__() 17 | self.net = nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, hidden_dim), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | nn.Linear(hidden_dim, dim), 23 | nn.Dropout(dropout) 24 | ) 25 | def forward(self, x): 26 | return self.net(x) 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 30 | super().__init__() 31 | inner_dim = dim_head * heads 32 | project_out = not (heads == 1 and dim_head == dim) 33 | 34 | self.heads = heads 35 | self.scale = dim_head ** -0.5 36 | 37 | self.norm = nn.LayerNorm(dim) 38 | self.attend = nn.Softmax(dim = -1) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 42 | 43 | self.to_out = nn.Sequential( 44 | nn.Linear(inner_dim, dim), 45 | nn.Dropout(dropout) 46 | ) if project_out else nn.Identity() 47 | 48 | def forward(self, x): 49 | x = self.norm(x) 50 | qkv = self.to_qkv(x).chunk(3, dim = -1) 51 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 52 | 53 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 54 | 55 | attn = self.attend(dots) 56 | attn = self.dropout(attn) 57 | 58 | out = torch.matmul(attn, v) 59 | out = rearrange(out, 'b h n d -> b n (h d)') 60 | return self.to_out(out) 61 | 62 | class Transformer(nn.Module): 63 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 64 | super().__init__() 65 | self.layers = nn.ModuleList([]) 66 | for _ in range(depth): 67 | self.layers.append(nn.ModuleList([ 68 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 69 | FeedForward(dim, mlp_dim, dropout = dropout) 70 | ])) 71 | def forward(self, x): 72 | for attn, ff in self.layers: 73 | x = attn(x) + x 74 | x = ff(x) + x 75 | return x 76 | 77 | class ViT(nn.Module): 78 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 79 | super().__init__() 80 | image_height, image_width = pair(image_size) 81 | patch_height, patch_width = pair(image_patch_size) 82 | 83 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 84 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' 85 | 86 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 87 | patch_dim = channels * patch_height * patch_width * frame_patch_size 88 | 89 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 90 | 91 | self.to_patch_embedding = nn.Sequential( 92 | Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 93 | nn.LayerNorm(patch_dim), 94 | nn.Linear(patch_dim, dim), 95 | nn.LayerNorm(dim), 96 | ) 97 | 98 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 99 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 100 | self.dropout = nn.Dropout(emb_dropout) 101 | 102 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 103 | 104 | self.pool = pool 105 | self.to_latent = nn.Identity() 106 | 107 | self.mlp_head = nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, num_classes) 110 | ) 111 | 112 | def forward(self, video): 113 | x = self.to_patch_embedding(video) 114 | b, n, _ = x.shape 115 | 116 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) 117 | x = torch.cat((cls_tokens, x), dim=1) 118 | x += self.pos_embedding[:, :(n + 1)] 119 | x = self.dropout(x) 120 | 121 | x = self.transformer(x) 122 | 123 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 124 | 125 | x = self.to_latent(x) 126 | return self.mlp_head(x) 127 | -------------------------------------------------------------------------------- /vit_pytorch/vit_for_small_dataset.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | # classes 15 | 16 | class FeedForward(nn.Module): 17 | def __init__(self, dim, hidden_dim, dropout = 0.): 18 | super().__init__() 19 | self.net = nn.Sequential( 20 | nn.LayerNorm(dim), 21 | nn.Linear(dim, hidden_dim), 22 | nn.GELU(), 23 | nn.Dropout(dropout), 24 | nn.Linear(hidden_dim, dim), 25 | nn.Dropout(dropout) 26 | ) 27 | def forward(self, x): 28 | return self.net(x) 29 | 30 | class LSA(nn.Module): 31 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 32 | super().__init__() 33 | inner_dim = dim_head * heads 34 | self.heads = heads 35 | self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5))) 36 | 37 | self.norm = nn.LayerNorm(dim) 38 | self.attend = nn.Softmax(dim = -1) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 42 | 43 | self.to_out = nn.Sequential( 44 | nn.Linear(inner_dim, dim), 45 | nn.Dropout(dropout) 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.norm(x) 50 | qkv = self.to_qkv(x).chunk(3, dim = -1) 51 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 52 | 53 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp() 54 | 55 | mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool) 56 | mask_value = -torch.finfo(dots.dtype).max 57 | dots = dots.masked_fill(mask, mask_value) 58 | 59 | attn = self.attend(dots) 60 | attn = self.dropout(attn) 61 | 62 | out = torch.matmul(attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | class Transformer(nn.Module): 67 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 68 | super().__init__() 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout), 73 | FeedForward(dim, mlp_dim, dropout = dropout) 74 | ])) 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) + x 78 | x = ff(x) + x 79 | return x 80 | 81 | class SPT(nn.Module): 82 | def __init__(self, *, dim, patch_size, channels = 3): 83 | super().__init__() 84 | patch_dim = patch_size * patch_size * 5 * channels 85 | 86 | self.to_patch_tokens = nn.Sequential( 87 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 88 | nn.LayerNorm(patch_dim), 89 | nn.Linear(patch_dim, dim) 90 | ) 91 | 92 | def forward(self, x): 93 | shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1)) 94 | shifted_x = list(map(lambda shift: F.pad(x, shift), shifts)) 95 | x_with_shifts = torch.cat((x, *shifted_x), dim = 1) 96 | return self.to_patch_tokens(x_with_shifts) 97 | 98 | class ViT(nn.Module): 99 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 100 | super().__init__() 101 | image_height, image_width = pair(image_size) 102 | patch_height, patch_width = pair(patch_size) 103 | 104 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 105 | 106 | num_patches = (image_height // patch_height) * (image_width // patch_width) 107 | patch_dim = channels * patch_height * patch_width 108 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 109 | 110 | self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels) 111 | 112 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 113 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 114 | self.dropout = nn.Dropout(emb_dropout) 115 | 116 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 117 | 118 | self.pool = pool 119 | self.to_latent = nn.Identity() 120 | 121 | self.mlp_head = nn.Sequential( 122 | nn.LayerNorm(dim), 123 | nn.Linear(dim, num_classes) 124 | ) 125 | 126 | def forward(self, img): 127 | x = self.to_patch_embedding(img) 128 | b, n, _ = x.shape 129 | 130 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 131 | x = torch.cat((cls_tokens, x), dim=1) 132 | x += self.pos_embedding[:, :(n + 1)] 133 | x = self.dropout(x) 134 | 135 | x = self.transformer(x) 136 | 137 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 138 | 139 | x = self.to_latent(x) 140 | return self.mlp_head(x) 141 | -------------------------------------------------------------------------------- /vit_pytorch/vit_with_patch_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | # helpers 8 | 9 | def pair(t): 10 | return t if isinstance(t, tuple) else (t, t) 11 | 12 | # classes 13 | 14 | class PatchDropout(nn.Module): 15 | def __init__(self, prob): 16 | super().__init__() 17 | assert 0 <= prob < 1. 18 | self.prob = prob 19 | 20 | def forward(self, x): 21 | if not self.training or self.prob == 0.: 22 | return x 23 | 24 | b, n, _, device = *x.shape, x.device 25 | 26 | batch_indices = torch.arange(b, device = device) 27 | batch_indices = rearrange(batch_indices, '... -> ... 1') 28 | num_patches_keep = max(1, int(n * (1 - self.prob))) 29 | patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices 30 | 31 | return x[batch_indices, patch_indices_keep] 32 | 33 | class FeedForward(nn.Module): 34 | def __init__(self, dim, hidden_dim, dropout = 0.): 35 | super().__init__() 36 | self.net = nn.Sequential( 37 | nn.LayerNorm(dim), 38 | nn.Linear(dim, hidden_dim), 39 | nn.GELU(), 40 | nn.Dropout(dropout), 41 | nn.Linear(hidden_dim, dim), 42 | nn.Dropout(dropout) 43 | ) 44 | def forward(self, x): 45 | return self.net(x) 46 | 47 | class Attention(nn.Module): 48 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 49 | super().__init__() 50 | inner_dim = dim_head * heads 51 | project_out = not (heads == 1 and dim_head == dim) 52 | 53 | self.heads = heads 54 | self.scale = dim_head ** -0.5 55 | 56 | self.norm = nn.LayerNorm(dim) 57 | self.attend = nn.Softmax(dim = -1) 58 | self.dropout = nn.Dropout(dropout) 59 | 60 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 61 | 62 | self.to_out = nn.Sequential( 63 | nn.Linear(inner_dim, dim), 64 | nn.Dropout(dropout) 65 | ) if project_out else nn.Identity() 66 | 67 | def forward(self, x): 68 | x = self.norm(x) 69 | qkv = self.to_qkv(x).chunk(3, dim = -1) 70 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 71 | 72 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 73 | 74 | attn = self.attend(dots) 75 | attn = self.dropout(attn) 76 | 77 | out = torch.matmul(attn, v) 78 | out = rearrange(out, 'b h n d -> b n (h d)') 79 | return self.to_out(out) 80 | 81 | class Transformer(nn.Module): 82 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 83 | super().__init__() 84 | self.layers = nn.ModuleList([]) 85 | for _ in range(depth): 86 | self.layers.append(nn.ModuleList([ 87 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 88 | FeedForward(dim, mlp_dim, dropout = dropout) 89 | ])) 90 | def forward(self, x): 91 | for attn, ff in self.layers: 92 | x = attn(x) + x 93 | x = ff(x) + x 94 | return x 95 | 96 | class ViT(nn.Module): 97 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25): 98 | super().__init__() 99 | image_height, image_width = pair(image_size) 100 | patch_height, patch_width = pair(patch_size) 101 | 102 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 103 | 104 | num_patches = (image_height // patch_height) * (image_width // patch_width) 105 | patch_dim = channels * patch_height * patch_width 106 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 107 | 108 | self.to_patch_embedding = nn.Sequential( 109 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 110 | nn.Linear(patch_dim, dim), 111 | ) 112 | 113 | self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim)) 114 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 115 | 116 | self.patch_dropout = PatchDropout(patch_dropout) 117 | self.dropout = nn.Dropout(emb_dropout) 118 | 119 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 120 | 121 | self.pool = pool 122 | self.to_latent = nn.Identity() 123 | 124 | self.mlp_head = nn.Sequential( 125 | nn.LayerNorm(dim), 126 | nn.Linear(dim, num_classes) 127 | ) 128 | 129 | def forward(self, img): 130 | x = self.to_patch_embedding(img) 131 | b, n, _ = x.shape 132 | 133 | x += self.pos_embedding 134 | 135 | x = self.patch_dropout(x) 136 | 137 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) 138 | 139 | x = torch.cat((cls_tokens, x), dim=1) 140 | x = self.dropout(x) 141 | 142 | x = self.transformer(x) 143 | 144 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 145 | 146 | x = self.to_latent(x) 147 | return self.mlp_head(x) 148 | -------------------------------------------------------------------------------- /vit_pytorch/vit_with_patch_merger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange, Reduce 6 | 7 | # helpers 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val ,d): 13 | return val if exists(val) else d 14 | 15 | def pair(t): 16 | return t if isinstance(t, tuple) else (t, t) 17 | 18 | # patch merger class 19 | 20 | class PatchMerger(nn.Module): 21 | def __init__(self, dim, num_tokens_out): 22 | super().__init__() 23 | self.scale = dim ** -0.5 24 | self.norm = nn.LayerNorm(dim) 25 | self.queries = nn.Parameter(torch.randn(num_tokens_out, dim)) 26 | 27 | def forward(self, x): 28 | x = self.norm(x) 29 | sim = torch.matmul(self.queries, x.transpose(-1, -2)) * self.scale 30 | attn = sim.softmax(dim = -1) 31 | return torch.matmul(attn, x) 32 | 33 | # classes 34 | 35 | class FeedForward(nn.Module): 36 | def __init__(self, dim, hidden_dim, dropout = 0.): 37 | super().__init__() 38 | self.net = nn.Sequential( 39 | nn.LayerNorm(dim), 40 | nn.Linear(dim, hidden_dim), 41 | nn.GELU(), 42 | nn.Dropout(dropout), 43 | nn.Linear(hidden_dim, dim), 44 | nn.Dropout(dropout) 45 | ) 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 51 | super().__init__() 52 | inner_dim = dim_head * heads 53 | project_out = not (heads == 1 and dim_head == dim) 54 | 55 | self.heads = heads 56 | self.scale = dim_head ** -0.5 57 | 58 | self.norm = nn.LayerNorm(dim) 59 | self.attend = nn.Softmax(dim = -1) 60 | self.dropout = nn.Dropout(dropout) 61 | 62 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 63 | 64 | self.to_out = nn.Sequential( 65 | nn.Linear(inner_dim, dim), 66 | nn.Dropout(dropout) 67 | ) if project_out else nn.Identity() 68 | 69 | def forward(self, x): 70 | x = self.norm(x) 71 | qkv = self.to_qkv(x).chunk(3, dim = -1) 72 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 73 | 74 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 75 | 76 | attn = self.attend(dots) 77 | attn = self.dropout(attn) 78 | 79 | out = torch.matmul(attn, v) 80 | out = rearrange(out, 'b h n d -> b n (h d)') 81 | return self.to_out(out) 82 | 83 | class Transformer(nn.Module): 84 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8): 85 | super().__init__() 86 | self.norm = nn.LayerNorm(dim) 87 | self.layers = nn.ModuleList([]) 88 | 89 | self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper 90 | self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens) 91 | 92 | for _ in range(depth): 93 | self.layers.append(nn.ModuleList([ 94 | Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), 95 | FeedForward(dim, mlp_dim, dropout = dropout) 96 | ])) 97 | def forward(self, x): 98 | for index, (attn, ff) in enumerate(self.layers): 99 | x = attn(x) + x 100 | x = ff(x) + x 101 | 102 | if index == self.patch_merge_layer_index: 103 | x = self.patch_merger(x) 104 | 105 | return self.norm(x) 106 | 107 | class ViT(nn.Module): 108 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 109 | super().__init__() 110 | image_height, image_width = pair(image_size) 111 | patch_height, patch_width = pair(patch_size) 112 | 113 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 114 | 115 | num_patches = (image_height // patch_height) * (image_width // patch_width) 116 | patch_dim = channels * patch_height * patch_width 117 | 118 | self.to_patch_embedding = nn.Sequential( 119 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 120 | nn.LayerNorm(patch_dim), 121 | nn.Linear(patch_dim, dim), 122 | nn.LayerNorm(dim) 123 | ) 124 | 125 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 126 | self.dropout = nn.Dropout(emb_dropout) 127 | 128 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, patch_merge_layer, patch_merge_num_tokens) 129 | 130 | self.mlp_head = nn.Sequential( 131 | Reduce('b n d -> b d', 'mean'), 132 | nn.Linear(dim, num_classes) 133 | ) 134 | 135 | def forward(self, img): 136 | x = self.to_patch_embedding(img) 137 | b, n, _ = x.shape 138 | 139 | x += self.pos_embedding[:, :n] 140 | x = self.dropout(x) 141 | 142 | x = self.transformer(x) 143 | 144 | return self.mlp_head(x) 145 | --------------------------------------------------------------------------------