├── README.md ├── LICENSE ├── .gitignore └── ssan.py /README.md: -------------------------------------------------------------------------------- 1 | # SSAN-pytorch 2 | Unofficial implementation of "SSAN: Separable Self-Attention Network for Video Representation Learning (CVPR2021)", in Pytorch 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Felix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ssan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | import timm 5 | 6 | class SSA(nn.Module): 7 | def __init__(self, dim, n_segment): 8 | super(SSA, self).__init__() 9 | self.scale = dim ** -0.5 10 | self.n_segment = n_segment 11 | 12 | self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size = 1) 13 | self.attend = nn.Softmax(dim = -1) 14 | self.to_temporal_qk = nn.Conv3d(dim, dim * 2, 15 | kernel_size=(3, 1, 1), 16 | padding=(1, 0, 0)) 17 | 18 | def forward(self, x): 19 | bt, c, h, w = x.shape 20 | t = self.n_segment 21 | b = bt / t 22 | # Spatial Attention: 23 | qkv = self.to_qkv(x) 24 | q, k, v = qkv.chunk(3, dim = 1) # bt, c, h, w 25 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) # bt, hw, c 26 | # -pixel attention 27 | pixel_dots = einsum('b i c, b j c -> b i j', q, k) # * self.scale 28 | pixel_attn = torch.softmax(pixel_dots, dim=-1) 29 | pixel_out = einsum('b i j, b j d -> b i d', pixel_attn, v) 30 | # -channel attention 31 | chan_dots = einsum('b i c, b i k -> b c k', q, k) # * self.scale # c x c 32 | chan_attn = torch.softmax(chan_dots, dim=-1) 33 | chan_out = einsum('b i j, b d j -> b d i', chan_attn, v) # hw, c 34 | 35 | # aggregation 36 | x_hat = pixel_out + chan_out 37 | x_hat = rearrange(x_hat, '(b t) (h w) c -> b c t h w', t=t, h=h, w=w) 38 | 39 | # Temporal attention 40 | t_qk = self.to_temporal_qkv(x_hat) 41 | tq, tk = t_qk.chunk(2, dim=1) # b, c, t, h, w 42 | tq, tk = map(lambda t: rearrange(t, 'b c t h w -> b t (c h w )'), (tq, tk)) # b, t, d 43 | tv = rearrange(v, '(b t) (h w) c -> b t (c h w)', t=t, h=h, w=w) # shared value embedding 44 | dots = einsum('b i d, b j d -> b i j', tq, tk) # txt 45 | attn = torch.softmax(dots, dim=-1) 46 | out = einsum('b k t, b t d -> b k d', attn, tv) # txd 47 | out = rearrange(out, 'b t (c h w) -> (b t) c h w', h=h,w=w,c=c) 48 | return out 49 | 50 | 51 | class SSAWrapper(nn.Module): 52 | def __init__(self, block, n_segment): 53 | super(SSAWrapper, self).__init__() 54 | self.block = block 55 | self.ssa = SSA(block.bn1.num_features, n_segment) 56 | self.n_segment = n_segment 57 | self.downsample = block.downsample 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | for idx, subm in enumerate(self.block.children()): 63 | if idx < 3: x = subm(x) # 1: conv->bn->relu 64 | 65 | x = self.ssa(x) 66 | 67 | for idx, subm in enumerate(self.block.children()): 68 | if idx < 3 or idx > 7: continue 69 | x = subm(x) # 2,3: conv->bn->relu->conv->bn 70 | 71 | if self.downsample is not None: 72 | residual = self.downsample(residual) 73 | 74 | x += residual # shortcut 75 | x = self.block.act3(x) # act 76 | return x 77 | 78 | 79 | class SSAN(nn.Module): 80 | def __init__(self, n_segment, net): 81 | super(SSAN, self).__init__() 82 | self.n_segment = n_segment 83 | # modify res2 and res3 84 | net.layer2 = nn.Sequential( 85 | SSAWrapper(net.layer2[0], n_segment), 86 | net.layer2[1], 87 | SSAWrapper(net.layer2[2], n_segment), 88 | net.layer2[3], 89 | ) 90 | net.layer3 = nn.Sequential( 91 | SSAWrapper(net.layer3[0], n_segment), 92 | net.layer3[1], 93 | SSAWrapper(net.layer3[2], n_segment), 94 | net.layer3[3], 95 | SSAWrapper(net.layer3[4], n_segment), 96 | net.layer3[5], 97 | ) 98 | self.backbone = net 99 | self.avgpool = net.global_pool 100 | self.fc = net.fc 101 | 102 | def forward(self, x): # BT, C, H, W 103 | x = self.backbone.forward_features(x) 104 | x = self.avgpool(x) 105 | x = torch.flatten(x, 1) 106 | x = self.fc(x) 107 | return x 108 | 109 | 110 | def ssan50(pretrained=False, n_segment=8, **kwargs): 111 | """Constructs a SSAN model. 112 | part of the SSAN model refers to the ResNet-50. 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | """ 116 | net = timm.create_model('resnet50', pretrained=pretrained) 117 | model = SSAN(n_segment=n_segment, net=net) 118 | return model 119 | 120 | 121 | if __name__ == '__main__': 122 | x = torch.randn(8, 3, 224, 224) # bt, c, h, w 123 | backbone = timm.create_model('resnet50', pretrained=False) 124 | model = SSAN(n_segment=4, net=backbone) 125 | y = model(x) # (bt, num_classes) 126 | print(y.shape) 127 | 128 | --------------------------------------------------------------------------------