├── uniformer.png
├── uniformer_pytorch
├── __init__.py
└── uniformer_pytorch.py
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── README.md
└── .gitignore
/uniformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/uniformer-pytorch/HEAD/uniformer.png
--------------------------------------------------------------------------------
/uniformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from uniformer_pytorch.uniformer_pytorch import Uniformer
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'uniformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'Uniformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/uniformer-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'video classification'
16 | ],
17 | install_requires=[
18 | 'einops>=0.3',
19 | 'torch>=1.6'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Uniformer - Pytorch
4 |
5 | Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install uniformer-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | Uniformer-S
16 |
17 | ```python
18 | import torch
19 | from uniformer_pytorch import Uniformer
20 |
21 | model = Uniformer(
22 | num_classes = 1000, # number of output classes
23 | dims = (64, 128, 256, 512), # feature dimensions per stage (4 stages)
24 | depths = (3, 4, 8, 3), # depth at each stage
25 | mhsa_types = ('l', 'l', 'g', 'g') # aggregation type at each stage, 'l' stands for local, 'g' stands for global
26 | )
27 |
28 | video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, time, height, width)
29 |
30 | logits = model(video) # (1, 1000)
31 | ```
32 |
33 | Uniformer-B
34 |
35 | ```python
36 | import torch
37 | from uniformer_pytorch import Uniformer
38 |
39 | model = Uniformer(
40 | num_classes = 1000
41 | depths = (5, 8, 20, 7)
42 | )
43 | ```
44 |
45 | ## Citations
46 |
47 | ```bibtex
48 | @inproceedings{anonymous2022uniformer,
49 | title = {UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning},
50 | author = {Anonymous},
51 | booktitle = {Submitted to The Tenth International Conference on Learning Representations },
52 | year = {2022},
53 | url = {https://openreview.net/forum?id=nBU_u6DLvoK},
54 | note = {under review}
55 | }
56 | ```
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/uniformer_pytorch/uniformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 | from einops.layers.torch import Reduce
5 |
6 | # helpers
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | # classes
12 |
13 | class LayerNorm(nn.Module):
14 | def __init__(self, dim, eps = 1e-5):
15 | super().__init__()
16 | self.eps = eps
17 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
18 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
19 |
20 | def forward(self, x):
21 | std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
22 | mean = torch.mean(x, dim = 1, keepdim = True)
23 | return (x - mean) / (std + self.eps) * self.g + self.b
24 |
25 | def FeedForward(dim, mult = 4, dropout = 0.):
26 | return nn.Sequential(
27 | LayerNorm(dim),
28 | nn.Conv3d(dim, dim * mult, 1),
29 | nn.GELU(),
30 | nn.Dropout(dropout),
31 | nn.Conv3d(dim * mult, dim, 1)
32 | )
33 |
34 | # MHRAs (multi-head relation aggregators)
35 |
36 | class LocalMHRA(nn.Module):
37 | def __init__(
38 | self,
39 | dim,
40 | heads,
41 | dim_head = 64,
42 | local_aggr_kernel = 5
43 | ):
44 | super().__init__()
45 | self.heads = heads
46 | inner_dim = dim_head * heads
47 |
48 | # they use batchnorm for the local MHRA instead of layer norm
49 | self.norm = nn.BatchNorm3d(dim)
50 |
51 | # only values, as the attention matrix is taking care of by a convolution
52 | self.to_v = nn.Conv3d(dim, inner_dim, 1, bias = False)
53 |
54 | # this should be equivalent to aggregating by an attention matrix parameterized as a function of the relative positions across each axis
55 | self.rel_pos = nn.Conv3d(heads, heads, local_aggr_kernel, padding = local_aggr_kernel // 2, groups = heads)
56 |
57 | # combine out across all the heads
58 | self.to_out = nn.Conv3d(inner_dim, dim, 1)
59 |
60 | def forward(self, x):
61 | x = self.norm(x)
62 |
63 | b, c, *_, h = *x.shape, self.heads
64 |
65 | # to values
66 | v = self.to_v(x)
67 |
68 | # split out heads
69 | v = rearrange(v, 'b (c h) ... -> (b c) h ...', h = h)
70 |
71 | # aggregate by relative positions
72 | out = self.rel_pos(v)
73 |
74 | # combine heads
75 | out = rearrange(out, '(b c) h ... -> b (c h) ...', b = b)
76 | return self.to_out(out)
77 |
78 | class GlobalMHRA(nn.Module):
79 | def __init__(
80 | self,
81 | dim,
82 | heads,
83 | dim_head = 64,
84 | dropout = 0.
85 | ):
86 | super().__init__()
87 | self.heads = heads
88 | self.scale = dim_head ** -0.5
89 | inner_dim = dim_head * heads
90 |
91 | self.norm = LayerNorm(dim)
92 | self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
93 | self.to_out = nn.Conv1d(inner_dim, dim, 1)
94 |
95 | def forward(self, x):
96 | x = self.norm(x)
97 |
98 | shape, h = x.shape, self.heads
99 |
100 | x = rearrange(x, 'b c ... -> b c (...)')
101 |
102 | q, k, v = self.to_qkv(x).chunk(3, dim = 1)
103 | q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h = h), (q, k, v))
104 |
105 | q = q * self.scale
106 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
107 |
108 | # attention
109 | attn = sim.softmax(dim = -1)
110 |
111 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
112 | out = rearrange(out, 'b h n d -> b (h d) n', h = h)
113 |
114 | out = self.to_out(out)
115 | return out.view(*shape)
116 |
117 | class Transformer(nn.Module):
118 | def __init__(
119 | self,
120 | *,
121 | dim,
122 | depth,
123 | heads,
124 | mhsa_type = 'g',
125 | local_aggr_kernel = 5,
126 | dim_head = 64,
127 | ff_mult = 4,
128 | ff_dropout = 0.,
129 | attn_dropout = 0.
130 | ):
131 | super().__init__()
132 |
133 | self.layers = nn.ModuleList([])
134 |
135 | for _ in range(depth):
136 | if mhsa_type == 'l':
137 | attn = LocalMHRA(dim, heads = heads, dim_head = dim_head, local_aggr_kernel = local_aggr_kernel)
138 | elif mhsa_type == 'g':
139 | attn = GlobalMHRA(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)
140 | else:
141 | raise ValueError('unknown mhsa_type')
142 |
143 | self.layers.append(nn.ModuleList([
144 | nn.Conv3d(dim, dim, 3, padding = 1),
145 | attn,
146 | FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
147 | ]))
148 |
149 | def forward(self, x):
150 | for dpe, attn, ff in self.layers:
151 | x = dpe(x) + x
152 | x = attn(x) + x
153 | x = ff(x) + x
154 |
155 | return x
156 |
157 | # main class
158 |
159 | class Uniformer(nn.Module):
160 | def __init__(
161 | self,
162 | *,
163 | num_classes,
164 | dims = (64, 128, 256, 512),
165 | depths = (3, 4, 8, 3),
166 | mhsa_types = ('l', 'l', 'g', 'g'),
167 | local_aggr_kernel = 5,
168 | channels = 3,
169 | ff_mult = 4,
170 | dim_head = 64,
171 | ff_dropout = 0.,
172 | attn_dropout = 0.
173 | ):
174 | super().__init__()
175 | init_dim, *_, last_dim = dims
176 | self.to_tokens = nn.Conv3d(channels, init_dim, (3, 4, 4), stride = (2, 4, 4), padding = (1, 0, 0))
177 |
178 | dim_in_out = tuple(zip(dims[:-1], dims[1:]))
179 | mhsa_types = tuple(map(lambda t: t.lower(), mhsa_types))
180 |
181 | self.stages = nn.ModuleList([])
182 |
183 | for ind, (depth, mhsa_type) in enumerate(zip(depths, mhsa_types)):
184 | is_last = ind == len(depths) - 1
185 | stage_dim = dims[ind]
186 | heads = stage_dim // dim_head
187 |
188 | self.stages.append(nn.ModuleList([
189 | Transformer(
190 | dim = stage_dim,
191 | depth = depth,
192 | heads = heads,
193 | mhsa_type = mhsa_type,
194 | ff_mult = ff_mult,
195 | ff_dropout = ff_dropout,
196 | attn_dropout = attn_dropout
197 | ),
198 | nn.Sequential(
199 | nn.Conv3d(stage_dim, dims[ind + 1], (1, 2, 2), stride = (1, 2, 2)),
200 | LayerNorm(dims[ind + 1]),
201 | ) if not is_last else None
202 | ]))
203 |
204 | self.to_logits = nn.Sequential(
205 | Reduce('b c t h w -> b c', 'mean'),
206 | nn.LayerNorm(last_dim),
207 | nn.Linear(last_dim, num_classes)
208 | )
209 |
210 | def forward(self, video):
211 | x = self.to_tokens(video)
212 |
213 | for transformer, conv in self.stages:
214 | x = transformer(x)
215 |
216 | if exists(conv):
217 | x = conv(x)
218 |
219 | return self.to_logits(x)
220 |
--------------------------------------------------------------------------------