├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── setup.py ├── tklr.png └── tokenlearner_pytorch ├── __init__.py └── tokenlearner_pytorch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishabh Anand 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tokenlearner-pytorch 2 | Unofficial PyTorch implementation of `TokenLearner` by Ryoo et al. from Google AI ([`abs`](https://arxiv.org/abs/2106.11297), [`pdf`](https://arxiv.org/pdf/2106.11297.pdf)) 3 | 4 | 5 | 6 | ## Installation 7 | You can install TokenLearner via `pip`: 8 | 9 | ``` 10 | pip install tokenlearner-pytorch 11 | ``` 12 | 13 | ## Usage 14 | You can access the `TokenLearner` class from the `tokenlearner_pytorch` package. You can use this layer with a Vision Transformer, MLPMixer, or Video Vision Transformer as done in the paper. 15 | 16 | ```python 17 | import torch 18 | from tokenlearner_pytorch import TokenLearner 19 | 20 | tklr = TokenLearner(S=8) 21 | x = torch.rand(512, 32, 32, 3) 22 | y = tklr(x) # [512, 8, 3] 23 | ``` 24 | 25 | You can also use `TokenLearner` and `TokenFuser` together with Multi-head Self-Attention as done in the paper: 26 | 27 | ```python 28 | import torch 29 | import torch.nn as nn 30 | from tokenlearner_pytorch import TokenLearner, TokenFuser 31 | 32 | mhsa = nn.MultiheadAttention(3, 1) 33 | tklr = TokenLearner(S=8) 34 | tkfr = TokenFuser(H=32, W=32, C=3, S=8) 35 | 36 | x = torch.rand(512, 32, 32, 3) # a batch of images 37 | 38 | y = tklr(x) 39 | y = y.view(8, 512, 3) 40 | y, _ = mhsa(y, y, y) # ignore attn weights 41 | y = y.view(512, 8, 3) 42 | 43 | out = tkfr(y, x) # [512, 32, 32, 3] 44 | ``` 45 | 46 | ## TODO 47 | - [ ] Add support for temporal dimension `T` 48 | - [ ] Implement `TokenFuser` with `ViT` 49 | - [ ] Implement `TokenFuser` with `ViViT` 50 | 51 | ## Contributions 52 | If I've made any errors or you have any suggestions, feel free to raise an Issue or PR. All contributions welcome!! 53 | 54 | ## License 55 | [MIT](https://github.com/rish-16/tokenlearner-pytorch/blob/main/LICENSE) 56 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tokenlearner_pytorch import TokenLearner, TokenFuser 4 | 5 | mhsa = nn.MultiheadAttention(3, 1) 6 | tklr = TokenLearner(S=8) 7 | tkfr = TokenFuser(32, 32, 3, S=8) 8 | 9 | x = torch.rand(512, 32, 32, 3) 10 | 11 | y = tklr(x) 12 | y = y.view(8, 512, 3) 13 | y, _ = mhsa(y, y, y) # ignore attn weights 14 | y = y.view(512, 8, 3) 15 | 16 | out = tkfr(y, x) 17 | print (out.shape) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | import setuptools 3 | from os import path 4 | 5 | this_directory = path.abspath(path.dirname(__file__)) 6 | with open(path.join(this_directory, 'README.md')) as f: 7 | long_description = f.read() 8 | 9 | setup( 10 | name='tokenlearner_pytorch', 11 | version='0.1.2', 12 | description='Unofficial PyTorch implementation of TokenLearner by Google AI', 13 | long_description=long_description, 14 | long_description_content_type = 'text/markdown', 15 | url='https://github.com/rish-16/tokenlearner-pytorch', 16 | author='Rishabh Anand', 17 | author_email='mail.rishabh.anand@gmail.com', 18 | license='MIT', 19 | packages=['tokenlearner_pytorch'], 20 | install_requires=['torch'], 21 | 22 | classifiers=[ 23 | 'Intended Audience :: Science/Research', 24 | 'Programming Language :: Python :: 3.6', 25 | ], 26 | ) -------------------------------------------------------------------------------- /tklr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rish-16/tokenlearner-pytorch/2b4f107e544e199514e2bc0dfc9461bc8775b862/tklr.png -------------------------------------------------------------------------------- /tokenlearner_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from tokenlearner_pytorch.tokenlearner_pytorch import TokenLearner, TokenFuser 2 | 3 | __version__ = "0.1.2" 4 | __author__ = 'Rishabh Anand' 5 | __credits__ = 'Google AI' -------------------------------------------------------------------------------- /tokenlearner_pytorch/tokenlearner_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SpatialAttention(nn.Module): 6 | def __init__(self) -> None: 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(2, 1, kernel_size=(1,1), stride=1), 10 | nn.BatchNorm2d(1), 11 | nn.ReLU() 12 | ) 13 | 14 | self.sgap = nn.AvgPool2d(2) 15 | 16 | def forward(self, x): 17 | B, H, W, C = x.shape 18 | x = x.view(B, C, H, W) 19 | 20 | mx = torch.max(x, 1)[0].unsqueeze(1) 21 | avg = torch.mean(x, 1).unsqueeze(1) 22 | combined = torch.cat([mx, avg], dim=1) 23 | fmap = self.conv(combined) 24 | weight_map = torch.sigmoid(fmap) 25 | out = (x * weight_map).mean(dim=(-2, -1)) 26 | 27 | return out, x * weight_map 28 | 29 | class TokenLearner(nn.Module): 30 | def __init__(self, S) -> None: 31 | super().__init__() 32 | self.S = S 33 | self.tokenizers = nn.ModuleList([SpatialAttention() for _ in range(S)]) 34 | 35 | def forward(self, x): 36 | B, _, _, C = x.shape 37 | Z = torch.Tensor(B, self.S, C) 38 | for i in range(self.S): 39 | Ai, _ = self.tokenizers[i](x) # [B, C] 40 | Z[:, i, :] = Ai 41 | return Z 42 | 43 | class TokenFuser(nn.Module): 44 | def __init__(self, H, W, C, S) -> None: 45 | super().__init__() 46 | self.projection = nn.Linear(S, S, bias=False) 47 | self.Bi = nn.Linear(C, S) 48 | self.spatial_attn = SpatialAttention() 49 | self.S = S 50 | 51 | def forward(self, y, x): 52 | B, S, C = y.shape 53 | B, H, W, C = x.shape 54 | 55 | Y = self.projection(y.view(B, C, S)).view(B, S, C) 56 | Bw = torch.sigmoid(self.Bi(x)).view(B, H*W, S) # [B, HW, S] 57 | BwY = torch.matmul(Bw, Y) 58 | 59 | _, xj = self.spatial_attn(x) 60 | xj = xj.view(B, H*W, C) 61 | 62 | out = (BwY + xj).view(B, H, W, C) 63 | 64 | return out 65 | --------------------------------------------------------------------------------