├── LICENSE ├── README.md ├── images └── arch.png ├── main.py └── models.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## ViViT: A Video Vision Transformer - Pytorch 4 | 5 | An unofficial implementation of ViViT. 6 | 7 | We present pure-transformer based models for video 8 | classification, drawing upon the recent success of such models in image classification. Our model extracts spatiotemporal tokens from the input video, which are then encoded by a series of transformer layers. In order to handle the long sequences of tokens encountered in video, we 9 | propose several, efficient variants of our model which factorise the spatial- and temporal-dimensions of the input. Although transformer-based models are known to only be effective when large training datasets are available, we show 10 | how we can effectively regularise the model during training 11 | and leverage pretrained image models to be able to train on 12 | comparatively small datasets. We conduct thorough ablation studies, and achieve state-of-the-art results on multiple 13 | video classification benchmarks including Kinetics 400 and 14 | 600, Epic Kitchens, Something-Something v2 and Moments 15 | in Time, outperforming prior methods based on deep 3D 16 | convolutional networks. To facilitate further research, we 17 | will release code and models. 18 | 19 | 20 | ## Notes: 21 | * Currently the implementation only includes Model-3. 22 | * Embedding technique: Tubelet embedding 23 | * For Model-2, refer to the repo: https://github.com/rishikksh20/ViViT-pytorch by [@rishikksh20](https://github.com/rishikksh20): 24 | 25 | 26 | ## Usage 27 | 28 | ```python 29 | import torch 30 | from models import ViViTBackbone 31 | 32 | v = ViViTBackbone( 33 | t=32, 34 | h=64, 35 | w=64, 36 | patch_t=8, 37 | patch_h=4, 38 | patch_w=4, 39 | num_classes=10, 40 | dim=512, 41 | depth=6, 42 | heads=10, 43 | mlp_dim=8, 44 | model=3 45 | ) 46 | 47 | device = torch.device('cpu') 48 | vid = torch.rand(32, 3, 32, 64, 64).to(device) 49 | 50 | pred = v(vid) # (32, 10) 51 | 52 | parameters = filter(lambda p: p.requires_grad, v.parameters()) 53 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 54 | print('Trainable Parameters: %.3fM' % parameters) 55 | ``` 56 | 57 | ## Citation: 58 | ``` 59 | @misc{arnab2021vivit, 60 | title={ViViT: A Video Vision Transformer}, 61 | author={Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lučić and Cordelia Schmid}, 62 | year={2021}, 63 | eprint={2103.15691}, 64 | archivePrefix={arXiv}, 65 | primaryClass={cs.CV} 66 | } 67 | ``` 68 | 69 | ## Acknowledgement: 70 | * Code implementation is derived from [@lucidrains](https://github.com/lucidrains), repo : https://github.com/lucidrains/vit-pytorch 71 | -------------------------------------------------------------------------------- /images/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drv-agwl/ViViT-pytorch/619bcaca2d6fc74b580930c34711336fb1917351/images/arch.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import ViViTBackbone 3 | 4 | v = ViViTBackbone( 5 | t=32, 6 | h=64, 7 | w=64, 8 | patch_t=8, 9 | patch_h=4, 10 | patch_w=4, 11 | num_classes=10, 12 | dim=512, 13 | depth=6, 14 | heads=10, 15 | mlp_dim=8, 16 | model=3 17 | ) 18 | 19 | device = torch.device('cpu') 20 | vid = torch.rand(32, 3, 32, 64, 64).to(device) 21 | 22 | pred = v(vid) # (32, 10) 23 | 24 | parameters = filter(lambda p: p.requires_grad, v.parameters()) 25 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 26 | print('Trainable Parameters: %.3fM' % parameters) 27 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn, einsum 2 | import torch 3 | from einops.layers.torch import Rearrange 4 | from einops import rearrange, repeat 5 | 6 | 7 | class PreNorm(nn.Module): 8 | def __init__(self, dim, fn): 9 | super().__init__() 10 | self.norm = nn.LayerNorm(dim) 11 | self.fn = fn 12 | 13 | def forward(self, x, **kwargs): 14 | return self.fn(self.norm(x), **kwargs) 15 | 16 | 17 | class FSAttention(nn.Module): 18 | """Factorized Self-Attention""" 19 | 20 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 21 | super().__init__() 22 | inner_dim = dim_head * heads 23 | project_out = not (heads == 1 and dim_head == dim) 24 | 25 | self.heads = heads 26 | self.scale = dim_head ** -0.5 27 | 28 | self.attend = nn.Softmax(dim=-1) 29 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 30 | 31 | self.to_out = nn.Sequential( 32 | nn.Linear(inner_dim, dim), 33 | nn.Dropout(dropout) 34 | ) if project_out else nn.Identity() 35 | 36 | def forward(self, x): 37 | b, n, _, h = *x.shape, self.heads 38 | qkv = self.to_qkv(x).chunk(3, dim=-1) 39 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 40 | 41 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 42 | 43 | attn = self.attend(dots) 44 | 45 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 46 | out = rearrange(out, 'b h n d -> b n (h d)') 47 | return self.to_out(out) 48 | 49 | 50 | class FDAttention(nn.Module): 51 | """Factorized Dot-product Attention""" 52 | 53 | def __init__(self, dim, nt, nh, nw, heads=8, dim_head=64, dropout=0.): 54 | super().__init__() 55 | inner_dim = dim_head * heads 56 | project_out = not (heads == 1 and dim_head == dim) 57 | 58 | self.nt = nt 59 | self.nh = nh 60 | self.nw = nw 61 | 62 | self.heads = heads 63 | self.scale = dim_head ** -0.5 64 | 65 | self.attend = nn.Softmax(dim=-1) 66 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 67 | 68 | self.to_out = nn.Sequential( 69 | nn.Linear(inner_dim, dim), 70 | nn.Dropout(dropout) 71 | ) if project_out else nn.Identity() 72 | 73 | def forward(self, x): 74 | b, n, d, h = *x.shape, self.heads 75 | 76 | qkv = self.to_qkv(x).chunk(3, dim=-1) 77 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) 78 | qs, qt = q.chunk(2, dim=1) 79 | ks, kt = k.chunk(2, dim=1) 80 | vs, vt = v.chunk(2, dim=1) 81 | 82 | # Attention over spatial dimension 83 | qs = qs.view(b, h // 2, self.nt, self.nh * self.nw, -1) 84 | ks, vs = ks.view(b, h // 2, self.nt, self.nh * self.nw, -1), vs.view(b, h // 2, self.nt, self.nh * self.nw, -1) 85 | spatial_dots = einsum('b h t i d, b h t j d -> b h t i j', qs, ks) * self.scale 86 | sp_attn = self.attend(spatial_dots) 87 | spatial_out = einsum('b h t i j, b h t j d -> b h t i d', sp_attn, vs) 88 | 89 | # Attention over temporal dimension 90 | qt = qt.view(b, h // 2, self.nh * self.nw, self.nt, -1) 91 | kt, vt = kt.view(b, h // 2, self.nh * self.nw, self.nt, -1), vt.view(b, h // 2, self.nh * self.nw, self.nt, -1) 92 | temporal_dots = einsum('b h s i d, b h s j d -> b h s i j', qt, kt) * self.scale 93 | temporal_attn = self.attend(temporal_dots) 94 | temporal_out = einsum('b h s i j, b h s j d -> b h s i d', temporal_attn, vt) 95 | 96 | # return self.to_out(out) 97 | 98 | 99 | class FeedForward(nn.Module): 100 | def __init__(self, dim, hidden_dim, dropout=0.): 101 | super().__init__() 102 | self.net = nn.Sequential( 103 | nn.Linear(dim, hidden_dim), 104 | nn.GELU(), 105 | nn.Dropout(dropout), 106 | nn.Linear(hidden_dim, dim), 107 | nn.Dropout(dropout) 108 | ) 109 | 110 | def forward(self, x): 111 | return self.net(x) 112 | 113 | 114 | class FSATransformerEncoder(nn.Module): 115 | """Factorized Self-Attention Transformer Encoder""" 116 | 117 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.): 118 | super().__init__() 119 | self.layers = nn.ModuleList([]) 120 | self.nt = nt 121 | self.nh = nh 122 | self.nw = nw 123 | 124 | for _ in range(depth): 125 | self.layers.append(nn.ModuleList( 126 | [PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 127 | PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 128 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 129 | ])) 130 | 131 | def forward(self, x): 132 | 133 | b = x.shape[0] 134 | x = torch.flatten(x, start_dim=0, end_dim=1) # extract spatial tokens from x 135 | 136 | for sp_attn, temp_attn, ff in self.layers: 137 | sp_attn_x = sp_attn(x) + x # Spatial attention 138 | 139 | # Reshape tensors for temporal attention 140 | sp_attn_x = sp_attn_x.chunk(b, dim=0) 141 | sp_attn_x = [temp[None] for temp in sp_attn_x] 142 | sp_attn_x = torch.cat(sp_attn_x, dim=0).transpose(1, 2) 143 | sp_attn_x = torch.flatten(sp_attn_x, start_dim=0, end_dim=1) 144 | 145 | temp_attn_x = temp_attn(sp_attn_x) + sp_attn_x # Temporal attention 146 | 147 | x = ff(temp_attn_x) + temp_attn_x # MLP 148 | 149 | # Again reshape tensor for spatial attention 150 | x = x.chunk(b, dim=0) 151 | x = [temp[None] for temp in x] 152 | x = torch.cat(x, dim=0).transpose(1, 2) 153 | x = torch.flatten(x, start_dim=0, end_dim=1) 154 | 155 | # Reshape vector to [b, nt*nh*nw, dim] 156 | x = x.chunk(b, dim=0) 157 | x = [temp[None] for temp in x] 158 | x = torch.cat(x, dim=0) 159 | x = torch.flatten(x, start_dim=1, end_dim=2) 160 | return x 161 | 162 | 163 | class FDATransformerEncoder(nn.Module): 164 | """Factorized Dot-product Attention Transformer Encoder""" 165 | 166 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.): 167 | super().__init__() 168 | self.layers = nn.ModuleList([]) 169 | self.nt = nt 170 | self.nh = nh 171 | self.nw = nw 172 | 173 | for _ in range(depth): 174 | self.layers.append( 175 | PreNorm(dim, FDAttention(dim, nt, nh, nw, heads=heads, dim_head=dim_head, dropout=dropout))) 176 | 177 | def forward(self, x): 178 | for attn in self.layers: 179 | x = attn(x) + x 180 | 181 | return x 182 | 183 | 184 | class ViViTBackbone(nn.Module): 185 | """ Model-3 backbone of ViViT """ 186 | 187 | def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3, 188 | channels=3, mode='tubelet', device='cuda', emb_dropout=0., dropout=0., model=3): 189 | super().__init__() 190 | 191 | assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \ 192 | "tubelet size " 193 | 194 | self.T = t 195 | self.H = h 196 | self.W = w 197 | self.channels = channels 198 | self.t = patch_t 199 | self.h = patch_h 200 | self.w = patch_w 201 | self.mode = mode 202 | self.device = device 203 | 204 | self.nt = self.T // self.t 205 | self.nh = self.H // self.h 206 | self.nw = self.W // self.w 207 | 208 | tubelet_dim = self.t * self.h * self.w * channels 209 | 210 | self.to_tubelet_embedding = nn.Sequential( 211 | Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w), 212 | nn.Linear(tubelet_dim, dim) 213 | ) 214 | 215 | # repeat same spatial position encoding temporally 216 | self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1) 217 | 218 | self.dropout = nn.Dropout(emb_dropout) 219 | 220 | if model == 3: 221 | self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim, 222 | self.nt, self.nh, self.nw, dropout) 223 | elif model == 4: 224 | assert heads % 2 == 0, "Number of heads should be even" 225 | self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim, 226 | self.nt, self.nh, self.nw, dropout) 227 | 228 | self.to_latent = nn.Identity() 229 | 230 | self.mlp_head = nn.Sequential( 231 | nn.LayerNorm(dim), 232 | nn.Linear(dim, num_classes) 233 | ) 234 | 235 | def forward(self, x): 236 | """ x is a video: (b, C, T, H, W) """ 237 | 238 | tokens = self.to_tubelet_embedding(x) 239 | 240 | tokens += self.pos_embedding 241 | tokens = self.dropout(tokens) 242 | 243 | x = self.transformer(tokens) 244 | x = x.mean(dim=1) 245 | 246 | x = self.to_latent(x) 247 | return self.mlp_head(x) 248 | 249 | 250 | if __name__ == '__main__': 251 | device = torch.device('cpu') 252 | x = torch.rand(32, 3, 32, 64, 64).to(device) 253 | 254 | vivit = ViViTBackbone(32, 64, 64, 8, 4, 4, 10, 512, 6, 10, 8, model=3).to(device) 255 | out = vivit(x) 256 | print(out) 257 | --------------------------------------------------------------------------------