├── LICENSE
├── README.md
├── images
└── arch.png
├── main.py
└── models.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Rishikesh (ऋषिकेश)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## ViViT: A Video Vision Transformer - Pytorch
4 |
5 | An unofficial implementation of ViViT.
6 |
7 | We present pure-transformer based models for video
8 | classification, drawing upon the recent success of such models in image classification. Our model extracts spatiotemporal tokens from the input video, which are then encoded by a series of transformer layers. In order to handle the long sequences of tokens encountered in video, we
9 | propose several, efficient variants of our model which factorise the spatial- and temporal-dimensions of the input. Although transformer-based models are known to only be effective when large training datasets are available, we show
10 | how we can effectively regularise the model during training
11 | and leverage pretrained image models to be able to train on
12 | comparatively small datasets. We conduct thorough ablation studies, and achieve state-of-the-art results on multiple
13 | video classification benchmarks including Kinetics 400 and
14 | 600, Epic Kitchens, Something-Something v2 and Moments
15 | in Time, outperforming prior methods based on deep 3D
16 | convolutional networks. To facilitate further research, we
17 | will release code and models.
18 |
19 |
20 | ## Notes:
21 | * Currently the implementation only includes Model-3.
22 | * Embedding technique: Tubelet embedding
23 | * For Model-2, refer to the repo: https://github.com/rishikksh20/ViViT-pytorch by [@rishikksh20](https://github.com/rishikksh20):
24 |
25 |
26 | ## Usage
27 |
28 | ```python
29 | import torch
30 | from models import ViViTBackbone
31 |
32 | v = ViViTBackbone(
33 | t=32,
34 | h=64,
35 | w=64,
36 | patch_t=8,
37 | patch_h=4,
38 | patch_w=4,
39 | num_classes=10,
40 | dim=512,
41 | depth=6,
42 | heads=10,
43 | mlp_dim=8,
44 | model=3
45 | )
46 |
47 | device = torch.device('cpu')
48 | vid = torch.rand(32, 3, 32, 64, 64).to(device)
49 |
50 | pred = v(vid) # (32, 10)
51 |
52 | parameters = filter(lambda p: p.requires_grad, v.parameters())
53 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
54 | print('Trainable Parameters: %.3fM' % parameters)
55 | ```
56 |
57 | ## Citation:
58 | ```
59 | @misc{arnab2021vivit,
60 | title={ViViT: A Video Vision Transformer},
61 | author={Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lučić and Cordelia Schmid},
62 | year={2021},
63 | eprint={2103.15691},
64 | archivePrefix={arXiv},
65 | primaryClass={cs.CV}
66 | }
67 | ```
68 |
69 | ## Acknowledgement:
70 | * Code implementation is derived from [@lucidrains](https://github.com/lucidrains), repo : https://github.com/lucidrains/vit-pytorch
71 |
--------------------------------------------------------------------------------
/images/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drv-agwl/ViViT-pytorch/619bcaca2d6fc74b580930c34711336fb1917351/images/arch.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import ViViTBackbone
3 |
4 | v = ViViTBackbone(
5 | t=32,
6 | h=64,
7 | w=64,
8 | patch_t=8,
9 | patch_h=4,
10 | patch_w=4,
11 | num_classes=10,
12 | dim=512,
13 | depth=6,
14 | heads=10,
15 | mlp_dim=8,
16 | model=3
17 | )
18 |
19 | device = torch.device('cpu')
20 | vid = torch.rand(32, 3, 32, 64, 64).to(device)
21 |
22 | pred = v(vid) # (32, 10)
23 |
24 | parameters = filter(lambda p: p.requires_grad, v.parameters())
25 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
26 | print('Trainable Parameters: %.3fM' % parameters)
27 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch import nn, einsum
2 | import torch
3 | from einops.layers.torch import Rearrange
4 | from einops import rearrange, repeat
5 |
6 |
7 | class PreNorm(nn.Module):
8 | def __init__(self, dim, fn):
9 | super().__init__()
10 | self.norm = nn.LayerNorm(dim)
11 | self.fn = fn
12 |
13 | def forward(self, x, **kwargs):
14 | return self.fn(self.norm(x), **kwargs)
15 |
16 |
17 | class FSAttention(nn.Module):
18 | """Factorized Self-Attention"""
19 |
20 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
21 | super().__init__()
22 | inner_dim = dim_head * heads
23 | project_out = not (heads == 1 and dim_head == dim)
24 |
25 | self.heads = heads
26 | self.scale = dim_head ** -0.5
27 |
28 | self.attend = nn.Softmax(dim=-1)
29 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
30 |
31 | self.to_out = nn.Sequential(
32 | nn.Linear(inner_dim, dim),
33 | nn.Dropout(dropout)
34 | ) if project_out else nn.Identity()
35 |
36 | def forward(self, x):
37 | b, n, _, h = *x.shape, self.heads
38 | qkv = self.to_qkv(x).chunk(3, dim=-1)
39 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
40 |
41 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
42 |
43 | attn = self.attend(dots)
44 |
45 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
46 | out = rearrange(out, 'b h n d -> b n (h d)')
47 | return self.to_out(out)
48 |
49 |
50 | class FDAttention(nn.Module):
51 | """Factorized Dot-product Attention"""
52 |
53 | def __init__(self, dim, nt, nh, nw, heads=8, dim_head=64, dropout=0.):
54 | super().__init__()
55 | inner_dim = dim_head * heads
56 | project_out = not (heads == 1 and dim_head == dim)
57 |
58 | self.nt = nt
59 | self.nh = nh
60 | self.nw = nw
61 |
62 | self.heads = heads
63 | self.scale = dim_head ** -0.5
64 |
65 | self.attend = nn.Softmax(dim=-1)
66 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
67 |
68 | self.to_out = nn.Sequential(
69 | nn.Linear(inner_dim, dim),
70 | nn.Dropout(dropout)
71 | ) if project_out else nn.Identity()
72 |
73 | def forward(self, x):
74 | b, n, d, h = *x.shape, self.heads
75 |
76 | qkv = self.to_qkv(x).chunk(3, dim=-1)
77 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
78 | qs, qt = q.chunk(2, dim=1)
79 | ks, kt = k.chunk(2, dim=1)
80 | vs, vt = v.chunk(2, dim=1)
81 |
82 | # Attention over spatial dimension
83 | qs = qs.view(b, h // 2, self.nt, self.nh * self.nw, -1)
84 | ks, vs = ks.view(b, h // 2, self.nt, self.nh * self.nw, -1), vs.view(b, h // 2, self.nt, self.nh * self.nw, -1)
85 | spatial_dots = einsum('b h t i d, b h t j d -> b h t i j', qs, ks) * self.scale
86 | sp_attn = self.attend(spatial_dots)
87 | spatial_out = einsum('b h t i j, b h t j d -> b h t i d', sp_attn, vs)
88 |
89 | # Attention over temporal dimension
90 | qt = qt.view(b, h // 2, self.nh * self.nw, self.nt, -1)
91 | kt, vt = kt.view(b, h // 2, self.nh * self.nw, self.nt, -1), vt.view(b, h // 2, self.nh * self.nw, self.nt, -1)
92 | temporal_dots = einsum('b h s i d, b h s j d -> b h s i j', qt, kt) * self.scale
93 | temporal_attn = self.attend(temporal_dots)
94 | temporal_out = einsum('b h s i j, b h s j d -> b h s i d', temporal_attn, vt)
95 |
96 | # return self.to_out(out)
97 |
98 |
99 | class FeedForward(nn.Module):
100 | def __init__(self, dim, hidden_dim, dropout=0.):
101 | super().__init__()
102 | self.net = nn.Sequential(
103 | nn.Linear(dim, hidden_dim),
104 | nn.GELU(),
105 | nn.Dropout(dropout),
106 | nn.Linear(hidden_dim, dim),
107 | nn.Dropout(dropout)
108 | )
109 |
110 | def forward(self, x):
111 | return self.net(x)
112 |
113 |
114 | class FSATransformerEncoder(nn.Module):
115 | """Factorized Self-Attention Transformer Encoder"""
116 |
117 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.):
118 | super().__init__()
119 | self.layers = nn.ModuleList([])
120 | self.nt = nt
121 | self.nh = nh
122 | self.nw = nw
123 |
124 | for _ in range(depth):
125 | self.layers.append(nn.ModuleList(
126 | [PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
127 | PreNorm(dim, FSAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
128 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
129 | ]))
130 |
131 | def forward(self, x):
132 |
133 | b = x.shape[0]
134 | x = torch.flatten(x, start_dim=0, end_dim=1) # extract spatial tokens from x
135 |
136 | for sp_attn, temp_attn, ff in self.layers:
137 | sp_attn_x = sp_attn(x) + x # Spatial attention
138 |
139 | # Reshape tensors for temporal attention
140 | sp_attn_x = sp_attn_x.chunk(b, dim=0)
141 | sp_attn_x = [temp[None] for temp in sp_attn_x]
142 | sp_attn_x = torch.cat(sp_attn_x, dim=0).transpose(1, 2)
143 | sp_attn_x = torch.flatten(sp_attn_x, start_dim=0, end_dim=1)
144 |
145 | temp_attn_x = temp_attn(sp_attn_x) + sp_attn_x # Temporal attention
146 |
147 | x = ff(temp_attn_x) + temp_attn_x # MLP
148 |
149 | # Again reshape tensor for spatial attention
150 | x = x.chunk(b, dim=0)
151 | x = [temp[None] for temp in x]
152 | x = torch.cat(x, dim=0).transpose(1, 2)
153 | x = torch.flatten(x, start_dim=0, end_dim=1)
154 |
155 | # Reshape vector to [b, nt*nh*nw, dim]
156 | x = x.chunk(b, dim=0)
157 | x = [temp[None] for temp in x]
158 | x = torch.cat(x, dim=0)
159 | x = torch.flatten(x, start_dim=1, end_dim=2)
160 | return x
161 |
162 |
163 | class FDATransformerEncoder(nn.Module):
164 | """Factorized Dot-product Attention Transformer Encoder"""
165 |
166 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, nt, nh, nw, dropout=0.):
167 | super().__init__()
168 | self.layers = nn.ModuleList([])
169 | self.nt = nt
170 | self.nh = nh
171 | self.nw = nw
172 |
173 | for _ in range(depth):
174 | self.layers.append(
175 | PreNorm(dim, FDAttention(dim, nt, nh, nw, heads=heads, dim_head=dim_head, dropout=dropout)))
176 |
177 | def forward(self, x):
178 | for attn in self.layers:
179 | x = attn(x) + x
180 |
181 | return x
182 |
183 |
184 | class ViViTBackbone(nn.Module):
185 | """ Model-3 backbone of ViViT """
186 |
187 | def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3,
188 | channels=3, mode='tubelet', device='cuda', emb_dropout=0., dropout=0., model=3):
189 | super().__init__()
190 |
191 | assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \
192 | "tubelet size "
193 |
194 | self.T = t
195 | self.H = h
196 | self.W = w
197 | self.channels = channels
198 | self.t = patch_t
199 | self.h = patch_h
200 | self.w = patch_w
201 | self.mode = mode
202 | self.device = device
203 |
204 | self.nt = self.T // self.t
205 | self.nh = self.H // self.h
206 | self.nw = self.W // self.w
207 |
208 | tubelet_dim = self.t * self.h * self.w * channels
209 |
210 | self.to_tubelet_embedding = nn.Sequential(
211 | Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w),
212 | nn.Linear(tubelet_dim, dim)
213 | )
214 |
215 | # repeat same spatial position encoding temporally
216 | self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1)
217 |
218 | self.dropout = nn.Dropout(emb_dropout)
219 |
220 | if model == 3:
221 | self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
222 | self.nt, self.nh, self.nw, dropout)
223 | elif model == 4:
224 | assert heads % 2 == 0, "Number of heads should be even"
225 | self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
226 | self.nt, self.nh, self.nw, dropout)
227 |
228 | self.to_latent = nn.Identity()
229 |
230 | self.mlp_head = nn.Sequential(
231 | nn.LayerNorm(dim),
232 | nn.Linear(dim, num_classes)
233 | )
234 |
235 | def forward(self, x):
236 | """ x is a video: (b, C, T, H, W) """
237 |
238 | tokens = self.to_tubelet_embedding(x)
239 |
240 | tokens += self.pos_embedding
241 | tokens = self.dropout(tokens)
242 |
243 | x = self.transformer(tokens)
244 | x = x.mean(dim=1)
245 |
246 | x = self.to_latent(x)
247 | return self.mlp_head(x)
248 |
249 |
250 | if __name__ == '__main__':
251 | device = torch.device('cpu')
252 | x = torch.rand(32, 3, 32, 64, 64).to(device)
253 |
254 | vivit = ViViTBackbone(32, 64, 64, 8, 4, 4, 10, 512, 6, 10, 8, model=3).to(device)
255 | out = vivit(x)
256 | print(out)
257 |
--------------------------------------------------------------------------------