├── adjacent_attention_network ├── __init__.py └── adjacent_attention_network.py ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── LICENSE ├── .gitignore └── README.md /adjacent_attention_network/__init__.py: -------------------------------------------------------------------------------- 1 | from adjacent_attention_network.adjacent_attention_network import AdjacentAttentionNetwork 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'adjacent-attention-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.12', 7 | license='MIT', 8 | description = 'Adjacent Attention Network - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/adjacent-attention-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'attention mechanism', 16 | 'graph neural network', 17 | 'transformers' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.3', 21 | 'torch>=1.6', 22 | 'isab-pytorch<0.2' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Adjacent Attention Network 2 | 3 | An implementation of a simple transformer that is equivalent to graph neural network where the message passing is done with multi-head attention at each successive layer. Since Graph Attention Network is already taken, I decided to name it Adjacent Attention Network instead. The design will be more transformer-centric. Instead of using the square root inverse adjacency matrix trick by Kipf and Welling, in this framework it will simply be translated to the proper attention mask at each layer. 4 | 5 | This repository is for my own exploration into the graph neural network field. My gut tells me the transformers architecture can generalize and outperform graph neural networks. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install adjacent-attention-network 11 | ``` 12 | 13 | ## Usage 14 | 15 | Basically a transformers where each node pays attention to the neighbors as defined by the adjacency matrix. Complexity is O(n * max_neighbors). Max number of neighbors as defined by the adjacency matrix. 16 | 17 | The following example will have a complexity of ~ 1024 * 100 18 | 19 | ```python 20 | import torch 21 | from adjacent_attention_network import AdjacentAttentionNetwork 22 | 23 | model = AdjacentAttentionNetwork( 24 | dim = 512, 25 | depth = 6, 26 | heads = 4 27 | ) 28 | 29 | adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1) < 0.1 30 | nodes = torch.randn(1, 1024, 512) 31 | mask = torch.ones(1, 1024).bool() 32 | 33 | model(nodes, adj_mat, mask = mask) # (1, 1024, 512) 34 | ``` 35 | 36 | If the number of neighbors contain outliers, then the above will lead to wasteful computation, since a lot of nodes will be doing attention on padding. You can use the following stop-gap measure to account for these outliers. 37 | 38 | ```python 39 | import torch 40 | from adjacent_attention_network import AdjacentAttentionNetwork 41 | 42 | model = AdjacentAttentionNetwork( 43 | dim = 512, 44 | depth = 6, 45 | heads = 4, 46 | num_neighbors_cutoff = 100 47 | ).cuda() 48 | 49 | adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1 50 | nodes = torch.randn(1, 1024, 512).cuda() 51 | mask = torch.ones(1, 1024).bool().cuda() 52 | 53 | # for some reason, one of the nodes is fully connected to all others 54 | adj_mat[:, 0] = 1. 55 | 56 | model(nodes, adj_mat, mask = mask) # (1, 1024, 512) 57 | ``` 58 | 59 | For non-local attention, I've decided to use a trick from the Set Transformers paper, the Induced Set Attention Block (ISAB). From the lens of graph neural net literature, this would be analogous as having global nodes for message passing non-locally. 60 | 61 | ```python 62 | import torch 63 | from adjacent_attention_network import AdjacentAttentionNetwork 64 | 65 | model = AdjacentAttentionNetwork( 66 | dim = 512, 67 | depth = 6, 68 | heads = 4, 69 | num_global_nodes = 5 70 | ).cuda() 71 | 72 | adj_mat = torch.empty(1, 1024, 1024).uniform_(0, 1).cuda() < 0.1 73 | nodes = torch.randn(1, 1024, 512).cuda() 74 | mask = torch.ones(1, 1024).bool().cuda() 75 | 76 | model(nodes, adj_mat, mask = mask) # (1, 1024, 512) 77 | ``` 78 | -------------------------------------------------------------------------------- /adjacent_attention_network/adjacent_attention_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | 5 | from einops import rearrange, repeat 6 | from isab_pytorch import ISAB 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def batched_index_select(values, indices): 14 | last_dim = values.shape[-1] 15 | return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim)) 16 | 17 | # helper classes 18 | 19 | class Residual(nn.Module): 20 | def __init__(self, fn): 21 | super().__init__() 22 | self.fn = fn 23 | def forward(self, x, **kwargs): 24 | return self.fn(x, **kwargs) + x 25 | 26 | class PreNorm(nn.Module): 27 | def __init__(self, dim, fn): 28 | super().__init__() 29 | self.fn = fn 30 | self.norm = nn.LayerNorm(dim) 31 | def forward(self, x, **kwargs): 32 | return self.fn(self.norm(x), **kwargs) 33 | 34 | class FeedForward(nn.Module): 35 | def __init__(self, dim, mult = 4, dropout = 0.): 36 | super().__init__() 37 | self.net = nn.Sequential( 38 | nn.Linear(dim, dim * mult), 39 | nn.GELU(), 40 | nn.Dropout(dropout), 41 | nn.Linear(dim * mult, dim) 42 | ) 43 | 44 | def forward(self, x, **kwargs): 45 | return self.net(x) 46 | 47 | # adjacent attention class 48 | 49 | class AdjacentAttention(nn.Module): 50 | def __init__( 51 | self, 52 | *, 53 | dim, 54 | dim_head = 64, 55 | heads = 4, 56 | dropout = 0. 57 | ): 58 | super().__init__() 59 | inner_dim = dim_head * heads 60 | self.scale = dim_head ** -0.5 61 | self.heads = heads 62 | 63 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 64 | self.to_out = nn.Linear(inner_dim, dim) 65 | 66 | self.null_k = nn.Parameter(torch.randn(heads, dim_head)) 67 | self.null_v = nn.Parameter(torch.randn(heads, dim_head)) 68 | 69 | self.dropout = nn.Dropout(dropout) 70 | 71 | def forward( 72 | self, 73 | x, 74 | adj_kv_indices, 75 | mask 76 | ): 77 | b, n, d, h = *x.shape, self.heads 78 | flat_indices = repeat(adj_kv_indices, 'b n a -> (b h) (n a)', h = h) 79 | 80 | # derive query, key, value 81 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 82 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 83 | 84 | # gather keys and values according to adjacency matrix 85 | k, v = map(lambda t: rearrange(t, 'b h n d -> (b h) n d'), (k, v)) 86 | k = batched_index_select(k, flat_indices) 87 | v = batched_index_select(v, flat_indices) 88 | k, v = map(lambda t: rearrange(t, '(b h) (n a) d -> b h n a d', h = h, n = n), (k, v)) 89 | 90 | # add null key / value, so a node can attend to nothing 91 | # have come across this in GNN literature as some other name 92 | nk, nv = map(lambda t: rearrange(t, 'h d -> () h () () d').expand(b, -1, n, 1, -1), (self.null_k, self.null_v)) 93 | k = torch.cat((nk, k), dim = -2) 94 | v = torch.cat((nv, v), dim = -2) 95 | mask = F.pad(mask, (1, 0), value = 1) 96 | 97 | # similarity of each node to its neighbors 98 | sim = einsum('b h n d, b h n a d -> b h n a', q, k) * self.scale 99 | 100 | # mask out neighbors that are just padding 101 | mask_value = -torch.finfo(sim.dtype).max 102 | mask = rearrange(mask.bool(), 'b n a -> b () n a') 103 | sim.masked_fill_(~mask.bool(), mask_value) 104 | 105 | # attention 106 | attn = sim.softmax(dim = -1) 107 | 108 | # dropout 109 | attn = self.dropout(attn) 110 | 111 | # get weighted average of the values of all neighbors 112 | out = einsum('b h n a, b h n a d -> b h n d', attn, v) 113 | out = rearrange(out, 'b h n d -> b n (h d)') 114 | 115 | # combine output 116 | return self.to_out(out) 117 | 118 | # adjacent network (layers of adjacent attention) 119 | 120 | class AdjacentAttentionNetwork(nn.Module): 121 | def __init__( 122 | self, 123 | *, 124 | dim, 125 | depth, 126 | dim_head = 64, 127 | heads = 4, 128 | num_neighbors_cutoff = None, 129 | num_global_nodes = 0, 130 | attn_dropout = 0., 131 | ff_dropout = 0. 132 | ): 133 | super().__init__() 134 | self.num_neighbors_cutoff = num_neighbors_cutoff 135 | self.layers = nn.ModuleList([]) 136 | 137 | for _ in range(depth): 138 | global_attn = PreNorm(dim, ISAB( 139 | dim = dim, 140 | heads = heads, 141 | num_induced_points = num_global_nodes 142 | )) if num_global_nodes > 0 else None 143 | 144 | self.layers.append(nn.ModuleList([ 145 | Residual(PreNorm(dim, AdjacentAttention( 146 | dim = dim, 147 | dim_head = dim_head, 148 | heads = heads, 149 | dropout = attn_dropout 150 | ))), 151 | global_attn, 152 | Residual(PreNorm(dim, FeedForward( 153 | dim = dim, 154 | dropout = ff_dropout 155 | ))) 156 | ])) 157 | 158 | def forward(self, x, adjacency_mat, mask = None): 159 | device, n = x.device, x.shape[1] 160 | 161 | diag = torch.eye(adjacency_mat.shape[-1], device = device).bool() 162 | adjacency_mat |= diag # nodes should pay attention itself (self-interacting) 163 | 164 | # zero out points on adjacency matrix 165 | # where the nodes are just padding 166 | if exists(mask): 167 | adjacency_mat &= (mask[:, :, None] * mask[:, None, :]) 168 | 169 | adj_mat = adjacency_mat.float() 170 | 171 | # if we don't set a hard limit to the number of neighbors: 172 | # - get the maximum number of neighbors and pad the rest of the nodes with less than that number of neighbors 173 | # else: 174 | # - randomly sample the cutoff number of neighbors for any node that exceeds the max 175 | # - this would be similar to random sparse attention (bigbird) 176 | 177 | # get the maximum number of neighbors 178 | max_neighbors = int(adj_mat.sum(dim = -1).max()) 179 | 180 | if exists(self.num_neighbors_cutoff) and max_neighbors > self.num_neighbors_cutoff: 181 | # to randomly sample the neighbors, add a small uniform noise to the mask and topk 182 | noise = torch.empty((n, n), device = device).uniform_(-0.01, 0.01) 183 | adj_mat = adj_mat + noise 184 | 185 | adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = self.num_neighbors_cutoff) 186 | 187 | # cast the mask back to 0s and 1s 188 | adj_mask = (adj_mask > 0.5).float() 189 | else: 190 | # todo - get distribution of number of neighbors, and strategically break up attention (message passing) to multiple steps 191 | # - start with a bimodal num neighbors test case, then generalize 192 | 193 | # use topk to get all the neighbors 194 | # also pass the mask into the attention, as some neighbors will be just padding and not actually neighbors 195 | adj_mask, adj_kv_indices = adj_mat.topk(dim = -1, k = max_neighbors) 196 | 197 | 198 | for attn, global_attn, ff in self.layers: 199 | x = attn( 200 | x, 201 | adj_kv_indices = adj_kv_indices, 202 | mask = adj_mask 203 | ) 204 | 205 | if exists(global_attn): 206 | out, _ = global_attn(x, mask = mask) 207 | x = x + out 208 | 209 | x = ff(x) 210 | 211 | return x 212 | --------------------------------------------------------------------------------