├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── aoa_pytorch
├── __init__.py
└── aoa_pytorch.py
├── saoa.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Attention on Attention - Pytorch
4 |
5 | A Pytorch implementation of the Attention on Attention module, from the paper An Improved Attention for Visual Question Answering. The repository will include both the Self and Guided (cross-attention) variants.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install aoa-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | Self Attention on Attention
16 |
17 | ```python
18 | import torch
19 | from aoa_pytorch import AoA
20 |
21 | attn = AoA(
22 | dim = 512,
23 | heads = 8
24 | )
25 |
26 | x = torch.randn(1, 1024, 512)
27 | attn(x) + x # (1, 1024, 512)
28 | ```
29 |
30 | Guided Attention on Attention
31 |
32 | ```python
33 | ```python
34 | import torch
35 | from aoa_pytorch import AoA
36 |
37 | attn = AoA(
38 | dim = 512,
39 | heads = 8
40 | )
41 |
42 | x = torch.randn(1, 1024, 512)
43 | context = torch.randn(1, 1024, 512)
44 |
45 | attn(x, context = context) + x # (1, 1024, 512)
46 | ```
47 |
48 | ## Citations
49 |
50 | ```bibtex
51 | @misc{rahman2020improved,
52 | title = {An Improved Attention for Visual Question Answering},
53 | author = {Tanzila Rahman and Shih-Han Chou and Leonid Sigal and Giuseppe Carenini},
54 | year = {2020},
55 | eprint = {2011.02164},
56 | archivePrefix = {arXiv},
57 | primaryClass = {cs.CV}
58 | }
59 | ```
60 |
61 | ```bibtex
62 | @misc{huang2019attention,
63 | title = {Attention on Attention for Image Captioning},
64 | author = {Lun Huang and Wenmin Wang and Jie Chen and Xiao-Yong Wei},
65 | year = {2019},
66 | eprint = {1908.06954},
67 | archivePrefix = {arXiv},
68 | primaryClass = {cs.CV}
69 | }
70 | ```
71 |
--------------------------------------------------------------------------------
/aoa_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from aoa_pytorch.aoa_pytorch import AttentionOnAttention
2 | AoA = AttentionOnAttention
3 |
--------------------------------------------------------------------------------
/aoa_pytorch/aoa_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 | def default(val, d):
11 | return val if exists(val) else d
12 |
13 | class AttentionOnAttention(nn.Module):
14 | def __init__(
15 | self,
16 | *,
17 | dim,
18 | dim_head = 64,
19 | heads = 8,
20 | dropout = 0.,
21 | aoa_dropout = 0.
22 | ):
23 | super().__init__()
24 | inner_dim = dim_head * heads
25 | self.heads = heads
26 | self.scale = dim_head ** -0.5
27 |
28 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
29 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
30 |
31 | self.dropout = nn.Dropout(dropout)
32 |
33 | self.aoa = nn.Sequential(
34 | nn.Linear(2 * inner_dim, 2 * dim),
35 | nn.GLU(),
36 | nn.Dropout(aoa_dropout)
37 | )
38 |
39 | def forward(self, x, context = None):
40 | h = self.heads
41 |
42 | q_ = self.to_q(x)
43 |
44 | context = default(context, x)
45 | kv = self.to_kv(context).chunk(2, dim = -1)
46 |
47 | # split heads
48 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, *kv))
49 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
50 |
51 | # attention
52 | attn = dots.softmax(dim = -1)
53 | attn = self.dropout(attn)
54 |
55 | # weighted average of values
56 | attn_out = einsum('b h i j, b h j d -> b h i d', attn, v)
57 |
58 | # concat heads
59 | out = rearrange(attn_out, 'b h n d -> b n (h d)', h = h)
60 |
61 | # attention on attention
62 | out = self.aoa(torch.cat((out, q_), dim = -1))
63 | return out
64 |
--------------------------------------------------------------------------------
/saoa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/AoA-pytorch/97d99d0fce4683fdba7b8fc05ff64aa69cdcf37a/saoa.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'aoa_pytorch',
5 | packages = find_packages(exclude=['examples']),
6 | version = '0.0.2',
7 | license='MIT',
8 | description = 'Attention on Attention - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/SAoA-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'attention mechanism',
15 | 'visual question answering'
16 | ],
17 | install_requires=[
18 | 'torch>=1.6',
19 | 'einops>=0.3'
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------