├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── agent-attention.png
├── agent_attention_pytorch
├── __init__.py
├── agent_attention_pytorch.py
└── agent_transformer.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Agent Attention - Pytorch
4 |
5 | Implementation of Agent Attention in Pytorch.
6 |
7 | This work seems to be an elegant simplification of `ISAB` architecture from the Set Transformers paper (requires only one attention block rather than two). While ISAB works, I have found it to be a bit unstable, thus wondering if the simplification in this work resolves that issue.
8 |
9 | This repository will add support for variable sequence lengths (masking) and post-softmax talking heads.
10 |
11 | ## Appreciation
12 |
13 | - A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install agent-attention-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | import torch
25 | from agent_attention_pytorch import AgentSelfAttention
26 |
27 | attn = AgentSelfAttention(
28 | dim = 512,
29 | num_agent_tokens = 256, # number of "agent" tokens
30 | dim_head = 64, # attention head dimension
31 | heads = 8 # number of heads
32 | )
33 |
34 | x = torch.randn(3, 65536, 512)
35 | mask = torch.ones(3, 65536).bool()
36 |
37 | out = attn(x, mask = mask)
38 |
39 | assert out.shape == x.shape
40 | ```
41 |
42 | For a full fledged linear transformer based on agent tokens, just import `AgentTransformer`
43 |
44 | ```python
45 | import torch
46 | from agent_attention_pytorch import AgentTransformer
47 |
48 | transformer = AgentTransformer(
49 | dim = 512,
50 | depth = 6,
51 | num_agent_tokens = 128,
52 | dim_head = 64,
53 | heads = 8
54 | )
55 |
56 | x = torch.randn(3, 65536, 512)
57 | mask = torch.ones(3, 65536).bool()
58 |
59 | out, agent_tokens = transformer(x, mask = mask, return_agent_tokens = True)
60 |
61 | # (3, 65536, 512), (3, 128, 512)
62 | assert out.shape == x.shape
63 | ```
64 |
65 | ## Citations
66 |
67 | ```bibtex
68 | @inproceedings{Han2023AgentAO,
69 | title = {Agent Attention: On the Integration of Softmax and Linear Attention},
70 | author = {Dongchen Han and Tianzhu Ye and Yizeng Han and Zhuofan Xia and Shiji Song and Gao Huang},
71 | year = {2023},
72 | url = {https://api.semanticscholar.org/CorpusID:266210414}
73 | }
74 | ```
75 |
76 | ```bibtex
77 | @misc{shazeer2020talkingheads,
78 | title = {Talking-Heads Attention},
79 | author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
80 | year = {2020},
81 | eprint = {2003.02436},
82 | archivePrefix = {arXiv},
83 | primaryClass = {cs.LG}
84 | }
85 | ```
86 |
87 | ```bibtex
88 | @article{Bondarenko2023QuantizableTR,
89 | title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
90 | author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
91 | journal = {ArXiv},
92 | year = {2023},
93 | volume = {abs/2306.12929},
94 | url = {https://api.semanticscholar.org/CorpusID:259224568}
95 | }
96 | ```
97 |
98 | ```bibtex
99 | @article{Wang2022FoundationT,
100 | title = {Foundation Transformers},
101 | author = {Hongyu Wang and Shuming Ma and Shaohan Huang and Li Dong and Wenhui Wang and Zhiliang Peng and Yu Wu and Payal Bajaj and Saksham Singhal and Alon Benhaim and Barun Patra and Zhun Liu and Vishrav Chaudhary and Xia Song and Furu Wei},
102 | journal = {ArXiv},
103 | year = {2022},
104 | volume = {abs/2210.06423},
105 | url = {https://api.semanticscholar.org/CorpusID:252846241}
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/agent-attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/agent-attention-pytorch/091f6d447d5005cd8c8e16843688b5cc2a9b8cd2/agent-attention.png
--------------------------------------------------------------------------------
/agent_attention_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from agent_attention_pytorch.agent_attention_pytorch import (
2 | AgentSelfAttention
3 | )
4 |
5 | from agent_attention_pytorch.agent_transformer import (
6 | AgentTransformer
7 | )
8 |
--------------------------------------------------------------------------------
/agent_attention_pytorch/agent_attention_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 | from torch import nn, einsum, Tensor
4 |
5 | from einops import rearrange, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 | # functions
9 |
10 | def exists(v):
11 | return v is not None
12 |
13 | # main class
14 |
15 | class AgentSelfAttention(Module):
16 | def __init__(
17 | self,
18 | dim,
19 | *,
20 | num_agent_tokens,
21 | dim_head = 64,
22 | heads = 8,
23 | dropout = 0.,
24 | talking_heads = True,
25 | gate = True,
26 | combine_agent_tokens = False
27 | ):
28 | super().__init__()
29 | self.scale = dim_head ** -0.5
30 | dim_inner = dim_head * heads
31 |
32 | self.to_qkv = nn.Sequential(
33 | nn.Linear(dim, dim_inner * 3, bias = False),
34 | Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
35 | )
36 |
37 | self.to_gates = nn.Sequential(
38 | nn.Linear(dim, heads),
39 | Rearrange('b n h -> b h n 1'),
40 | nn.Sigmoid()
41 | ) if gate else None
42 |
43 | self.agent_tokens = nn.Parameter(torch.zeros(heads, num_agent_tokens, dim_head))
44 | nn.init.normal_(self.agent_tokens, std = 0.02)
45 |
46 | self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
47 | self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
48 |
49 | self.qa_dropout = nn.Dropout(dropout)
50 | self.ak_dropout = nn.Dropout(dropout)
51 |
52 | self.to_out = nn.Sequential(
53 | Rearrange('b h n d -> b n (h d)'),
54 | nn.Linear(dim_inner, dim, bias = False)
55 | )
56 |
57 | def forward(
58 | self,
59 | x,
60 | mask = None,
61 | agent_tokens = None,
62 | return_agent_tokens = False
63 | ):
64 | batch = x.shape[0]
65 |
66 | q, k, v = self.to_qkv(x)
67 |
68 | if exists(agent_tokens):
69 | a = agent_tokens
70 | else:
71 | a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch)
72 |
73 | a = a * self.scale
74 |
75 | qa_sim = einsum('b h i d, b h j d -> b h i j', q, a)
76 | ak_sim = einsum('b h i d, b h j d -> b h i j', a, k)
77 |
78 | if exists(mask):
79 | max_neg_value = -torch.finfo(qa_sim.dtype).max
80 | ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
81 |
82 | qa_attn = qa_sim.softmax(dim = -1)
83 | ak_attn = ak_sim.softmax(dim = -1)
84 |
85 | qa_attn = self.qa_dropout(qa_attn)
86 | ak_attn = self.ak_dropout(ak_attn)
87 |
88 | qa_attn = self.qa_talking_heads(qa_attn)
89 | ak_attn = self.ak_talking_heads(ak_attn)
90 |
91 | agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
92 |
93 | out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens)
94 |
95 | if exists(mask):
96 | out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
97 |
98 | if exists(self.to_gates):
99 | out = out * self.to_gates(x)
100 |
101 | out = self.to_out(out)
102 |
103 | if not return_agent_tokens:
104 | return out
105 |
106 | return out, agent_gathered_tokens
107 |
--------------------------------------------------------------------------------
/agent_attention_pytorch/agent_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Module, ModuleList
4 | from torch import nn, einsum, Tensor
5 |
6 | from einops import rearrange, repeat, pack, unpack
7 | from einops.layers.torch import Rearrange
8 |
9 | # functions
10 |
11 | def exists(v):
12 | return v is not None
13 |
14 | # norm
15 |
16 | class RMSNorm(Module):
17 | def __init__(self, dim):
18 | super().__init__()
19 | self.scale = dim ** 0.5
20 | self.gamma = nn.Parameter(torch.ones(dim))
21 |
22 | def forward(self, x):
23 | return F.normalize(x, dim = -1) * self.scale * self.gamma
24 |
25 | # feedforward
26 |
27 | def FeedForward(dim, mult = 4):
28 | dim_inner = int(dim * mult)
29 | return nn.Sequential(
30 | RMSNorm(dim),
31 | nn.Linear(dim, dim_inner),
32 | nn.GELU(),
33 | nn.Linear(dim_inner, dim)
34 | )
35 |
36 | # main class
37 |
38 | class AgentSelfAttention(Module):
39 | def __init__(
40 | self,
41 | dim,
42 | *,
43 | num_agent_tokens,
44 | dim_head = 64,
45 | heads = 8,
46 | dropout = 0.,
47 | talking_heads = True,
48 | gate = True,
49 | sub_layernorm = False
50 | ):
51 | super().__init__()
52 | self.scale = dim_head ** -0.5
53 | dim_inner = dim_head * heads
54 |
55 | self.norm = RMSNorm(dim)
56 |
57 | self.to_qkv = nn.Sequential(
58 | nn.Linear(dim, dim_inner * 3, bias = False),
59 | Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
60 | )
61 |
62 | self.to_gates = nn.Sequential(
63 | nn.Linear(dim, heads),
64 | Rearrange('b n h -> b h n 1'),
65 | nn.Sigmoid()
66 | ) if gate else None
67 |
68 | self.qa_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
69 | self.ak_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
70 |
71 | self.qa_dropout = nn.Dropout(dropout)
72 | self.ak_dropout = nn.Dropout(dropout)
73 |
74 | self.to_agent_out = nn.Sequential(
75 | nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
76 | Rearrange('b h n d -> b n (h d)'),
77 | nn.Linear(dim_inner, dim, bias = False)
78 | )
79 |
80 | self.to_out = nn.Sequential(
81 | nn.LayerNorm(dim_head) if sub_layernorm else nn.Identity(),
82 | Rearrange('b h n d -> b n (h d)'),
83 | nn.Linear(dim_inner, dim, bias = False)
84 | )
85 |
86 | def forward(
87 | self,
88 | x,
89 | *,
90 | agent_tokens,
91 | mask = None,
92 | return_agent_tokens = False
93 | ):
94 | x = self.norm(x)
95 | a = self.norm(agent_tokens)
96 |
97 | x_and_agents, xa_ps = pack([a, x], 'b * d')
98 | qkv = self.to_qkv(x_and_agents)
99 |
100 | qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d')
101 |
102 | q, k, v = qkv_input
103 | agent_queries, agent_keys, _ = qkv_agent
104 |
105 | q = q * self.scale
106 | agent_queries = agent_queries * self.scale
107 |
108 | qa_sim = einsum('b h i d, b h j d -> b h i j', q, agent_keys)
109 | ak_sim = einsum('b h i d, b h j d -> b h i j', agent_queries, k)
110 |
111 | if exists(mask):
112 | max_neg_value = -torch.finfo(qa_sim.dtype).max
113 | ak_sim = ak_sim.masked_fill(~rearrange(mask, 'b j -> b 1 1 j'), max_neg_value)
114 |
115 | qa_attn = qa_sim.softmax(dim = -1)
116 | ak_attn = ak_sim.softmax(dim = -1)
117 |
118 | qa_attn = self.qa_dropout(qa_attn)
119 | ak_attn = self.ak_dropout(ak_attn)
120 |
121 | qa_attn = self.qa_talking_heads(qa_attn)
122 | ak_attn = self.ak_talking_heads(ak_attn)
123 |
124 | agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
125 |
126 | out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out)
127 |
128 | if exists(mask):
129 | out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)
130 |
131 | if exists(self.to_gates):
132 | out = out * self.to_gates(x)
133 | agent_out = agent_out * self.to_gates(a)
134 |
135 | out = self.to_out(out)
136 | agent_out = self.to_agent_out(agent_out)
137 |
138 | if not return_agent_tokens:
139 | return out
140 |
141 | return out, agent_out
142 |
143 | # transformer
144 |
145 | class AgentTransformer(Module):
146 | def __init__(
147 | self,
148 | dim,
149 | *,
150 | num_agent_tokens,
151 | depth,
152 | heads = 8,
153 | dim_head = 64,
154 | ff_mult = 4,
155 | final_norm = True,
156 | **attn_kwargs: dict
157 | ):
158 | super().__init__()
159 |
160 | self.agent_tokens = nn.Parameter(torch.zeros(num_agent_tokens, dim))
161 | nn.init.normal_(self.agent_tokens, std = 0.02)
162 |
163 | self.layers = ModuleList([])
164 |
165 | for _ in range(depth):
166 | self.layers.append(ModuleList([
167 | AgentSelfAttention(
168 | dim = dim,
169 | heads = heads,
170 | dim_head = dim_head,
171 | num_agent_tokens = num_agent_tokens,
172 | **attn_kwargs
173 | ),
174 | FeedForward(dim = dim, mult = ff_mult)
175 | ]))
176 |
177 | self.final_norm = RMSNorm(dim) if final_norm else None
178 |
179 | def forward(
180 | self,
181 | x,
182 | mask = None,
183 | return_agent_tokens = False
184 | ):
185 | batch = x.shape[0]
186 | a = repeat(self.agent_tokens, 'm d -> b m d', b = batch)
187 |
188 | for attn, ff in self.layers:
189 | attn_out, agent_out = attn(
190 | x,
191 | agent_tokens = a,
192 | mask = mask,
193 | return_agent_tokens = True
194 | )
195 |
196 | a = a + agent_out
197 | x = x + attn_out
198 |
199 | x, ps = pack([a, x], 'b * d')
200 |
201 | x = ff(x) + x
202 |
203 | a, x = unpack(x, ps, 'b * d')
204 |
205 | if exists(self.final_norm):
206 | x = self.final_norm(x)
207 | a = self.final_norm(a)
208 |
209 | if not return_agent_tokens:
210 | return x
211 |
212 | return x, a
213 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'agent-attention-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.7',
7 | license='MIT',
8 | description = 'Agent Attention - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/agent-attention-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'attention',
17 | 'linear attention'
18 | ],
19 | install_requires=[
20 | 'einops>=0.7.0',
21 | 'torch>=2.0'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.6',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------