├── diagram.png ├── timesformer_pytorch ├── __init__.py ├── rotary.py └── timesformer_pytorch.py ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── LICENSE ├── README.md └── .gitignore /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/TimeSformer-pytorch/HEAD/diagram.png -------------------------------------------------------------------------------- /timesformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from timesformer_pytorch.timesformer_pytorch import TimeSformer 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'timesformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.4.1', 7 | license='MIT', 8 | description = 'TimeSformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/TimeSformer-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'transformers', 16 | 'video classification', 17 | ], 18 | install_requires=[ 19 | 'einops>=0.3', 20 | 'torch>=1.6' 21 | ], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3.6', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## TimeSformer - Pytorch 4 | 5 | Implementation of TimeSformer, from Facebook AI. A pure and simple attention-based solution for reaching SOTA on video classification. This repository will only house the best performing variant, 'Divided Space-Time Attention', which is nothing more than attention along the time axis before the spatial. 6 | 7 | Press release 8 | 9 | ## Install 10 | 11 | ``` bash 12 | $ pip install timesformer-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from timesformer_pytorch import TimeSformer 20 | 21 | model = TimeSformer( 22 | dim = 512, 23 | image_size = 224, 24 | patch_size = 16, 25 | num_frames = 8, 26 | num_classes = 10, 27 | depth = 12, 28 | heads = 8, 29 | dim_head = 64, 30 | attn_dropout = 0.1, 31 | ff_dropout = 0.1 32 | ) 33 | 34 | video = torch.randn(2, 8, 3, 224, 224) # (batch x frames x channels x height x width) 35 | mask = torch.ones(2, 8).bool() # (batch x frame) - use a mask if there are variable length videos in the same batch 36 | 37 | pred = model(video, mask = mask) # (2, 10) 38 | ``` 39 | 40 | ## Citations 41 | 42 | ```bibtex 43 | @misc{bertasius2021spacetime, 44 | title = {Is Space-Time Attention All You Need for Video Understanding?}, 45 | author = {Gedas Bertasius and Heng Wang and Lorenzo Torresani}, 46 | year = {2021}, 47 | eprint = {2102.05095}, 48 | archivePrefix = {arXiv}, 49 | primaryClass = {cs.CV} 50 | } 51 | ``` 52 | 53 | ```bibtex 54 | @misc{su2021roformer, 55 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 56 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 57 | year = {2021}, 58 | eprint = {2104.09864}, 59 | archivePrefix = {arXiv}, 60 | primaryClass = {cs.CL} 61 | } 62 | ``` 63 | 64 | ```bibtex 65 | @article{tokshift2021, 66 | title = {Token Shift Transformer for Video Classification}, 67 | author = {Hao Zhang, Yanbin Hao, Chong-Wah Ngo}, 68 | journal = {ACM Multimedia 2021}, 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /timesformer_pytorch/rotary.py: -------------------------------------------------------------------------------- 1 | from math import log, pi 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | 7 | def rotate_every_two(x): 8 | x = rearrange(x, '... (d j) -> ... d j', j = 2) 9 | x1, x2 = x.unbind(dim = -1) 10 | x = torch.stack((-x2, x1), dim = -1) 11 | return rearrange(x, '... d j -> ... (d j)') 12 | 13 | def apply_rot_emb(q, k, rot_emb): 14 | sin, cos = rot_emb 15 | rot_dim = sin.shape[-1] 16 | (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k)) 17 | q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k)) 18 | q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass))) 19 | return q, k 20 | 21 | class AxialRotaryEmbedding(nn.Module): 22 | def __init__(self, dim, max_freq = 10): 23 | super().__init__() 24 | self.dim = dim 25 | scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2) 26 | self.register_buffer('scales', scales) 27 | 28 | def forward(self, h, w, device): 29 | scales = rearrange(self.scales, '... -> () ...') 30 | scales = scales.to(device) 31 | 32 | h_seq = torch.linspace(-1., 1., steps = h, device = device) 33 | h_seq = h_seq.unsqueeze(-1) 34 | 35 | w_seq = torch.linspace(-1., 1., steps = w, device = device) 36 | w_seq = w_seq.unsqueeze(-1) 37 | 38 | h_seq = h_seq * scales * pi 39 | w_seq = w_seq * scales * pi 40 | 41 | x_sinu = repeat(h_seq, 'i d -> i j d', j = w) 42 | y_sinu = repeat(w_seq, 'j d -> i j d', i = h) 43 | 44 | sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1) 45 | cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1) 46 | 47 | sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos)) 48 | sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos)) 49 | return sin, cos 50 | 51 | class RotaryEmbedding(nn.Module): 52 | def __init__(self, dim): 53 | super().__init__() 54 | inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 55 | self.register_buffer('inv_freqs', inv_freqs) 56 | 57 | def forward(self, n, device): 58 | seq = torch.arange(n, device = device) 59 | freqs = einsum('i, j -> i j', seq, self.inv_freqs) 60 | freqs = torch.cat((freqs, freqs), dim = -1) 61 | freqs = rearrange(freqs, 'n d -> () n d') 62 | return freqs.sin(), freqs.cos() -------------------------------------------------------------------------------- /timesformer_pytorch/timesformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | 6 | from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # classes 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.fn = fn 19 | self.norm = nn.LayerNorm(dim) 20 | 21 | def forward(self, x, *args, **kwargs): 22 | x = self.norm(x) 23 | return self.fn(x, *args, **kwargs) 24 | 25 | # time token shift 26 | 27 | def shift(t, amt): 28 | if amt is 0: 29 | return t 30 | return F.pad(t, (0, 0, 0, 0, amt, -amt)) 31 | 32 | class PreTokenShift(nn.Module): 33 | def __init__(self, frames, fn): 34 | super().__init__() 35 | self.frames = frames 36 | self.fn = fn 37 | 38 | def forward(self, x, *args, **kwargs): 39 | f, dim = self.frames, x.shape[-1] 40 | cls_x, x = x[:, :1], x[:, 1:] 41 | x = rearrange(x, 'b (f n) d -> b f n d', f = f) 42 | 43 | # shift along time frame before and after 44 | 45 | dim_chunk = (dim // 3) 46 | chunks = x.split(dim_chunk, dim = -1) 47 | chunks_to_shift, rest = chunks[:3], chunks[3:] 48 | shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1)))) 49 | x = torch.cat((*shifted_chunks, *rest), dim = -1) 50 | 51 | x = rearrange(x, 'b f n d -> b (f n) d') 52 | x = torch.cat((cls_x, x), dim = 1) 53 | return self.fn(x, *args, **kwargs) 54 | 55 | # feedforward 56 | 57 | class GEGLU(nn.Module): 58 | def forward(self, x): 59 | x, gates = x.chunk(2, dim = -1) 60 | return x * F.gelu(gates) 61 | 62 | class FeedForward(nn.Module): 63 | def __init__(self, dim, mult = 4, dropout = 0.): 64 | super().__init__() 65 | self.net = nn.Sequential( 66 | nn.Linear(dim, dim * mult * 2), 67 | GEGLU(), 68 | nn.Dropout(dropout), 69 | nn.Linear(dim * mult, dim) 70 | ) 71 | 72 | def forward(self, x): 73 | return self.net(x) 74 | 75 | # attention 76 | 77 | def attn(q, k, v, mask = None): 78 | sim = einsum('b i d, b j d -> b i j', q, k) 79 | 80 | if exists(mask): 81 | max_neg_value = -torch.finfo(sim.dtype).max 82 | sim.masked_fill_(~mask, max_neg_value) 83 | 84 | attn = sim.softmax(dim = -1) 85 | out = einsum('b i j, b j d -> b i d', attn, v) 86 | return out 87 | 88 | class Attention(nn.Module): 89 | def __init__( 90 | self, 91 | dim, 92 | dim_head = 64, 93 | heads = 8, 94 | dropout = 0. 95 | ): 96 | super().__init__() 97 | self.heads = heads 98 | self.scale = dim_head ** -0.5 99 | inner_dim = dim_head * heads 100 | 101 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 102 | self.to_out = nn.Sequential( 103 | nn.Linear(inner_dim, dim), 104 | nn.Dropout(dropout) 105 | ) 106 | 107 | def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims): 108 | h = self.heads 109 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 110 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) 111 | 112 | q = q * self.scale 113 | 114 | # splice out classification token at index 1 115 | (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v)) 116 | 117 | # let classification token attend to key / values of all patches across time and space 118 | cls_out = attn(cls_q, k, v, mask = cls_mask) 119 | 120 | # rearrange across time or space 121 | q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) 122 | 123 | # add rotary embeddings, if applicable 124 | if exists(rot_emb): 125 | q_, k_ = apply_rot_emb(q_, k_, rot_emb) 126 | 127 | # expand cls token keys and values across time or space and concat 128 | r = q_.shape[0] // cls_k.shape[0] 129 | cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v)) 130 | 131 | k_ = torch.cat((cls_k, k_), dim = 1) 132 | v_ = torch.cat((cls_v, v_), dim = 1) 133 | 134 | # attention 135 | out = attn(q_, k_, v_, mask = mask) 136 | 137 | # merge back time or space 138 | out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) 139 | 140 | # concat back the cls token 141 | out = torch.cat((cls_out, out), dim = 1) 142 | 143 | # merge back the heads 144 | out = rearrange(out, '(b h) n d -> b n (h d)', h = h) 145 | 146 | # combine heads out 147 | return self.to_out(out) 148 | 149 | # main classes 150 | 151 | class TimeSformer(nn.Module): 152 | def __init__( 153 | self, 154 | *, 155 | dim, 156 | num_frames, 157 | num_classes, 158 | image_size = 224, 159 | patch_size = 16, 160 | channels = 3, 161 | depth = 12, 162 | heads = 8, 163 | dim_head = 64, 164 | attn_dropout = 0., 165 | ff_dropout = 0., 166 | rotary_emb = True, 167 | shift_tokens = False 168 | ): 169 | super().__init__() 170 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 171 | 172 | num_patches = (image_size // patch_size) ** 2 173 | num_positions = num_frames * num_patches 174 | patch_dim = channels * patch_size ** 2 175 | 176 | self.heads = heads 177 | self.patch_size = patch_size 178 | self.to_patch_embedding = nn.Linear(patch_dim, dim) 179 | self.cls_token = nn.Parameter(torch.randn(1, dim)) 180 | 181 | self.use_rotary_emb = rotary_emb 182 | if rotary_emb: 183 | self.frame_rot_emb = RotaryEmbedding(dim_head) 184 | self.image_rot_emb = AxialRotaryEmbedding(dim_head) 185 | else: 186 | self.pos_emb = nn.Embedding(num_positions + 1, dim) 187 | 188 | 189 | self.layers = nn.ModuleList([]) 190 | for _ in range(depth): 191 | ff = FeedForward(dim, dropout = ff_dropout) 192 | time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 193 | spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) 194 | 195 | if shift_tokens: 196 | time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff)) 197 | 198 | time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff)) 199 | 200 | self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) 201 | 202 | self.to_out = nn.Sequential( 203 | nn.LayerNorm(dim), 204 | nn.Linear(dim, num_classes) 205 | ) 206 | 207 | def forward(self, video, mask = None): 208 | b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size 209 | assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}' 210 | 211 | # calculate num patches in height and width dimension, and number of total patches (n) 212 | 213 | hp, wp = (h // p), (w // p) 214 | n = hp * wp 215 | 216 | # video to patch embeddings 217 | 218 | video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p) 219 | tokens = self.to_patch_embedding(video) 220 | 221 | # add cls token 222 | 223 | cls_token = repeat(self.cls_token, 'n d -> b n d', b = b) 224 | x = torch.cat((cls_token, tokens), dim = 1) 225 | 226 | # positional embedding 227 | 228 | frame_pos_emb = None 229 | image_pos_emb = None 230 | if not self.use_rotary_emb: 231 | x += self.pos_emb(torch.arange(x.shape[1], device = device)) 232 | else: 233 | frame_pos_emb = self.frame_rot_emb(f, device = device) 234 | image_pos_emb = self.image_rot_emb(hp, wp, device = device) 235 | 236 | # calculate masking for uneven number of frames 237 | 238 | frame_mask = None 239 | cls_attn_mask = None 240 | if exists(mask): 241 | mask_with_cls = F.pad(mask, (1, 0), value = True) 242 | 243 | frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads) 244 | 245 | cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads) 246 | cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True) 247 | 248 | # time and space attention 249 | 250 | for (time_attn, spatial_attn, ff) in self.layers: 251 | x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x 252 | x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x 253 | x = ff(x) + x 254 | 255 | cls_token = x[:, 0] 256 | return self.to_out(cls_token) 257 | --------------------------------------------------------------------------------