├── diagram.png
├── timesformer_pytorch
├── __init__.py
├── rotary.py
└── timesformer_pytorch.py
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── LICENSE
├── README.md
└── .gitignore
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/TimeSformer-pytorch/HEAD/diagram.png
--------------------------------------------------------------------------------
/timesformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from timesformer_pytorch.timesformer_pytorch import TimeSformer
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'timesformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.4.1',
7 | license='MIT',
8 | description = 'TimeSformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/TimeSformer-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'transformers',
16 | 'video classification',
17 | ],
18 | install_requires=[
19 | 'einops>=0.3',
20 | 'torch>=1.6'
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## TimeSformer - Pytorch
4 |
5 | Implementation of TimeSformer, from Facebook AI. A pure and simple attention-based solution for reaching SOTA on video classification. This repository will only house the best performing variant, 'Divided Space-Time Attention', which is nothing more than attention along the time axis before the spatial.
6 |
7 | Press release
8 |
9 | ## Install
10 |
11 | ``` bash
12 | $ pip install timesformer-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from timesformer_pytorch import TimeSformer
20 |
21 | model = TimeSformer(
22 | dim = 512,
23 | image_size = 224,
24 | patch_size = 16,
25 | num_frames = 8,
26 | num_classes = 10,
27 | depth = 12,
28 | heads = 8,
29 | dim_head = 64,
30 | attn_dropout = 0.1,
31 | ff_dropout = 0.1
32 | )
33 |
34 | video = torch.randn(2, 8, 3, 224, 224) # (batch x frames x channels x height x width)
35 | mask = torch.ones(2, 8).bool() # (batch x frame) - use a mask if there are variable length videos in the same batch
36 |
37 | pred = model(video, mask = mask) # (2, 10)
38 | ```
39 |
40 | ## Citations
41 |
42 | ```bibtex
43 | @misc{bertasius2021spacetime,
44 | title = {Is Space-Time Attention All You Need for Video Understanding?},
45 | author = {Gedas Bertasius and Heng Wang and Lorenzo Torresani},
46 | year = {2021},
47 | eprint = {2102.05095},
48 | archivePrefix = {arXiv},
49 | primaryClass = {cs.CV}
50 | }
51 | ```
52 |
53 | ```bibtex
54 | @misc{su2021roformer,
55 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
56 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
57 | year = {2021},
58 | eprint = {2104.09864},
59 | archivePrefix = {arXiv},
60 | primaryClass = {cs.CL}
61 | }
62 | ```
63 |
64 | ```bibtex
65 | @article{tokshift2021,
66 | title = {Token Shift Transformer for Video Classification},
67 | author = {Hao Zhang, Yanbin Hao, Chong-Wah Ngo},
68 | journal = {ACM Multimedia 2021},
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/timesformer_pytorch/rotary.py:
--------------------------------------------------------------------------------
1 | from math import log, pi
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 |
7 | def rotate_every_two(x):
8 | x = rearrange(x, '... (d j) -> ... d j', j = 2)
9 | x1, x2 = x.unbind(dim = -1)
10 | x = torch.stack((-x2, x1), dim = -1)
11 | return rearrange(x, '... d j -> ... (d j)')
12 |
13 | def apply_rot_emb(q, k, rot_emb):
14 | sin, cos = rot_emb
15 | rot_dim = sin.shape[-1]
16 | (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k))
17 | q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k))
18 | q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
19 | return q, k
20 |
21 | class AxialRotaryEmbedding(nn.Module):
22 | def __init__(self, dim, max_freq = 10):
23 | super().__init__()
24 | self.dim = dim
25 | scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
26 | self.register_buffer('scales', scales)
27 |
28 | def forward(self, h, w, device):
29 | scales = rearrange(self.scales, '... -> () ...')
30 | scales = scales.to(device)
31 |
32 | h_seq = torch.linspace(-1., 1., steps = h, device = device)
33 | h_seq = h_seq.unsqueeze(-1)
34 |
35 | w_seq = torch.linspace(-1., 1., steps = w, device = device)
36 | w_seq = w_seq.unsqueeze(-1)
37 |
38 | h_seq = h_seq * scales * pi
39 | w_seq = w_seq * scales * pi
40 |
41 | x_sinu = repeat(h_seq, 'i d -> i j d', j = w)
42 | y_sinu = repeat(w_seq, 'j d -> i j d', i = h)
43 |
44 | sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
45 | cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
46 |
47 | sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
48 | sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
49 | return sin, cos
50 |
51 | class RotaryEmbedding(nn.Module):
52 | def __init__(self, dim):
53 | super().__init__()
54 | inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
55 | self.register_buffer('inv_freqs', inv_freqs)
56 |
57 | def forward(self, n, device):
58 | seq = torch.arange(n, device = device)
59 | freqs = einsum('i, j -> i j', seq, self.inv_freqs)
60 | freqs = torch.cat((freqs, freqs), dim = -1)
61 | freqs = rearrange(freqs, 'n d -> () n d')
62 | return freqs.sin(), freqs.cos()
--------------------------------------------------------------------------------
/timesformer_pytorch/timesformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 | from einops import rearrange, repeat
5 |
6 | from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding
7 |
8 | # helpers
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # classes
14 |
15 | class PreNorm(nn.Module):
16 | def __init__(self, dim, fn):
17 | super().__init__()
18 | self.fn = fn
19 | self.norm = nn.LayerNorm(dim)
20 |
21 | def forward(self, x, *args, **kwargs):
22 | x = self.norm(x)
23 | return self.fn(x, *args, **kwargs)
24 |
25 | # time token shift
26 |
27 | def shift(t, amt):
28 | if amt is 0:
29 | return t
30 | return F.pad(t, (0, 0, 0, 0, amt, -amt))
31 |
32 | class PreTokenShift(nn.Module):
33 | def __init__(self, frames, fn):
34 | super().__init__()
35 | self.frames = frames
36 | self.fn = fn
37 |
38 | def forward(self, x, *args, **kwargs):
39 | f, dim = self.frames, x.shape[-1]
40 | cls_x, x = x[:, :1], x[:, 1:]
41 | x = rearrange(x, 'b (f n) d -> b f n d', f = f)
42 |
43 | # shift along time frame before and after
44 |
45 | dim_chunk = (dim // 3)
46 | chunks = x.split(dim_chunk, dim = -1)
47 | chunks_to_shift, rest = chunks[:3], chunks[3:]
48 | shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
49 | x = torch.cat((*shifted_chunks, *rest), dim = -1)
50 |
51 | x = rearrange(x, 'b f n d -> b (f n) d')
52 | x = torch.cat((cls_x, x), dim = 1)
53 | return self.fn(x, *args, **kwargs)
54 |
55 | # feedforward
56 |
57 | class GEGLU(nn.Module):
58 | def forward(self, x):
59 | x, gates = x.chunk(2, dim = -1)
60 | return x * F.gelu(gates)
61 |
62 | class FeedForward(nn.Module):
63 | def __init__(self, dim, mult = 4, dropout = 0.):
64 | super().__init__()
65 | self.net = nn.Sequential(
66 | nn.Linear(dim, dim * mult * 2),
67 | GEGLU(),
68 | nn.Dropout(dropout),
69 | nn.Linear(dim * mult, dim)
70 | )
71 |
72 | def forward(self, x):
73 | return self.net(x)
74 |
75 | # attention
76 |
77 | def attn(q, k, v, mask = None):
78 | sim = einsum('b i d, b j d -> b i j', q, k)
79 |
80 | if exists(mask):
81 | max_neg_value = -torch.finfo(sim.dtype).max
82 | sim.masked_fill_(~mask, max_neg_value)
83 |
84 | attn = sim.softmax(dim = -1)
85 | out = einsum('b i j, b j d -> b i d', attn, v)
86 | return out
87 |
88 | class Attention(nn.Module):
89 | def __init__(
90 | self,
91 | dim,
92 | dim_head = 64,
93 | heads = 8,
94 | dropout = 0.
95 | ):
96 | super().__init__()
97 | self.heads = heads
98 | self.scale = dim_head ** -0.5
99 | inner_dim = dim_head * heads
100 |
101 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
102 | self.to_out = nn.Sequential(
103 | nn.Linear(inner_dim, dim),
104 | nn.Dropout(dropout)
105 | )
106 |
107 | def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
108 | h = self.heads
109 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
110 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
111 |
112 | q = q * self.scale
113 |
114 | # splice out classification token at index 1
115 | (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))
116 |
117 | # let classification token attend to key / values of all patches across time and space
118 | cls_out = attn(cls_q, k, v, mask = cls_mask)
119 |
120 | # rearrange across time or space
121 | q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
122 |
123 | # add rotary embeddings, if applicable
124 | if exists(rot_emb):
125 | q_, k_ = apply_rot_emb(q_, k_, rot_emb)
126 |
127 | # expand cls token keys and values across time or space and concat
128 | r = q_.shape[0] // cls_k.shape[0]
129 | cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))
130 |
131 | k_ = torch.cat((cls_k, k_), dim = 1)
132 | v_ = torch.cat((cls_v, v_), dim = 1)
133 |
134 | # attention
135 | out = attn(q_, k_, v_, mask = mask)
136 |
137 | # merge back time or space
138 | out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
139 |
140 | # concat back the cls token
141 | out = torch.cat((cls_out, out), dim = 1)
142 |
143 | # merge back the heads
144 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
145 |
146 | # combine heads out
147 | return self.to_out(out)
148 |
149 | # main classes
150 |
151 | class TimeSformer(nn.Module):
152 | def __init__(
153 | self,
154 | *,
155 | dim,
156 | num_frames,
157 | num_classes,
158 | image_size = 224,
159 | patch_size = 16,
160 | channels = 3,
161 | depth = 12,
162 | heads = 8,
163 | dim_head = 64,
164 | attn_dropout = 0.,
165 | ff_dropout = 0.,
166 | rotary_emb = True,
167 | shift_tokens = False
168 | ):
169 | super().__init__()
170 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
171 |
172 | num_patches = (image_size // patch_size) ** 2
173 | num_positions = num_frames * num_patches
174 | patch_dim = channels * patch_size ** 2
175 |
176 | self.heads = heads
177 | self.patch_size = patch_size
178 | self.to_patch_embedding = nn.Linear(patch_dim, dim)
179 | self.cls_token = nn.Parameter(torch.randn(1, dim))
180 |
181 | self.use_rotary_emb = rotary_emb
182 | if rotary_emb:
183 | self.frame_rot_emb = RotaryEmbedding(dim_head)
184 | self.image_rot_emb = AxialRotaryEmbedding(dim_head)
185 | else:
186 | self.pos_emb = nn.Embedding(num_positions + 1, dim)
187 |
188 |
189 | self.layers = nn.ModuleList([])
190 | for _ in range(depth):
191 | ff = FeedForward(dim, dropout = ff_dropout)
192 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
193 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
194 |
195 | if shift_tokens:
196 | time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff))
197 |
198 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff))
199 |
200 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))
201 |
202 | self.to_out = nn.Sequential(
203 | nn.LayerNorm(dim),
204 | nn.Linear(dim, num_classes)
205 | )
206 |
207 | def forward(self, video, mask = None):
208 | b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
209 | assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'
210 |
211 | # calculate num patches in height and width dimension, and number of total patches (n)
212 |
213 | hp, wp = (h // p), (w // p)
214 | n = hp * wp
215 |
216 | # video to patch embeddings
217 |
218 | video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
219 | tokens = self.to_patch_embedding(video)
220 |
221 | # add cls token
222 |
223 | cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
224 | x = torch.cat((cls_token, tokens), dim = 1)
225 |
226 | # positional embedding
227 |
228 | frame_pos_emb = None
229 | image_pos_emb = None
230 | if not self.use_rotary_emb:
231 | x += self.pos_emb(torch.arange(x.shape[1], device = device))
232 | else:
233 | frame_pos_emb = self.frame_rot_emb(f, device = device)
234 | image_pos_emb = self.image_rot_emb(hp, wp, device = device)
235 |
236 | # calculate masking for uneven number of frames
237 |
238 | frame_mask = None
239 | cls_attn_mask = None
240 | if exists(mask):
241 | mask_with_cls = F.pad(mask, (1, 0), value = True)
242 |
243 | frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads)
244 |
245 | cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads)
246 | cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True)
247 |
248 | # time and space attention
249 |
250 | for (time_attn, spatial_attn, ff) in self.layers:
251 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x
252 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x
253 | x = ff(x) + x
254 |
255 | cls_token = x[:, 0]
256 | return self.to_out(cls_token)
257 |
--------------------------------------------------------------------------------