├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── hamburger.png
├── hamburger_pytorch
├── __init__.py
└── hamburger_pytorch.py
├── mu.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## 🍔 - Pytorch
6 |
7 | Pytorch implementation of the hamburger module from the ICLR 2021 paper Is Attention Better Than Matrix Decomposition?. Following Betteridge's law, the answer according to the paper is "No" for segmentation and GANs.
8 |
9 | This repository will contain the NMF-MU (nonnegative matrix factorization w/ multiplicative update) module sandwiched by linear projections.
10 |
11 | Update: I tried this, but did not get better results than just using linear attention
12 |
13 | ## Install
14 |
15 | ```bash
16 | $ pip install hamburger-pytorch
17 | ```
18 |
19 | ## Usage
20 |
21 | ```python
22 | import torch
23 | from hamburger_pytorch import Hamburger
24 |
25 | hamburger = Hamburger(
26 | dim = 512, # input dimension
27 | n = 32 * 32, # n will be size of the sequence, in this case, height times width of the images
28 | ratio = 8, # matrix factorization ratio, recommended to be at 8
29 | K = 6 # number of iterations, optimal at 6 as shown in paper
30 | )
31 |
32 | x = torch.randn(1, 512, 32, 32)
33 | hamburger(x) + x # (1, 512, 32, 32)
34 | ```
35 |
36 | ## Citations
37 |
38 | ```bibtex
39 | @inproceedings{
40 | anonymous2021is,
41 | title={Is Attention Better Than Matrix Decomposition?},
42 | author={Anonymous},
43 | booktitle={Submitted to International Conference on Learning Representations},
44 | year={2021},
45 | url={https://openreview.net/forum?id=1FvkSpWosOl},
46 | note={under review}
47 | }
48 | ```
49 |
--------------------------------------------------------------------------------
/hamburger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/hamburger-pytorch/e8519212b31a8baf2072ffd62a3dd728914cb338/hamburger.png
--------------------------------------------------------------------------------
/hamburger_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from hamburger_pytorch.hamburger_pytorch import Hamburger
2 |
--------------------------------------------------------------------------------
/hamburger_pytorch/hamburger_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 | from contextlib import contextmanager
5 | from einops import repeat, rearrange
6 |
7 | # helper fn
8 |
9 | @contextmanager
10 | def null_context():
11 | yield
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def default(val, d):
17 | return val if exists(val) else d
18 |
19 | # classes
20 |
21 | class NMF(nn.Module):
22 | def __init__(
23 | self,
24 | dim,
25 | n,
26 | ratio = 8,
27 | K = 6,
28 | eps = 2e-8
29 | ):
30 | super().__init__()
31 | r = dim // ratio
32 |
33 | D = torch.zeros(dim, r).uniform_(0, 1)
34 | C = torch.zeros(r, n).uniform_(0, 1)
35 |
36 | self.K = K
37 | self.D = nn.Parameter(D)
38 | self.C = nn.Parameter(C)
39 |
40 | self.eps = eps
41 |
42 | def forward(self, x):
43 | b, D, C, eps = x.shape[0], self.D, self.C, self.eps
44 |
45 | # x is made non-negative with relu as proposed in paper
46 | x = F.relu(x)
47 |
48 | D = repeat(D, 'd r -> b d r', b = b)
49 | C = repeat(C, 'r n -> b r n', b = b)
50 |
51 | # transpose
52 | t = lambda tensor: rearrange(tensor, 'b i j -> b j i')
53 |
54 | for k in reversed(range(self.K)):
55 | # only calculate gradients on the last step, per propose 'One-step Gradient'
56 | context = null_context if k == 0 else torch.no_grad
57 | with context():
58 | C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps))
59 | D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps))
60 | C, D = C_new, D_new
61 |
62 | return D @ C
63 |
64 | class Hamburger(nn.Module):
65 | def __init__(
66 | self,
67 | *,
68 | dim,
69 | n,
70 | inner_dim = None,
71 | ratio = 8,
72 | K = 6
73 | ):
74 | super().__init__()
75 | inner_dim = default(inner_dim, dim)
76 |
77 | self.lower_bread = nn.Conv1d(dim, inner_dim, 1, bias = False)
78 | self.ham = NMF(inner_dim, n, ratio = ratio, K = K)
79 | self.upper_bread = nn.Conv1d(inner_dim, dim, 1, bias = False)
80 |
81 | def forward(self, x):
82 | shape = x.shape
83 | x = x.flatten(2)
84 |
85 | x = self.lower_bread(x)
86 | x = self.ham(x)
87 | x = self.upper_bread(x)
88 | return x.reshape(shape)
89 |
--------------------------------------------------------------------------------
/mu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/hamburger-pytorch/e8519212b31a8baf2072ffd62a3dd728914cb338/mu.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'hamburger-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.3',
7 | license='MIT',
8 | description = 'Hamburger - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/hamburger-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'matrix factorization'
16 | ],
17 | install_requires=[
18 | 'torch',
19 | 'einops>=0.3'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
--------------------------------------------------------------------------------