├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── compositional-attention.png
├── compositional_attention_pytorch
├── __init__.py
└── compositional_attention_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Compositional Attention - Pytorch
4 |
5 | Implementation of Compositional Attention from MILA. They reframe the "heads" of multi-head attention as "searches", and once the multi-headed/searched values are aggregated, there is an extra retrieval step (using attention) off the searched results. They then show this variant of attention yield better OOD results on a toy task. Their ESBN results still leaves a lot to be desired, but I like the general direction of the paper.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install compositional-attention-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from compositional_attention_pytorch import CompositionalAttention
18 |
19 | attn = CompositionalAttention(
20 | dim = 1024, # input dimension
21 | dim_head = 64, # dimension per attention 'head' - head is now either search or retrieval
22 | num_searches = 8, # number of searches
23 | num_retrievals = 2, # number of retrievals
24 | dropout = 0., # dropout of attention of search and retrieval
25 | )
26 |
27 | tokens = torch.randn(1, 512, 1024) # tokens
28 | mask = torch.ones((1, 512)).bool() # mask
29 |
30 | out = attn(tokens, mask = mask) # (1, 512, 1024)
31 | ```
32 |
33 | ## Citations
34 |
35 | ```bibtex
36 | @article{Mittal2021CompositionalAD,
37 | title = {Compositional Attention: Disentangling Search and Retrieval},
38 | author = {Sarthak Mittal and Sharath Chandra Raparthy and Irina Rish and Yoshua Bengio and Guillaume Lajoie},
39 | journal = {ArXiv},
40 | year = {2021},
41 | volume = {abs/2110.09419}
42 | }
43 | ```
44 |
--------------------------------------------------------------------------------
/compositional-attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/compositional-attention-pytorch/e4e59540c98b0f0da7ea8d89cfe983aeb6ddc1fd/compositional-attention.png
--------------------------------------------------------------------------------
/compositional_attention_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from compositional_attention_pytorch.compositional_attention_pytorch import CompositionalAttention
2 |
--------------------------------------------------------------------------------
/compositional_attention_pytorch/compositional_attention_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn, einsum
4 |
5 | from einops import rearrange
6 | from einops_exts import rearrange_many
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | def stable_softmax(t, dim = -1):
12 | t = t - t.amax(dim = dim, keepdim = True).detach()
13 | return t.softmax(dim = dim)
14 |
15 | class CompositionalAttention(nn.Module):
16 | def __init__(
17 | self,
18 | dim,
19 | dim_head = 64,
20 | num_searches = 8,
21 | num_retrievals = 2,
22 | dropout = 0.,
23 | prenorm = False,
24 | causal = False
25 | ):
26 | super().__init__()
27 | self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()
28 |
29 | self.scale = dim_head ** -0.5
30 | inner_search_dim = dim_head * num_searches
31 | inner_retrieval_dim = dim_head * num_retrievals
32 |
33 | self.num_searches = num_searches
34 | self.num_retrievals = num_retrievals
35 |
36 | self.to_searches_queries = nn.Linear(dim, inner_search_dim, bias = False)
37 | self.to_searches_keys = nn.Linear(dim, inner_search_dim, bias = False)
38 | self.to_retrieval_values = nn.Linear(dim, inner_retrieval_dim, bias = False)
39 |
40 | self.to_retrieval_queries = nn.Linear(dim, inner_search_dim, bias = False)
41 | self.to_retrieval_keys = nn.Linear(dim_head, dim_head, bias = False)
42 |
43 | self.to_out = nn.Linear(inner_search_dim, dim, bias = False)
44 |
45 | self.search_dropout = nn.Dropout(dropout)
46 | self.retrieval_dropout = nn.Dropout(dropout)
47 |
48 | # autoregressive variant for self-experimentation
49 | self.causal = causal
50 |
51 | def forward(self, x, mask = None):
52 | """
53 | einstein notation:
54 | b - batch
55 | n - sequence dimension
56 | i - sequence dimension (source)
57 | j - sequence dimension (target, aggregation dimension)
58 | s - number of searches
59 | r - number of retrievals
60 | d - feature dimension
61 | """
62 | x = self.norm(x)
63 |
64 | s = self.num_searches
65 | r = self.num_retrievals
66 |
67 | # get search queries and keys
68 |
69 | sq, sk = self.to_searches_queries(x), self.to_searches_keys(x)
70 | sq, sk = rearrange_many((sq, sk), 'b n (s d) -> b s n d', s = s)
71 |
72 | sq = sq * self.scale
73 |
74 | # search similarity and attention
75 |
76 | search_sim = einsum('b s i d, b s j d -> b s i j', sq, sk)
77 |
78 | if exists(mask):
79 | mask = rearrange(mask, 'b j -> b 1 1 j')
80 | search_sim = search_sim.masked_fill(~mask, -torch.finfo(search_sim.dtype).max)
81 |
82 | if self.causal:
83 | i, j = search_sim.shape[-2:]
84 | causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
85 | search_sim = search_sim.masked_fill(causal_mask, -torch.finfo(search_sim.dtype).max)
86 |
87 | search_attn = stable_softmax(search_sim, dim = -1)
88 | search_attn = self.search_dropout(search_attn)
89 |
90 | # get retrieval values
91 |
92 | rv = self.to_retrieval_values(x)
93 | rv = rearrange(rv, 'b n (r d) -> b r n d', r = r)
94 |
95 | retrieved = einsum('b s i j, b r j d -> b s r i d', search_attn, rv)
96 |
97 | # get retrieval queries and keys
98 |
99 | rq, rk = self.to_retrieval_queries(x), self.to_retrieval_keys(retrieved)
100 | rq = rearrange(rq, 'b n (s d) -> b s n d', s = s)
101 | rq = rq * self.scale
102 |
103 | # get retrieval attention
104 |
105 | retrieval_sim = einsum('b s n d , b s r n d -> b s n r', rq, rk)
106 |
107 | retrieval_attn = stable_softmax(retrieval_sim, dim = -1)
108 | retrieval_attn = self.retrieval_dropout(retrieval_attn)
109 |
110 | # aggregate retrievals
111 |
112 | out = einsum('b s n r, b s r n d -> b s n d', retrieval_attn, retrieved)
113 |
114 | # combine search results out
115 |
116 | out = rearrange(out, 'b s n d -> b n (s d)')
117 | return self.to_out(out)
118 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'compositional-attention-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.1',
7 | license='MIT',
8 | description = 'Compositional Attention - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/compositional-attention-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning',
15 | 'attention mechanism'
16 | ],
17 | install_requires=[
18 | 'einops>=0.4',
19 | 'einops-exts',
20 | 'torch>=1.6',
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------