├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── lambda_networks ├── __init__.py ├── lambda_networks.py └── tfkeras.py ├── setup.py └── λ.png /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Lambda Networks - Pytorch 4 | 5 | Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately. 6 | 7 | Yannic Kilcher's paper review 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install lambda-networks 13 | ``` 14 | 15 | ## Usage 16 | 17 | Global context 18 | 19 | ```python 20 | import torch 21 | from lambda_networks import LambdaLayer 22 | 23 | layer = LambdaLayer( 24 | dim = 32, # channels going in 25 | dim_out = 32, # channels out 26 | n = 64, # size of the receptive window - max(height, width) 27 | dim_k = 16, # key dimension 28 | heads = 4, # number of heads, for multi-query 29 | dim_u = 1 # 'intra-depth' dimension 30 | ) 31 | 32 | x = torch.randn(1, 32, 64, 64) 33 | layer(x) # (1, 32, 64, 64) 34 | ``` 35 | 36 | Localized context 37 | 38 | ```python 39 | import torch 40 | from lambda_networks import LambdaLayer 41 | 42 | layer = LambdaLayer( 43 | dim = 32, 44 | dim_out = 32, 45 | r = 23, # the receptive field for relative positional encoding (23 x 23) 46 | dim_k = 16, 47 | heads = 4, 48 | dim_u = 4 49 | ) 50 | 51 | x = torch.randn(1, 32, 64, 64) 52 | layer(x) # (1, 32, 64, 64) 53 | ``` 54 | 55 | For fun, you can also import this as follows 56 | 57 | ```python 58 | from lambda_networks import λLayer 59 | ``` 60 | 61 | ## Tensorflow / Keras version 62 | 63 | Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following. 64 | 65 | ```python 66 | import tensorflow as tf 67 | from lambda_networks.tfkeras import LambdaLayer 68 | 69 | layer = LambdaLayer( 70 | dim_out = 32, 71 | r = 23, 72 | dim_k = 16, 73 | heads = 4, 74 | dim_u = 1 75 | ) 76 | 77 | x = tf.random.normal((1, 64, 64, 16)) # channel last format 78 | layer(x) # (1, 64, 64, 32) 79 | ``` 80 | 81 | ## Citations 82 | 83 | ```bibtex 84 | @inproceedings{ 85 | anonymous2021lambdanetworks, 86 | title={LambdaNetworks: Modeling long-range Interactions without Attention}, 87 | author={Anonymous}, 88 | booktitle={Submitted to International Conference on Learning Representations}, 89 | year={2021}, 90 | url={https://openreview.net/forum?id=xTJEN-ggl1b}, 91 | note={under review} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /lambda_networks/__init__.py: -------------------------------------------------------------------------------- 1 | from lambda_networks.lambda_networks import LambdaLayer 2 | λLayer = LambdaLayer -------------------------------------------------------------------------------- /lambda_networks/lambda_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | # helpers functions 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | def calc_rel_pos(n): 14 | pos = torch.meshgrid(torch.arange(n), torch.arange(n)) 15 | pos = rearrange(torch.stack(pos), 'n i j -> (i j) n') # [n*n, 2] pos[n] = (i, j) 16 | rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j) 17 | rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] 18 | return rel_pos 19 | 20 | # lambda layer 21 | 22 | class LambdaLayer(nn.Module): 23 | def __init__( 24 | self, 25 | dim, 26 | *, 27 | dim_k, 28 | n = None, 29 | r = None, 30 | heads = 4, 31 | dim_out = None, 32 | dim_u = 1): 33 | super().__init__() 34 | dim_out = default(dim_out, dim) 35 | self.u = dim_u # intra-depth dimension 36 | self.heads = heads 37 | 38 | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query' 39 | dim_v = dim_out // heads 40 | 41 | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False) 42 | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False) 43 | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False) 44 | 45 | self.norm_q = nn.BatchNorm2d(dim_k * heads) 46 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u) 47 | 48 | self.local_contexts = exists(r) 49 | if exists(r): 50 | assert (r % 2) == 1, 'Receptive kernel size should be odd' 51 | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2)) 52 | else: 53 | assert exists(n), 'You must specify the window size (n=h=w)' 54 | rel_lengths = 2 * n - 1 55 | self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u)) 56 | self.rel_pos = calc_rel_pos(n) 57 | 58 | def forward(self, x): 59 | b, c, hh, ww, u, h = *x.shape, self.u, self.heads 60 | 61 | q = self.to_q(x) 62 | k = self.to_k(x) 63 | v = self.to_v(x) 64 | 65 | q = self.norm_q(q) 66 | v = self.norm_v(v) 67 | 68 | q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h) 69 | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u) 70 | v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u) 71 | 72 | k = k.softmax(dim=-1) 73 | 74 | λc = einsum('b u k m, b u v m -> b k v', k, v) 75 | Yc = einsum('b h k n, b k v -> b h v n', q, λc) 76 | 77 | if self.local_contexts: 78 | v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww) 79 | λp = self.pos_conv(v) 80 | Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3)) 81 | else: 82 | n, m = self.rel_pos.unbind(dim = -1) 83 | rel_pos_emb = self.rel_pos_emb[n, m] 84 | λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v) 85 | Yp = einsum('b h k n, b n k v -> b h v n', q, λp) 86 | 87 | Y = Yc + Yp 88 | out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww) 89 | return out 90 | -------------------------------------------------------------------------------- /lambda_networks/tfkeras.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from einops.layers.tensorflow import Rearrange 3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer 4 | from tensorflow.keras import initializers 5 | from tensorflow import einsum, nn, meshgrid 6 | 7 | # helpers functions 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | return val if exists(val) else d 14 | 15 | def calc_rel_pos(n): 16 | pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij')) 17 | pos = Rearrange('n i j -> (i j) n')(pos) # [n*n, 2] pos[n] = (i, j) 18 | rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j) 19 | rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] 20 | return rel_pos 21 | 22 | # lambda layer 23 | 24 | class LambdaLayer(Layer): 25 | def __init__( 26 | self, 27 | *, 28 | dim_k, 29 | n = None, 30 | r = None, 31 | heads = 4, 32 | dim_out = None, 33 | dim_u = 1): 34 | super(LambdaLayer, self).__init__() 35 | 36 | self.out_dim = dim_out 37 | self.u = dim_u # intra-depth dimension 38 | self.heads = heads 39 | 40 | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query' 41 | self.dim_v = dim_out // heads 42 | self.dim_k = dim_k 43 | self.heads = heads 44 | 45 | self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False) 46 | self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False) 47 | self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False) 48 | 49 | self.norm_q = BatchNormalization() 50 | self.norm_v = BatchNormalization() 51 | 52 | self.local_contexts = exists(r) 53 | if exists(r): 54 | assert (r % 2) == 1, 'Receptive kernel size should be odd' 55 | self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same') 56 | else: 57 | assert exists(n), 'You must specify the window length (n = h = w)' 58 | rel_length = 2 * n - 1 59 | self.rel_pos_emb = self.add_weight(name='pos_emb', 60 | shape=(rel_length, rel_length, dim_k, dim_u), 61 | initializer=initializers.random_normal, 62 | trainable=True) 63 | self.rel_pos = calc_rel_pos(n) 64 | 65 | def call(self, x, **kwargs): 66 | b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads 67 | 68 | q = self.to_q(x) 69 | k = self.to_k(x) 70 | v = self.to_v(x) 71 | 72 | q = self.norm_q(q) 73 | v = self.norm_v(v) 74 | 75 | q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q) 76 | k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k) 77 | v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v) 78 | 79 | k = nn.softmax(k) 80 | 81 | Lc = einsum('b u k m, b u v m -> b k v', k, v) 82 | Yc = einsum('b h k n, b k v -> b n h v', q, Lc) 83 | 84 | if self.local_contexts: 85 | v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v) 86 | Lp = self.pos_conv(v) 87 | Lp = Rearrange('b v h w k -> b v k (h w)')(Lp) 88 | Yp = einsum('b h k n, b v k n -> b n h v', q, Lp) 89 | else: 90 | rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos) 91 | Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v) 92 | Yp = einsum('b h k n, b n k v -> b n h v', q, Lp) 93 | 94 | Y = Yc + Yp 95 | out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y) 96 | return out 97 | 98 | def compute_output_shape(self, input_shape): 99 | return (*input_shape[:2], self.out_dim) 100 | 101 | def get_config(self): 102 | config = {'output_dim': (*self.input_shape[:2], self.out_dim)} 103 | base_config = super(LambdaLayer, self).get_config() 104 | return dict(list(base_config.items()) + list(config.items())) 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'lambda-networks', 5 | packages = find_packages(), 6 | version = '0.4.0', 7 | license='MIT', 8 | description = 'Lambda Networks - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/lambda-networks', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'attention mechanism', 15 | 'image recognition' 16 | ], 17 | install_requires=[ 18 | 'torch>=1.6', 19 | 'einops>=0.3' 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) -------------------------------------------------------------------------------- /λ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/lambda-networks/8116ff3685c679efaaca4137fc0f5d7cc88b5e4b/λ.png --------------------------------------------------------------------------------