├── uniformer.png ├── uniformer_pytorch ├── __init__.py └── uniformer_pytorch.py ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── README.md └── .gitignore /uniformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/uniformer-pytorch/HEAD/uniformer.png -------------------------------------------------------------------------------- /uniformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from uniformer_pytorch.uniformer_pytorch import Uniformer 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'uniformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.4', 7 | license='MIT', 8 | description = 'Uniformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/uniformer-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'video classification' 16 | ], 17 | install_requires=[ 18 | 'einops>=0.3', 19 | 'torch>=1.6' 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Uniformer - Pytorch 4 | 5 | Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install uniformer-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | Uniformer-S 16 | 17 | ```python 18 | import torch 19 | from uniformer_pytorch import Uniformer 20 | 21 | model = Uniformer( 22 | num_classes = 1000, # number of output classes 23 | dims = (64, 128, 256, 512), # feature dimensions per stage (4 stages) 24 | depths = (3, 4, 8, 3), # depth at each stage 25 | mhsa_types = ('l', 'l', 'g', 'g') # aggregation type at each stage, 'l' stands for local, 'g' stands for global 26 | ) 27 | 28 | video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, time, height, width) 29 | 30 | logits = model(video) # (1, 1000) 31 | ``` 32 | 33 | Uniformer-B 34 | 35 | ```python 36 | import torch 37 | from uniformer_pytorch import Uniformer 38 | 39 | model = Uniformer( 40 | num_classes = 1000 41 | depths = (5, 8, 20, 7) 42 | ) 43 | ``` 44 | 45 | ## Citations 46 | 47 | ```bibtex 48 | @inproceedings{anonymous2022uniformer, 49 | title = {UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning}, 50 | author = {Anonymous}, 51 | booktitle = {Submitted to The Tenth International Conference on Learning Representations }, 52 | year = {2022}, 53 | url = {https://openreview.net/forum?id=nBU_u6DLvoK}, 54 | note = {under review} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /uniformer_pytorch/uniformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | from einops.layers.torch import Reduce 5 | 6 | # helpers 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | # classes 12 | 13 | class LayerNorm(nn.Module): 14 | def __init__(self, dim, eps = 1e-5): 15 | super().__init__() 16 | self.eps = eps 17 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) 18 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1)) 19 | 20 | def forward(self, x): 21 | std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() 22 | mean = torch.mean(x, dim = 1, keepdim = True) 23 | return (x - mean) / (std + self.eps) * self.g + self.b 24 | 25 | def FeedForward(dim, mult = 4, dropout = 0.): 26 | return nn.Sequential( 27 | LayerNorm(dim), 28 | nn.Conv3d(dim, dim * mult, 1), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Conv3d(dim * mult, dim, 1) 32 | ) 33 | 34 | # MHRAs (multi-head relation aggregators) 35 | 36 | class LocalMHRA(nn.Module): 37 | def __init__( 38 | self, 39 | dim, 40 | heads, 41 | dim_head = 64, 42 | local_aggr_kernel = 5 43 | ): 44 | super().__init__() 45 | self.heads = heads 46 | inner_dim = dim_head * heads 47 | 48 | # they use batchnorm for the local MHRA instead of layer norm 49 | self.norm = nn.BatchNorm3d(dim) 50 | 51 | # only values, as the attention matrix is taking care of by a convolution 52 | self.to_v = nn.Conv3d(dim, inner_dim, 1, bias = False) 53 | 54 | # this should be equivalent to aggregating by an attention matrix parameterized as a function of the relative positions across each axis 55 | self.rel_pos = nn.Conv3d(heads, heads, local_aggr_kernel, padding = local_aggr_kernel // 2, groups = heads) 56 | 57 | # combine out across all the heads 58 | self.to_out = nn.Conv3d(inner_dim, dim, 1) 59 | 60 | def forward(self, x): 61 | x = self.norm(x) 62 | 63 | b, c, *_, h = *x.shape, self.heads 64 | 65 | # to values 66 | v = self.to_v(x) 67 | 68 | # split out heads 69 | v = rearrange(v, 'b (c h) ... -> (b c) h ...', h = h) 70 | 71 | # aggregate by relative positions 72 | out = self.rel_pos(v) 73 | 74 | # combine heads 75 | out = rearrange(out, '(b c) h ... -> b (c h) ...', b = b) 76 | return self.to_out(out) 77 | 78 | class GlobalMHRA(nn.Module): 79 | def __init__( 80 | self, 81 | dim, 82 | heads, 83 | dim_head = 64, 84 | dropout = 0. 85 | ): 86 | super().__init__() 87 | self.heads = heads 88 | self.scale = dim_head ** -0.5 89 | inner_dim = dim_head * heads 90 | 91 | self.norm = LayerNorm(dim) 92 | self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False) 93 | self.to_out = nn.Conv1d(inner_dim, dim, 1) 94 | 95 | def forward(self, x): 96 | x = self.norm(x) 97 | 98 | shape, h = x.shape, self.heads 99 | 100 | x = rearrange(x, 'b c ... -> b c (...)') 101 | 102 | q, k, v = self.to_qkv(x).chunk(3, dim = 1) 103 | q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h = h), (q, k, v)) 104 | 105 | q = q * self.scale 106 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 107 | 108 | # attention 109 | attn = sim.softmax(dim = -1) 110 | 111 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 112 | out = rearrange(out, 'b h n d -> b (h d) n', h = h) 113 | 114 | out = self.to_out(out) 115 | return out.view(*shape) 116 | 117 | class Transformer(nn.Module): 118 | def __init__( 119 | self, 120 | *, 121 | dim, 122 | depth, 123 | heads, 124 | mhsa_type = 'g', 125 | local_aggr_kernel = 5, 126 | dim_head = 64, 127 | ff_mult = 4, 128 | ff_dropout = 0., 129 | attn_dropout = 0. 130 | ): 131 | super().__init__() 132 | 133 | self.layers = nn.ModuleList([]) 134 | 135 | for _ in range(depth): 136 | if mhsa_type == 'l': 137 | attn = LocalMHRA(dim, heads = heads, dim_head = dim_head, local_aggr_kernel = local_aggr_kernel) 138 | elif mhsa_type == 'g': 139 | attn = GlobalMHRA(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout) 140 | else: 141 | raise ValueError('unknown mhsa_type') 142 | 143 | self.layers.append(nn.ModuleList([ 144 | nn.Conv3d(dim, dim, 3, padding = 1), 145 | attn, 146 | FeedForward(dim, mult = ff_mult, dropout = ff_dropout), 147 | ])) 148 | 149 | def forward(self, x): 150 | for dpe, attn, ff in self.layers: 151 | x = dpe(x) + x 152 | x = attn(x) + x 153 | x = ff(x) + x 154 | 155 | return x 156 | 157 | # main class 158 | 159 | class Uniformer(nn.Module): 160 | def __init__( 161 | self, 162 | *, 163 | num_classes, 164 | dims = (64, 128, 256, 512), 165 | depths = (3, 4, 8, 3), 166 | mhsa_types = ('l', 'l', 'g', 'g'), 167 | local_aggr_kernel = 5, 168 | channels = 3, 169 | ff_mult = 4, 170 | dim_head = 64, 171 | ff_dropout = 0., 172 | attn_dropout = 0. 173 | ): 174 | super().__init__() 175 | init_dim, *_, last_dim = dims 176 | self.to_tokens = nn.Conv3d(channels, init_dim, (3, 4, 4), stride = (2, 4, 4), padding = (1, 0, 0)) 177 | 178 | dim_in_out = tuple(zip(dims[:-1], dims[1:])) 179 | mhsa_types = tuple(map(lambda t: t.lower(), mhsa_types)) 180 | 181 | self.stages = nn.ModuleList([]) 182 | 183 | for ind, (depth, mhsa_type) in enumerate(zip(depths, mhsa_types)): 184 | is_last = ind == len(depths) - 1 185 | stage_dim = dims[ind] 186 | heads = stage_dim // dim_head 187 | 188 | self.stages.append(nn.ModuleList([ 189 | Transformer( 190 | dim = stage_dim, 191 | depth = depth, 192 | heads = heads, 193 | mhsa_type = mhsa_type, 194 | ff_mult = ff_mult, 195 | ff_dropout = ff_dropout, 196 | attn_dropout = attn_dropout 197 | ), 198 | nn.Sequential( 199 | nn.Conv3d(stage_dim, dims[ind + 1], (1, 2, 2), stride = (1, 2, 2)), 200 | LayerNorm(dims[ind + 1]), 201 | ) if not is_last else None 202 | ])) 203 | 204 | self.to_logits = nn.Sequential( 205 | Reduce('b c t h w -> b c', 'mean'), 206 | nn.LayerNorm(last_dim), 207 | nn.Linear(last_dim, num_classes) 208 | ) 209 | 210 | def forward(self, video): 211 | x = self.to_tokens(video) 212 | 213 | for transformer, conv in self.stages: 214 | x = transformer(x) 215 | 216 | if exists(conv): 217 | x = conv(x) 218 | 219 | return self.to_logits(x) 220 | --------------------------------------------------------------------------------