├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── hamburger.png ├── hamburger_pytorch ├── __init__.py └── hamburger_pytorch.py ├── mu.png └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## 🍔 - Pytorch 6 | 7 | Pytorch implementation of the hamburger module from the ICLR 2021 paper Is Attention Better Than Matrix Decomposition?. Following Betteridge's law, the answer according to the paper is "No" for segmentation and GANs. 8 | 9 | This repository will contain the NMF-MU (nonnegative matrix factorization w/ multiplicative update) module sandwiched by linear projections. 10 | 11 | Update: I tried this, but did not get better results than just using linear attention 12 | 13 | ## Install 14 | 15 | ```bash 16 | $ pip install hamburger-pytorch 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import torch 23 | from hamburger_pytorch import Hamburger 24 | 25 | hamburger = Hamburger( 26 | dim = 512, # input dimension 27 | n = 32 * 32, # n will be size of the sequence, in this case, height times width of the images 28 | ratio = 8, # matrix factorization ratio, recommended to be at 8 29 | K = 6 # number of iterations, optimal at 6 as shown in paper 30 | ) 31 | 32 | x = torch.randn(1, 512, 32, 32) 33 | hamburger(x) + x # (1, 512, 32, 32) 34 | ``` 35 | 36 | ## Citations 37 | 38 | ```bibtex 39 | @inproceedings{ 40 | anonymous2021is, 41 | title={Is Attention Better Than Matrix Decomposition?}, 42 | author={Anonymous}, 43 | booktitle={Submitted to International Conference on Learning Representations}, 44 | year={2021}, 45 | url={https://openreview.net/forum?id=1FvkSpWosOl}, 46 | note={under review} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /hamburger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/hamburger-pytorch/e8519212b31a8baf2072ffd62a3dd728914cb338/hamburger.png -------------------------------------------------------------------------------- /hamburger_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from hamburger_pytorch.hamburger_pytorch import Hamburger 2 | -------------------------------------------------------------------------------- /hamburger_pytorch/hamburger_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | from einops import repeat, rearrange 6 | 7 | # helper fn 8 | 9 | @contextmanager 10 | def null_context(): 11 | yield 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def default(val, d): 17 | return val if exists(val) else d 18 | 19 | # classes 20 | 21 | class NMF(nn.Module): 22 | def __init__( 23 | self, 24 | dim, 25 | n, 26 | ratio = 8, 27 | K = 6, 28 | eps = 2e-8 29 | ): 30 | super().__init__() 31 | r = dim // ratio 32 | 33 | D = torch.zeros(dim, r).uniform_(0, 1) 34 | C = torch.zeros(r, n).uniform_(0, 1) 35 | 36 | self.K = K 37 | self.D = nn.Parameter(D) 38 | self.C = nn.Parameter(C) 39 | 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | b, D, C, eps = x.shape[0], self.D, self.C, self.eps 44 | 45 | # x is made non-negative with relu as proposed in paper 46 | x = F.relu(x) 47 | 48 | D = repeat(D, 'd r -> b d r', b = b) 49 | C = repeat(C, 'r n -> b r n', b = b) 50 | 51 | # transpose 52 | t = lambda tensor: rearrange(tensor, 'b i j -> b j i') 53 | 54 | for k in reversed(range(self.K)): 55 | # only calculate gradients on the last step, per propose 'One-step Gradient' 56 | context = null_context if k == 0 else torch.no_grad 57 | with context(): 58 | C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps)) 59 | D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps)) 60 | C, D = C_new, D_new 61 | 62 | return D @ C 63 | 64 | class Hamburger(nn.Module): 65 | def __init__( 66 | self, 67 | *, 68 | dim, 69 | n, 70 | inner_dim = None, 71 | ratio = 8, 72 | K = 6 73 | ): 74 | super().__init__() 75 | inner_dim = default(inner_dim, dim) 76 | 77 | self.lower_bread = nn.Conv1d(dim, inner_dim, 1, bias = False) 78 | self.ham = NMF(inner_dim, n, ratio = ratio, K = K) 79 | self.upper_bread = nn.Conv1d(inner_dim, dim, 1, bias = False) 80 | 81 | def forward(self, x): 82 | shape = x.shape 83 | x = x.flatten(2) 84 | 85 | x = self.lower_bread(x) 86 | x = self.ham(x) 87 | x = self.upper_bread(x) 88 | return x.reshape(shape) 89 | -------------------------------------------------------------------------------- /mu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/hamburger-pytorch/e8519212b31a8baf2072ffd62a3dd728914cb338/mu.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'hamburger-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.3', 7 | license='MIT', 8 | description = 'Hamburger - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/hamburger-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'matrix factorization' 16 | ], 17 | install_requires=[ 18 | 'torch', 19 | 'einops>=0.3' 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) --------------------------------------------------------------------------------