├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── lambda_networks
├── __init__.py
├── lambda_networks.py
└── tfkeras.py
├── setup.py
└── λ.png
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Lambda Networks - Pytorch
4 |
5 | Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.
6 |
7 | Yannic Kilcher's paper review
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install lambda-networks
13 | ```
14 |
15 | ## Usage
16 |
17 | Global context
18 |
19 | ```python
20 | import torch
21 | from lambda_networks import LambdaLayer
22 |
23 | layer = LambdaLayer(
24 | dim = 32, # channels going in
25 | dim_out = 32, # channels out
26 | n = 64, # size of the receptive window - max(height, width)
27 | dim_k = 16, # key dimension
28 | heads = 4, # number of heads, for multi-query
29 | dim_u = 1 # 'intra-depth' dimension
30 | )
31 |
32 | x = torch.randn(1, 32, 64, 64)
33 | layer(x) # (1, 32, 64, 64)
34 | ```
35 |
36 | Localized context
37 |
38 | ```python
39 | import torch
40 | from lambda_networks import LambdaLayer
41 |
42 | layer = LambdaLayer(
43 | dim = 32,
44 | dim_out = 32,
45 | r = 23, # the receptive field for relative positional encoding (23 x 23)
46 | dim_k = 16,
47 | heads = 4,
48 | dim_u = 4
49 | )
50 |
51 | x = torch.randn(1, 32, 64, 64)
52 | layer(x) # (1, 32, 64, 64)
53 | ```
54 |
55 | For fun, you can also import this as follows
56 |
57 | ```python
58 | from lambda_networks import λLayer
59 | ```
60 |
61 | ## Tensorflow / Keras version
62 |
63 | Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following.
64 |
65 | ```python
66 | import tensorflow as tf
67 | from lambda_networks.tfkeras import LambdaLayer
68 |
69 | layer = LambdaLayer(
70 | dim_out = 32,
71 | r = 23,
72 | dim_k = 16,
73 | heads = 4,
74 | dim_u = 1
75 | )
76 |
77 | x = tf.random.normal((1, 64, 64, 16)) # channel last format
78 | layer(x) # (1, 64, 64, 32)
79 | ```
80 |
81 | ## Citations
82 |
83 | ```bibtex
84 | @inproceedings{
85 | anonymous2021lambdanetworks,
86 | title={LambdaNetworks: Modeling long-range Interactions without Attention},
87 | author={Anonymous},
88 | booktitle={Submitted to International Conference on Learning Representations},
89 | year={2021},
90 | url={https://openreview.net/forum?id=xTJEN-ggl1b},
91 | note={under review}
92 | }
93 | ```
94 |
--------------------------------------------------------------------------------
/lambda_networks/__init__.py:
--------------------------------------------------------------------------------
1 | from lambda_networks.lambda_networks import LambdaLayer
2 | λLayer = LambdaLayer
--------------------------------------------------------------------------------
/lambda_networks/lambda_networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 | # helpers functions
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 | def default(val, d):
11 | return val if exists(val) else d
12 |
13 | def calc_rel_pos(n):
14 | pos = torch.meshgrid(torch.arange(n), torch.arange(n))
15 | pos = rearrange(torch.stack(pos), 'n i j -> (i j) n') # [n*n, 2] pos[n] = (i, j)
16 | rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
17 | rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
18 | return rel_pos
19 |
20 | # lambda layer
21 |
22 | class LambdaLayer(nn.Module):
23 | def __init__(
24 | self,
25 | dim,
26 | *,
27 | dim_k,
28 | n = None,
29 | r = None,
30 | heads = 4,
31 | dim_out = None,
32 | dim_u = 1):
33 | super().__init__()
34 | dim_out = default(dim_out, dim)
35 | self.u = dim_u # intra-depth dimension
36 | self.heads = heads
37 |
38 | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
39 | dim_v = dim_out // heads
40 |
41 | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
42 | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
43 | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
44 |
45 | self.norm_q = nn.BatchNorm2d(dim_k * heads)
46 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
47 |
48 | self.local_contexts = exists(r)
49 | if exists(r):
50 | assert (r % 2) == 1, 'Receptive kernel size should be odd'
51 | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
52 | else:
53 | assert exists(n), 'You must specify the window size (n=h=w)'
54 | rel_lengths = 2 * n - 1
55 | self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
56 | self.rel_pos = calc_rel_pos(n)
57 |
58 | def forward(self, x):
59 | b, c, hh, ww, u, h = *x.shape, self.u, self.heads
60 |
61 | q = self.to_q(x)
62 | k = self.to_k(x)
63 | v = self.to_v(x)
64 |
65 | q = self.norm_q(q)
66 | v = self.norm_v(v)
67 |
68 | q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
69 | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
70 | v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
71 |
72 | k = k.softmax(dim=-1)
73 |
74 | λc = einsum('b u k m, b u v m -> b k v', k, v)
75 | Yc = einsum('b h k n, b k v -> b h v n', q, λc)
76 |
77 | if self.local_contexts:
78 | v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
79 | λp = self.pos_conv(v)
80 | Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
81 | else:
82 | n, m = self.rel_pos.unbind(dim = -1)
83 | rel_pos_emb = self.rel_pos_emb[n, m]
84 | λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
85 | Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
86 |
87 | Y = Yc + Yp
88 | out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
89 | return out
90 |
--------------------------------------------------------------------------------
/lambda_networks/tfkeras.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from einops.layers.tensorflow import Rearrange
3 | from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
4 | from tensorflow.keras import initializers
5 | from tensorflow import einsum, nn, meshgrid
6 |
7 | # helpers functions
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def default(val, d):
13 | return val if exists(val) else d
14 |
15 | def calc_rel_pos(n):
16 | pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij'))
17 | pos = Rearrange('n i j -> (i j) n')(pos) # [n*n, 2] pos[n] = (i, j)
18 | rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
19 | rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
20 | return rel_pos
21 |
22 | # lambda layer
23 |
24 | class LambdaLayer(Layer):
25 | def __init__(
26 | self,
27 | *,
28 | dim_k,
29 | n = None,
30 | r = None,
31 | heads = 4,
32 | dim_out = None,
33 | dim_u = 1):
34 | super(LambdaLayer, self).__init__()
35 |
36 | self.out_dim = dim_out
37 | self.u = dim_u # intra-depth dimension
38 | self.heads = heads
39 |
40 | assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
41 | self.dim_v = dim_out // heads
42 | self.dim_k = dim_k
43 | self.heads = heads
44 |
45 | self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False)
46 | self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False)
47 | self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False)
48 |
49 | self.norm_q = BatchNormalization()
50 | self.norm_v = BatchNormalization()
51 |
52 | self.local_contexts = exists(r)
53 | if exists(r):
54 | assert (r % 2) == 1, 'Receptive kernel size should be odd'
55 | self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
56 | else:
57 | assert exists(n), 'You must specify the window length (n = h = w)'
58 | rel_length = 2 * n - 1
59 | self.rel_pos_emb = self.add_weight(name='pos_emb',
60 | shape=(rel_length, rel_length, dim_k, dim_u),
61 | initializer=initializers.random_normal,
62 | trainable=True)
63 | self.rel_pos = calc_rel_pos(n)
64 |
65 | def call(self, x, **kwargs):
66 | b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads
67 |
68 | q = self.to_q(x)
69 | k = self.to_k(x)
70 | v = self.to_v(x)
71 |
72 | q = self.norm_q(q)
73 | v = self.norm_v(v)
74 |
75 | q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
76 | k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
77 | v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)
78 |
79 | k = nn.softmax(k)
80 |
81 | Lc = einsum('b u k m, b u v m -> b k v', k, v)
82 | Yc = einsum('b h k n, b k v -> b n h v', q, Lc)
83 |
84 | if self.local_contexts:
85 | v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
86 | Lp = self.pos_conv(v)
87 | Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
88 | Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
89 | else:
90 | rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos)
91 | Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
92 | Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)
93 |
94 | Y = Yc + Yp
95 | out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
96 | return out
97 |
98 | def compute_output_shape(self, input_shape):
99 | return (*input_shape[:2], self.out_dim)
100 |
101 | def get_config(self):
102 | config = {'output_dim': (*self.input_shape[:2], self.out_dim)}
103 | base_config = super(LambdaLayer, self).get_config()
104 | return dict(list(base_config.items()) + list(config.items()))
105 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'lambda-networks',
5 | packages = find_packages(),
6 | version = '0.4.0',
7 | license='MIT',
8 | description = 'Lambda Networks - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/lambda-networks',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'image recognition'
16 | ],
17 | install_requires=[
18 | 'torch>=1.6',
19 | 'einops>=0.3'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
--------------------------------------------------------------------------------
/λ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/lambda-networks/8116ff3685c679efaaca4137fc0f5d7cc88b5e4b/λ.png
--------------------------------------------------------------------------------