├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── denoise_sparse.py ├── egnn.png ├── egnn_pytorch ├── __init__.py ├── egnn_pytorch.py ├── egnn_pytorch_geometric.py └── utils.py ├── examples └── egnn_test.ipynb ├── setup.cfg ├── setup.py └── tests └── test_equivariance.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # mac files and custom data for examples 10 | *.DS_Store 11 | .DS_Store 12 | *sidechainnet_data/* 13 | *.pkl 14 | *tests/custom_tests.py 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang, Eric Alcaide 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. 🙏 ** 4 | 5 | ## EGNN - Pytorch 6 | 7 | Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install egnn-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from egnn_pytorch import EGNN 20 | 21 | layer1 = EGNN(dim = 512) 22 | layer2 = EGNN(dim = 512) 23 | 24 | feats = torch.randn(1, 16, 512) 25 | coors = torch.randn(1, 16, 3) 26 | 27 | feats, coors = layer1(feats, coors) 28 | feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3) 29 | ``` 30 | 31 | With edges 32 | 33 | ```python 34 | import torch 35 | from egnn_pytorch import EGNN 36 | 37 | layer1 = EGNN(dim = 512, edge_dim = 4) 38 | layer2 = EGNN(dim = 512, edge_dim = 4) 39 | 40 | feats = torch.randn(1, 16, 512) 41 | coors = torch.randn(1, 16, 3) 42 | edges = torch.randn(1, 16, 16, 4) 43 | 44 | feats, coors = layer1(feats, coors, edges) 45 | feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3) 46 | ``` 47 | 48 | A full EGNN network 49 | 50 | ```python 51 | import torch 52 | from egnn_pytorch import EGNN_Network 53 | 54 | net = EGNN_Network( 55 | num_tokens = 21, 56 | num_positions = 1024, # unless what you are passing in is an unordered set, set this to the maximum sequence length 57 | dim = 32, 58 | depth = 3, 59 | num_nearest_neighbors = 8, 60 | coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors 61 | ) 62 | 63 | feats = torch.randint(0, 21, (1, 1024)) # (1, 1024) 64 | coors = torch.randn(1, 1024, 3) # (1, 1024, 3) 65 | mask = torch.ones_like(feats).bool() # (1, 1024) 66 | 67 | feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3) 68 | ``` 69 | 70 | Only attend to sparse neighbors, given to the network as an adjacency matrix. 71 | 72 | ```python 73 | import torch 74 | from egnn_pytorch import EGNN_Network 75 | 76 | net = EGNN_Network( 77 | num_tokens = 21, 78 | dim = 32, 79 | depth = 3, 80 | only_sparse_neighbors = True 81 | ) 82 | 83 | feats = torch.randint(0, 21, (1, 1024)) 84 | coors = torch.randn(1, 1024, 3) 85 | mask = torch.ones_like(feats).bool() 86 | 87 | # naive adjacency matrix 88 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024) 89 | i = torch.arange(1024) 90 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) 91 | 92 | feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) 93 | ``` 94 | 95 | You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments 96 | 97 | ```python 98 | import torch 99 | from egnn_pytorch import EGNN_Network 100 | 101 | net = EGNN_Network( 102 | num_tokens = 21, 103 | dim = 32, 104 | depth = 3, 105 | num_adj_degrees = 3, # fetch up to 3rd degree neighbors 106 | adj_dim = 8, # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP 107 | only_sparse_neighbors = True 108 | ) 109 | 110 | feats = torch.randint(0, 21, (1, 1024)) 111 | coors = torch.randn(1, 1024, 3) 112 | mask = torch.ones_like(feats).bool() 113 | 114 | # naive adjacency matrix 115 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024) 116 | i = torch.arange(1024) 117 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) 118 | 119 | feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) 120 | ``` 121 | 122 | ## Edges 123 | 124 | If you need to pass in continuous edges 125 | 126 | ```python 127 | import torch 128 | from egnn_pytorch import EGNN_Network 129 | 130 | net = EGNN_Network( 131 | num_tokens = 21, 132 | dim = 32, 133 | depth = 3, 134 | edge_dim = 4, 135 | num_nearest_neighbors = 3 136 | ) 137 | 138 | feats = torch.randint(0, 21, (1, 1024)) 139 | coors = torch.randn(1, 1024, 3) 140 | mask = torch.ones_like(feats).bool() 141 | 142 | continuous_edges = torch.randn(1, 1024, 1024, 4) 143 | 144 | # naive adjacency matrix 145 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024) 146 | i = torch.arange(1024) 147 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) 148 | 149 | feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3) 150 | ``` 151 | 152 | ## Stability 153 | 154 | The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this. 155 | 156 | ```python 157 | import torch 158 | from egnn_pytorch import EGNN_Network 159 | 160 | net = EGNN_Network( 161 | num_tokens = 21, 162 | dim = 32, 163 | depth = 3, 164 | num_nearest_neighbors = 32, 165 | norm_coors = True, # normalize the relative coordinates 166 | coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors 167 | ) 168 | 169 | feats = torch.randint(0, 21, (1, 1024)) # (1, 1024) 170 | coors = torch.randn(1, 1024, 3) # (1, 1024, 3) 171 | mask = torch.ones_like(feats).bool() # (1, 1024) 172 | 173 | feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3) 174 | ``` 175 | 176 | ## All parameters 177 | 178 | ```python 179 | import torch 180 | from egnn_pytorch import EGNN 181 | 182 | model = EGNN( 183 | dim = dim, # input dimension 184 | edge_dim = 0, # dimension of the edges, if exists, should be > 0 185 | m_dim = 16, # hidden model dimension 186 | fourier_features = 0, # number of fourier features for encoding of relative distance - defaults to none as in paper 187 | num_nearest_neighbors = 0, # cap the number of neighbors doing message passing by relative distance 188 | dropout = 0.0, # dropout 189 | norm_feats = False, # whether to layernorm the features 190 | norm_coors = False, # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper 191 | update_feats = True, # whether to update features - you can build a layer that only updates one or the other 192 | update_coors = True, # whether ot update coordinates 193 | only_sparse_neighbors = False, # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in 194 | valid_radius = float('inf'), # the valid radius each node considers for message passing 195 | m_pool_method = 'sum', # whether to mean or sum pool for output node representation 196 | soft_edges = False, # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper 197 | coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes 198 | ) 199 | 200 | ``` 201 | 202 | ## Examples 203 | 204 | To run the protein backbone denoising example, first install `sidechainnet` 205 | 206 | ```bash 207 | $ pip install sidechainnet 208 | ``` 209 | 210 | Then 211 | 212 | ```bash 213 | $ python denoise_sparse.py 214 | ``` 215 | 216 | ## Tests 217 | 218 | Make sure you have pytorch geometric installed locally 219 | 220 | ```bash 221 | $ python setup.py test 222 | ``` 223 | 224 | ## Citations 225 | 226 | ```bibtex 227 | @misc{satorras2021en, 228 | title = {E(n) Equivariant Graph Neural Networks}, 229 | author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling}, 230 | year = {2021}, 231 | eprint = {2102.09844}, 232 | archivePrefix = {arXiv}, 233 | primaryClass = {cs.LG} 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /denoise_sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.optim import Adam 5 | 6 | from einops import rearrange, repeat 7 | 8 | import sidechainnet as scn 9 | from egnn_pytorch.egnn_pytorch import EGNN_Network 10 | 11 | torch.set_default_dtype(torch.float64) 12 | 13 | BATCH_SIZE = 1 14 | GRADIENT_ACCUMULATE_EVERY = 16 15 | 16 | def cycle(loader, len_thres = 200): 17 | while True: 18 | for data in loader: 19 | if data.seqs.shape[1] > len_thres: 20 | continue 21 | yield data 22 | 23 | net = EGNN_Network( 24 | num_tokens = 21, 25 | num_positions = 200 * 3, # maximum number of positions - absolute positional embedding since there is inherent order in the sequence 26 | depth = 5, 27 | dim = 8, 28 | num_nearest_neighbors = 16, 29 | fourier_features = 2, 30 | norm_coors = True, 31 | coor_weights_clamp_value = 2. 32 | ).cuda() 33 | 34 | data = scn.load( 35 | casp_version = 12, 36 | thinning = 30, 37 | with_pytorch = 'dataloaders', 38 | batch_size = BATCH_SIZE, 39 | dynamic_batching = False 40 | ) 41 | 42 | dl = cycle(data['train']) 43 | optim = Adam(net.parameters(), lr=1e-3) 44 | 45 | for _ in range(10000): 46 | for _ in range(GRADIENT_ACCUMULATE_EVERY): 47 | batch = next(dl) 48 | seqs, coords, masks = batch.seqs, batch.crds, batch.msks 49 | 50 | seqs = seqs.cuda().argmax(dim = -1) 51 | coords = coords.cuda().type(torch.float64) 52 | masks = masks.cuda().bool() 53 | 54 | l = seqs.shape[1] 55 | coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14) 56 | 57 | # Keeping only the backbone coordinates 58 | 59 | coords = coords[:, :, 0:3, :] 60 | coords = rearrange(coords, 'b l s c -> b (l s) c') 61 | 62 | seq = repeat(seqs, 'b n -> b (n c)', c = 3) 63 | masks = repeat(masks, 'b n -> b (n c)', c = 3) 64 | 65 | i = torch.arange(seq.shape[-1], device = seq.device) 66 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1)) 67 | 68 | noised_coords = coords + torch.randn_like(coords) 69 | 70 | feats, denoised_coords = net(seq, noised_coords, adj_mat = adj_mat, mask = masks) 71 | 72 | loss = F.mse_loss(denoised_coords[masks], coords[masks]) 73 | 74 | (loss / GRADIENT_ACCUMULATE_EVERY).backward() 75 | 76 | print('loss:', loss.item()) 77 | optim.step() 78 | optim.zero_grad() 79 | -------------------------------------------------------------------------------- /egnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/egnn-pytorch/d3ba97d5b2540085cf8e02c838ebf5e42071697f/egnn.png -------------------------------------------------------------------------------- /egnn_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from egnn_pytorch.egnn_pytorch import EGNN, EGNN_Network 2 | from egnn_pytorch.egnn_pytorch_geometric import EGNN_Sparse, EGNN_Sparse_Network 3 | -------------------------------------------------------------------------------- /egnn_pytorch/egnn_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum, broadcast_tensors 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # helper functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def safe_div(num, den, eps = 1e-8): 14 | res = num.div(den.clamp(min = eps)) 15 | res.masked_fill_(den == 0, 0.) 16 | return res 17 | 18 | def batched_index_select(values, indices, dim = 1): 19 | value_dims = values.shape[(dim + 1):] 20 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) 21 | indices = indices[(..., *((None,) * len(value_dims)))] 22 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) 23 | value_expand_len = len(indices_shape) - (dim + 1) 24 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] 25 | 26 | value_expand_shape = [-1] * len(values.shape) 27 | expand_slice = slice(dim, (dim + value_expand_len)) 28 | value_expand_shape[expand_slice] = indices.shape[expand_slice] 29 | values = values.expand(*value_expand_shape) 30 | 31 | dim += value_expand_len 32 | return values.gather(dim, indices) 33 | 34 | def fourier_encode_dist(x, num_encodings = 4, include_self = True): 35 | x = x.unsqueeze(-1) 36 | device, dtype, orig_x = x.device, x.dtype, x 37 | scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype) 38 | x = x / scales 39 | x = torch.cat([x.sin(), x.cos()], dim=-1) 40 | x = torch.cat((x, orig_x), dim = -1) if include_self else x 41 | return x 42 | 43 | def embedd_token(x, dims, layers): 44 | stop_concat = -len(dims) 45 | to_embedd = x[:, stop_concat:].long() 46 | for i,emb_layer in enumerate(layers): 47 | # the portion corresponding to `to_embedd` part gets dropped 48 | x = torch.cat([ x[:, :stop_concat], 49 | emb_layer( to_embedd[:, i] ) 50 | ], dim=-1) 51 | stop_concat = x.shape[-1] 52 | return x 53 | 54 | # swish activation fallback 55 | 56 | class Swish_(nn.Module): 57 | def forward(self, x): 58 | return x * x.sigmoid() 59 | 60 | SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_ 61 | 62 | # helper classes 63 | 64 | # this follows the same strategy for normalization as done in SE3 Transformers 65 | # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95 66 | 67 | class CoorsNorm(nn.Module): 68 | def __init__(self, eps = 1e-8, scale_init = 1.): 69 | super().__init__() 70 | self.eps = eps 71 | scale = torch.zeros(1).fill_(scale_init) 72 | self.scale = nn.Parameter(scale) 73 | 74 | def forward(self, coors): 75 | norm = coors.norm(dim = -1, keepdim = True) 76 | normed_coors = coors / norm.clamp(min = self.eps) 77 | return normed_coors * self.scale 78 | 79 | # global linear attention 80 | 81 | class Attention(nn.Module): 82 | def __init__(self, dim, heads = 8, dim_head = 64): 83 | super().__init__() 84 | inner_dim = heads * dim_head 85 | self.heads = heads 86 | self.scale = dim_head ** -0.5 87 | 88 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 89 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 90 | self.to_out = nn.Linear(inner_dim, dim) 91 | 92 | def forward(self, x, context, mask = None): 93 | h = self.heads 94 | 95 | q = self.to_q(x) 96 | kv = self.to_kv(context).chunk(2, dim = -1) 97 | 98 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv)) 99 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 100 | 101 | if exists(mask): 102 | mask_value = -torch.finfo(dots.dtype).max 103 | mask = rearrange(mask, 'b n -> b () () n') 104 | dots.masked_fill_(~mask, mask_value) 105 | 106 | attn = dots.softmax(dim = -1) 107 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 108 | 109 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 110 | return self.to_out(out) 111 | 112 | class GlobalLinearAttention(nn.Module): 113 | def __init__( 114 | self, 115 | *, 116 | dim, 117 | heads = 8, 118 | dim_head = 64 119 | ): 120 | super().__init__() 121 | self.norm_seq = nn.LayerNorm(dim) 122 | self.norm_queries = nn.LayerNorm(dim) 123 | self.attn1 = Attention(dim, heads, dim_head) 124 | self.attn2 = Attention(dim, heads, dim_head) 125 | 126 | self.ff = nn.Sequential( 127 | nn.LayerNorm(dim), 128 | nn.Linear(dim, dim * 4), 129 | nn.GELU(), 130 | nn.Linear(dim * 4, dim) 131 | ) 132 | 133 | def forward(self, x, queries, mask = None): 134 | res_x, res_queries = x, queries 135 | x, queries = self.norm_seq(x), self.norm_queries(queries) 136 | 137 | induced = self.attn1(queries, x, mask = mask) 138 | out = self.attn2(x, induced) 139 | 140 | x = out + res_x 141 | queries = induced + res_queries 142 | 143 | x = self.ff(x) + x 144 | return x, queries 145 | 146 | # classes 147 | 148 | class EGNN(nn.Module): 149 | def __init__( 150 | self, 151 | dim, 152 | edge_dim = 0, 153 | m_dim = 16, 154 | fourier_features = 0, 155 | num_nearest_neighbors = 0, 156 | dropout = 0.0, 157 | init_eps = 1e-3, 158 | norm_feats = False, 159 | norm_coors = False, 160 | norm_coors_scale_init = 1e-2, 161 | update_feats = True, 162 | update_coors = True, 163 | only_sparse_neighbors = False, 164 | valid_radius = float('inf'), 165 | m_pool_method = 'sum', 166 | soft_edges = False, 167 | coor_weights_clamp_value = None 168 | ): 169 | super().__init__() 170 | assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' 171 | assert update_feats or update_coors, 'you must update either features, coordinates, or both' 172 | 173 | self.fourier_features = fourier_features 174 | 175 | edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1 176 | dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 177 | 178 | self.edge_mlp = nn.Sequential( 179 | nn.Linear(edge_input_dim, edge_input_dim * 2), 180 | dropout, 181 | SiLU(), 182 | nn.Linear(edge_input_dim * 2, m_dim), 183 | SiLU() 184 | ) 185 | 186 | self.edge_gate = nn.Sequential( 187 | nn.Linear(m_dim, 1), 188 | nn.Sigmoid() 189 | ) if soft_edges else None 190 | 191 | self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity() 192 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() 193 | 194 | self.m_pool_method = m_pool_method 195 | 196 | self.node_mlp = nn.Sequential( 197 | nn.Linear(dim + m_dim, dim * 2), 198 | dropout, 199 | SiLU(), 200 | nn.Linear(dim * 2, dim), 201 | ) if update_feats else None 202 | 203 | self.coors_mlp = nn.Sequential( 204 | nn.Linear(m_dim, m_dim * 4), 205 | dropout, 206 | SiLU(), 207 | nn.Linear(m_dim * 4, 1) 208 | ) if update_coors else None 209 | 210 | self.num_nearest_neighbors = num_nearest_neighbors 211 | self.only_sparse_neighbors = only_sparse_neighbors 212 | self.valid_radius = valid_radius 213 | 214 | self.coor_weights_clamp_value = coor_weights_clamp_value 215 | 216 | self.init_eps = init_eps 217 | self.apply(self.init_) 218 | 219 | def init_(self, module): 220 | if type(module) in {nn.Linear}: 221 | # seems to be needed to keep the network from exploding to NaN with greater depths 222 | nn.init.normal_(module.weight, std = self.init_eps) 223 | 224 | def forward(self, feats, coors, edges = None, mask = None, adj_mat = None): 225 | b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors 226 | 227 | if exists(mask): 228 | num_nodes = mask.sum(dim = -1) 229 | 230 | use_nearest = num_nearest > 0 or only_sparse_neighbors 231 | 232 | rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d') 233 | rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True) 234 | 235 | i = j = n 236 | 237 | if use_nearest: 238 | ranking = rel_dist[..., 0].clone() 239 | 240 | if exists(mask): 241 | rank_mask = mask[:, :, None] * mask[:, None, :] 242 | ranking.masked_fill_(~rank_mask, 1e5) 243 | 244 | if exists(adj_mat): 245 | if len(adj_mat.shape) == 2: 246 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) 247 | 248 | if only_sparse_neighbors: 249 | num_nearest = int(adj_mat.float().sum(dim = -1).max().item()) 250 | valid_radius = 0 251 | 252 | self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j') 253 | 254 | adj_mat = adj_mat.masked_fill(self_mask, False) 255 | ranking.masked_fill_(self_mask, -1.) 256 | ranking.masked_fill_(adj_mat, 0.) 257 | 258 | nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False) 259 | 260 | nbhd_mask = nbhd_ranking <= valid_radius 261 | 262 | rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2) 263 | rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2) 264 | 265 | if exists(edges): 266 | edges = batched_index_select(edges, nbhd_indices, dim = 2) 267 | 268 | j = num_nearest 269 | 270 | if fourier_features > 0: 271 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features) 272 | rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') 273 | 274 | if use_nearest: 275 | feats_j = batched_index_select(feats, nbhd_indices, dim = 1) 276 | else: 277 | feats_j = rearrange(feats, 'b j d -> b () j d') 278 | 279 | feats_i = rearrange(feats, 'b i d -> b i () d') 280 | feats_i, feats_j = broadcast_tensors(feats_i, feats_j) 281 | 282 | edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1) 283 | 284 | if exists(edges): 285 | edge_input = torch.cat((edge_input, edges), dim = -1) 286 | 287 | m_ij = self.edge_mlp(edge_input) 288 | 289 | if exists(self.edge_gate): 290 | m_ij = m_ij * self.edge_gate(m_ij) 291 | 292 | if exists(mask): 293 | mask_i = rearrange(mask, 'b i -> b i ()') 294 | 295 | if use_nearest: 296 | mask_j = batched_index_select(mask, nbhd_indices, dim = 1) 297 | mask = (mask_i * mask_j) & nbhd_mask 298 | else: 299 | mask_j = rearrange(mask, 'b j -> b () j') 300 | mask = mask_i * mask_j 301 | 302 | if exists(self.coors_mlp): 303 | coor_weights = self.coors_mlp(m_ij) 304 | coor_weights = rearrange(coor_weights, 'b i j () -> b i j') 305 | 306 | rel_coors = self.coors_norm(rel_coors) 307 | 308 | if exists(mask): 309 | coor_weights.masked_fill_(~mask, 0.) 310 | 311 | if exists(self.coor_weights_clamp_value): 312 | clamp_value = self.coor_weights_clamp_value 313 | coor_weights.clamp_(min = -clamp_value, max = clamp_value) 314 | 315 | coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors 316 | else: 317 | coors_out = coors 318 | 319 | if exists(self.node_mlp): 320 | if exists(mask): 321 | m_ij_mask = rearrange(mask, '... -> ... ()') 322 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.) 323 | 324 | if self.m_pool_method == 'mean': 325 | if exists(mask): 326 | # masked mean 327 | mask_sum = m_ij_mask.sum(dim = -2) 328 | m_i = safe_div(m_ij.sum(dim = -2), mask_sum) 329 | else: 330 | m_i = m_ij.mean(dim = -2) 331 | 332 | elif self.m_pool_method == 'sum': 333 | m_i = m_ij.sum(dim = -2) 334 | 335 | normed_feats = self.node_norm(feats) 336 | node_mlp_input = torch.cat((normed_feats, m_i), dim = -1) 337 | node_out = self.node_mlp(node_mlp_input) + feats 338 | else: 339 | node_out = feats 340 | 341 | return node_out, coors_out 342 | 343 | class EGNN_Network(nn.Module): 344 | def __init__( 345 | self, 346 | *, 347 | depth, 348 | dim, 349 | num_tokens = None, 350 | num_edge_tokens = None, 351 | num_positions = None, 352 | edge_dim = 0, 353 | num_adj_degrees = None, 354 | adj_dim = 0, 355 | global_linear_attn_every = 0, 356 | global_linear_attn_heads = 8, 357 | global_linear_attn_dim_head = 64, 358 | num_global_tokens = 4, 359 | **kwargs 360 | ): 361 | super().__init__() 362 | assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1' 363 | self.num_positions = num_positions 364 | 365 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None 366 | self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None 367 | self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None 368 | self.has_edges = edge_dim > 0 369 | 370 | self.num_adj_degrees = num_adj_degrees 371 | self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None 372 | 373 | edge_dim = edge_dim if self.has_edges else 0 374 | adj_dim = adj_dim if exists(num_adj_degrees) else 0 375 | 376 | has_global_attn = global_linear_attn_every > 0 377 | self.global_tokens = None 378 | if has_global_attn: 379 | self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim)) 380 | 381 | self.layers = nn.ModuleList([]) 382 | for ind in range(depth): 383 | is_global_layer = has_global_attn and (ind % global_linear_attn_every) == 0 384 | 385 | self.layers.append(nn.ModuleList([ 386 | GlobalLinearAttention(dim = dim, heads = global_linear_attn_heads, dim_head = global_linear_attn_dim_head) if is_global_layer else None, 387 | EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs), 388 | ])) 389 | 390 | def forward( 391 | self, 392 | feats, 393 | coors, 394 | adj_mat = None, 395 | edges = None, 396 | mask = None, 397 | return_coor_changes = False 398 | ): 399 | b, device = feats.shape[0], feats.device 400 | 401 | if exists(self.token_emb): 402 | feats = self.token_emb(feats) 403 | 404 | if exists(self.pos_emb): 405 | n = feats.shape[1] 406 | assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init' 407 | pos_emb = self.pos_emb(torch.arange(n, device = device)) 408 | feats += rearrange(pos_emb, 'n d -> () n d') 409 | 410 | if exists(edges) and exists(self.edge_emb): 411 | edges = self.edge_emb(edges) 412 | 413 | # create N-degrees adjacent matrix from 1st degree connections 414 | if exists(self.num_adj_degrees): 415 | assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)' 416 | 417 | if len(adj_mat.shape) == 2: 418 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) 419 | 420 | adj_indices = adj_mat.clone().long() 421 | 422 | for ind in range(self.num_adj_degrees - 1): 423 | degree = ind + 2 424 | 425 | next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 426 | next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool() 427 | adj_indices.masked_fill_(next_degree_mask, degree) 428 | adj_mat = next_degree_adj_mat.clone() 429 | 430 | if exists(self.adj_emb): 431 | adj_emb = self.adj_emb(adj_indices) 432 | edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb 433 | 434 | # setup global attention 435 | 436 | global_tokens = None 437 | if exists(self.global_tokens): 438 | global_tokens = repeat(self.global_tokens, 'n d -> b n d', b = b) 439 | 440 | # go through layers 441 | 442 | coor_changes = [coors] 443 | 444 | for global_attn, egnn in self.layers: 445 | if exists(global_attn): 446 | feats, global_tokens = global_attn(feats, global_tokens, mask = mask) 447 | 448 | feats, coors = egnn(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask) 449 | coor_changes.append(coors) 450 | 451 | if return_coor_changes: 452 | return feats, coors, coor_changes 453 | 454 | return feats, coors 455 | -------------------------------------------------------------------------------- /egnn_pytorch/egnn_pytorch_geometric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum, broadcast_tensors 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # types 9 | 10 | from typing import Optional, List, Union 11 | 12 | # pytorch geometric 13 | 14 | try: 15 | import torch_geometric 16 | from torch_geometric.nn import MessagePassing 17 | from torch_geometric.typing import Adj, Size, OptTensor, Tensor 18 | except: 19 | Tensor = OptTensor = Adj = MessagePassing = Size = object 20 | PYG_AVAILABLE = False 21 | 22 | # to stop throwing errors from type suggestions 23 | Adj = object 24 | Size = object 25 | OptTensor = object 26 | Tensor = object 27 | 28 | from .egnn_pytorch import * 29 | 30 | # global linear attention 31 | 32 | class Attention_Sparse(Attention): 33 | def __init__(self, **kwargs): 34 | """ Wraps the attention class to operate with pytorch-geometric inputs. """ 35 | super(Attention_Sparse, self).__init__(**kwargs) 36 | 37 | def sparse_forward(self, x, context, batch=None, batch_uniques=None, mask=None): 38 | assert batch is not None or batch_uniques is not None, "Batch/(uniques) must be passed for block_sparse_attn" 39 | if batch_uniques is None: 40 | batch_uniques = torch.unique(batch, return_counts=True) 41 | # only one example in batch - do dense - faster 42 | if batch_uniques[0].shape[0] == 1: 43 | x, context = map(lambda t: rearrange(t, 'h d -> () h d'), (x, context)) 44 | return self.forward(x, context, mask=None).squeeze() # get rid of batch dim 45 | # multiple examples in batch - do block-sparse by dense loop 46 | else: 47 | x_list = [] 48 | aux_count = 0 49 | for bi,n_idxs in zip(*batch_uniques): 50 | x_list.append( 51 | self.sparse_forward( 52 | x[aux_count:aux_count+n_i], 53 | context[aux_count:aux_count+n_idxs], 54 | batch_uniques = (bi.unsqueeze(-1), n_idxs.unsqueeze(-1)) 55 | ) 56 | ) 57 | return torch.cat(x_list, dim=0) 58 | 59 | 60 | class GlobalLinearAttention_Sparse(nn.Module): 61 | def __init__( 62 | self, 63 | *, 64 | dim, 65 | heads = 8, 66 | dim_head = 64 67 | ): 68 | super().__init__() 69 | self.norm_seq = torch_geomtric.nn.norm.LayerNorm(dim) 70 | self.norm_queries = torch_geomtric.nn.norm.LayerNorm(dim) 71 | self.attn1 = Attention_Sparse(dim, heads, dim_head) 72 | self.attn2 = Attention_Sparse(dim, heads, dim_head) 73 | 74 | # can't concat pyg norms with torch sequentials 75 | self.ff_norm = torch_geomtric.nn.norm.LayerNorm(dim) 76 | self.ff = nn.Sequential( 77 | nn.Linear(dim, dim * 4), 78 | nn.GELU(), 79 | nn.Linear(dim * 4, dim) 80 | ) 81 | 82 | def forward(self, x, queries, batch=None, batch_uniques=None, mask = None): 83 | res_x, res_queries = x, queries 84 | x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch) 85 | 86 | induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask = mask) 87 | out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques) 88 | 89 | x = out + res_x 90 | queries = induced + res_queries 91 | 92 | x_norm = self.ff_norm(x, batch=batch) 93 | x = self.ff(x_norm) + x_norm 94 | return x, queries 95 | 96 | 97 | # define pytorch-geometric equivalents 98 | 99 | class EGNN_Sparse(MessagePassing): 100 | """ Different from the above since it separates the edge assignment 101 | from the computation (this allows for great reduction in time and 102 | computations when the graph is locally or sparse connected). 103 | * aggr: one of ["add", "mean", "max"] 104 | """ 105 | def __init__( 106 | self, 107 | feats_dim, 108 | pos_dim=3, 109 | edge_attr_dim = 0, 110 | m_dim = 16, 111 | fourier_features = 0, 112 | soft_edge = 0, 113 | norm_feats = False, 114 | norm_coors = False, 115 | norm_coors_scale_init = 1e-2, 116 | update_feats = True, 117 | update_coors = True, 118 | dropout = 0., 119 | coor_weights_clamp_value = None, 120 | aggr = "add", 121 | **kwargs 122 | ): 123 | assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option' 124 | assert update_feats or update_coors, 'you must update either features, coordinates, or both' 125 | kwargs.setdefault('aggr', aggr) 126 | super(EGNN_Sparse, self).__init__(**kwargs) 127 | # model params 128 | self.fourier_features = fourier_features 129 | self.feats_dim = feats_dim 130 | self.pos_dim = pos_dim 131 | self.m_dim = m_dim 132 | self.soft_edge = soft_edge 133 | self.norm_feats = norm_feats 134 | self.norm_coors = norm_coors 135 | self.update_coors = update_coors 136 | self.update_feats = update_feats 137 | self.coor_weights_clamp_value = None 138 | 139 | self.edge_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2) 140 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 141 | 142 | # EDGES 143 | self.edge_mlp = nn.Sequential( 144 | nn.Linear(self.edge_input_dim, self.edge_input_dim * 2), 145 | self.dropout, 146 | SiLU(), 147 | nn.Linear(self.edge_input_dim * 2, m_dim), 148 | SiLU() 149 | ) 150 | 151 | self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1), 152 | nn.Sigmoid() 153 | ) if soft_edge else None 154 | 155 | # NODES - can't do identity in node_norm bc pyg expects 2 inputs, but identity expects 1. 156 | self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None 157 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() 158 | 159 | self.node_mlp = nn.Sequential( 160 | nn.Linear(feats_dim + m_dim, feats_dim * 2), 161 | self.dropout, 162 | SiLU(), 163 | nn.Linear(feats_dim * 2, feats_dim), 164 | ) if update_feats else None 165 | 166 | # COORS 167 | self.coors_mlp = nn.Sequential( 168 | nn.Linear(m_dim, m_dim * 4), 169 | self.dropout, 170 | SiLU(), 171 | nn.Linear(self.m_dim * 4, 1) 172 | ) if update_coors else None 173 | 174 | self.apply(self.init_) 175 | 176 | def init_(self, module): 177 | if type(module) in {nn.Linear}: 178 | # seems to be needed to keep the network from exploding to NaN with greater depths 179 | nn.init.xavier_normal_(module.weight) 180 | nn.init.zeros_(module.bias) 181 | 182 | def forward(self, x: Tensor, edge_index: Adj, 183 | edge_attr: OptTensor = None, batch: Adj = None, 184 | angle_data: List = None, size: Size = None) -> Tensor: 185 | """ Inputs: 186 | * x: (n_points, d) where d is pos_dims + feat_dims 187 | * edge_index: (2, n_edges) 188 | * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. 189 | * batch: (n_points,) long tensor. specifies xloud belonging for each point 190 | * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. 191 | * size: None 192 | """ 193 | coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:] 194 | 195 | rel_coors = coors[edge_index[0]] - coors[edge_index[1]] 196 | rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True) 197 | 198 | if self.fourier_features > 0: 199 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = self.fourier_features) 200 | rel_dist = rearrange(rel_dist, 'n () d -> n d') 201 | 202 | if exists(edge_attr): 203 | edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1) 204 | else: 205 | edge_attr_feats = rel_dist 206 | 207 | hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats, 208 | coors=coors, rel_coors=rel_coors, 209 | batch=batch) 210 | return torch.cat([coors_out, hidden_out], dim=-1) 211 | 212 | 213 | def message(self, x_i, x_j, edge_attr) -> Tensor: 214 | m_ij = self.edge_mlp( torch.cat([x_i, x_j, edge_attr], dim=-1) ) 215 | return m_ij 216 | 217 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs): 218 | """The initial call to start propagating messages. 219 | Args: 220 | `edge_index` holds the indices of a general (sparse) 221 | assignment matrix of shape :obj:`[N, M]`. 222 | size (tuple, optional) if none, the size will be inferred 223 | and assumed to be quadratic. 224 | **kwargs: Any additional data which is needed to construct and 225 | aggregate messages, and to update node embeddings. 226 | """ 227 | size = self._check_input(edge_index, size) 228 | coll_dict = self._collect(self._user_args, 229 | edge_index, size, kwargs) 230 | msg_kwargs = self.inspector.collect_param_data('message', coll_dict) 231 | aggr_kwargs = self.inspector.collect_param_data('aggregate', coll_dict) 232 | update_kwargs = self.inspector.collect_param_data('update', coll_dict) 233 | 234 | # get messages 235 | m_ij = self.message(**msg_kwargs) 236 | 237 | # update coors if specified 238 | if self.update_coors: 239 | coor_wij = self.coors_mlp(m_ij) 240 | # clamp if arg is set 241 | if self.coor_weights_clamp_value: 242 | coor_weights_clamp_value = self.coor_weights_clamp_value 243 | coor_weights.clamp_(min = -clamp_value, max = clamp_value) 244 | 245 | # normalize if needed 246 | kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"]) 247 | 248 | mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs) 249 | coors_out = kwargs["coors"] + mhat_i 250 | else: 251 | coors_out = kwargs["coors"] 252 | 253 | # update feats if specified 254 | if self.update_feats: 255 | # weight the edges if arg is passed 256 | if self.soft_edge: 257 | m_ij = m_ij * self.edge_weight(m_ij) 258 | m_i = self.aggregate(m_ij, **aggr_kwargs) 259 | 260 | hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"] 261 | hidden_out = self.node_mlp( torch.cat([hidden_feats, m_i], dim = -1) ) 262 | hidden_out = kwargs["x"] + hidden_out 263 | else: 264 | hidden_out = kwargs["x"] 265 | 266 | # return tuple 267 | return self.update((hidden_out, coors_out), **update_kwargs) 268 | 269 | def __repr__(self): 270 | dict_print = {} 271 | return "E(n)-GNN Layer for Graphs " + str(self.__dict__) 272 | 273 | 274 | class EGNN_Sparse_Network(nn.Module): 275 | r"""Sample GNN model architecture that uses the EGNN-Sparse 276 | message passing layer to learn over point clouds. 277 | Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1 278 | 279 | Inputs will be standard GNN: x, edge_index, edge_attr, batch, ... 280 | 281 | Args: 282 | * n_layers: int. number of MPNN layers 283 | * ... : same interpretation as the base layer. 284 | * embedding_nums: list. number of unique keys to embedd. for points 285 | 1 entry per embedding needed. 286 | * embedding_dims: list. point - number of dimensions of 287 | the resulting embedding. 1 entry per embedding needed. 288 | * edge_embedding_nums: list. number of unique keys to embedd. for edges. 289 | 1 entry per embedding needed. 290 | * edge_embedding_dims: list. point - number of dimensions of 291 | the resulting embedding. 1 entry per embedding needed. 292 | * recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc 293 | * verbose: bool. verbosity level. 294 | ----- 295 | Diff with normal layer: one has to do preprocessing before (radius, global token, ...) 296 | """ 297 | def __init__(self, n_layers, feats_dim, 298 | pos_dim = 3, 299 | edge_attr_dim = 0, 300 | m_dim = 16, 301 | fourier_features = 0, 302 | soft_edge = 0, 303 | embedding_nums=[], 304 | embedding_dims=[], 305 | edge_embedding_nums=[], 306 | edge_embedding_dims=[], 307 | update_coors=True, 308 | update_feats=True, 309 | norm_feats=True, 310 | norm_coors=False, 311 | norm_coors_scale_init = 1e-2, 312 | dropout=0., 313 | coor_weights_clamp_value=None, 314 | aggr="add", 315 | global_linear_attn_every = 0, 316 | global_linear_attn_heads = 8, 317 | global_linear_attn_dim_head = 64, 318 | num_global_tokens = 4, 319 | recalc=0 ,): 320 | super().__init__() 321 | 322 | self.n_layers = n_layers 323 | 324 | # Embeddings? solve here 325 | self.embedding_nums = embedding_nums 326 | self.embedding_dims = embedding_dims 327 | self.emb_layers = nn.ModuleList() 328 | self.edge_embedding_nums = edge_embedding_nums 329 | self.edge_embedding_dims = edge_embedding_dims 330 | self.edge_emb_layers = nn.ModuleList() 331 | 332 | # instantiate point and edge embedding layers 333 | 334 | for i in range( len(self.embedding_dims) ): 335 | self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i], 336 | embedding_dim = embedding_dims[i])) 337 | feats_dim += embedding_dims[i] - 1 338 | 339 | for i in range( len(self.edge_embedding_dims) ): 340 | self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i], 341 | embedding_dim = edge_embedding_dims[i])) 342 | edge_attr_dim += edge_embedding_dims[i] - 1 343 | # rest 344 | self.mpnn_layers = nn.ModuleList() 345 | self.feats_dim = feats_dim 346 | self.pos_dim = pos_dim 347 | self.edge_attr_dim = edge_attr_dim 348 | self.m_dim = m_dim 349 | self.fourier_features = fourier_features 350 | self.soft_edge = soft_edge 351 | self.norm_feats = norm_feats 352 | self.norm_coors = norm_coors 353 | self.norm_coors_scale_init = norm_coors_scale_init 354 | self.update_feats = update_feats 355 | self.update_coors = update_coors 356 | self.dropout = dropout 357 | self.coor_weights_clamp_value = coor_weights_clamp_value 358 | self.recalc = recalc 359 | 360 | self.has_global_attn = global_linear_attn_every > 0 361 | self.global_tokens = None 362 | self.global_linear_attn_every = global_linear_attn_every 363 | if self.has_global_attn: 364 | self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim)) 365 | 366 | # instantiate layers 367 | for i in range(n_layers): 368 | layer = EGNN_Sparse(feats_dim = feats_dim, 369 | pos_dim = pos_dim, 370 | edge_attr_dim = edge_attr_dim, 371 | m_dim = m_dim, 372 | fourier_features = fourier_features, 373 | soft_edge = soft_edge, 374 | norm_feats = norm_feats, 375 | norm_coors = norm_coors, 376 | norm_coors_scale_init = norm_coors_scale_init, 377 | update_feats = update_feats, 378 | update_coors = update_coors, 379 | dropout = dropout, 380 | coor_weights_clamp_value = coor_weights_clamp_value) 381 | 382 | # global attention case 383 | is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0 384 | if is_global_layer: 385 | attn_layer = GlobalLinearAttention(dim = self.feats_dim, 386 | heads = global_linear_attn_heads, 387 | dim_head = global_linear_attn_dim_head) 388 | self.mpnn_layers.append(nn.ModuleList([layer, attn_layer])) 389 | # normal case 390 | else: 391 | self.mpnn_layers.append(layer) 392 | 393 | 394 | def forward(self, x, edge_index, batch, edge_attr, 395 | bsize=None, recalc_edge=None, verbose=0): 396 | """ Recalculate edge features every `self.recalc_edge` with the 397 | `recalc_edge` function if self.recalc_edge is set. 398 | 399 | * x: (N, pos_dim+feats_dim) will be unpacked into coors, feats. 400 | """ 401 | # NODES - Embedd each dim to its target dimensions: 402 | x = embedd_token(x, self.embedding_dims, self.emb_layers) 403 | 404 | # regulates wether to embedd edges each layer 405 | edges_need_embedding = True 406 | for i,layer in enumerate(self.mpnn_layers): 407 | 408 | # EDGES - Embedd each dim to its target dimensions: 409 | if edges_need_embedding: 410 | edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers) 411 | edges_need_embedding = False 412 | 413 | # attn tokens 414 | global_tokens = None 415 | if exists(self.global_tokens): 416 | unique, amounts = torch.unique(batch, return_counts) 417 | num_idxs = torch.cat([torch.arange(num_idxs_i) for num_idxs_i in amounts], dim=-1) 418 | global_tokens = self.global_tokens[num_idxs] 419 | 420 | # pass layers 421 | is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0 422 | if not is_global_layer: 423 | x = layer(x, edge_index, edge_attr, batch=batch, size=bsize) 424 | else: 425 | # only pass feats to the attn layer 426 | x_attn = layer[0](x[:, self.pos_dim:], global_tokens) 427 | # merge attn-ed feats and coords 428 | x = torch.cat( (x[:, :self.pos_dim], x_attn), dim=-1) 429 | x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize) 430 | 431 | # recalculate edge info - not needed if last layer 432 | if self.recalc and ((i%self.recalc == 0) and not (i == len(self.mpnn_layers)-1)) : 433 | edge_index, edge_attr, _ = recalc_edge(x) # returns attr, idx, any_other_info 434 | edges_need_embedding = True 435 | 436 | return x 437 | 438 | def __repr__(self): 439 | return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers)) 440 | -------------------------------------------------------------------------------- /egnn_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import sin, cos, atan2, acos 3 | 4 | def rot_z(gamma): 5 | return torch.tensor([ 6 | [cos(gamma), -sin(gamma), 0], 7 | [sin(gamma), cos(gamma), 0], 8 | [0, 0, 1] 9 | ], dtype=gamma.dtype) 10 | 11 | def rot_y(beta): 12 | return torch.tensor([ 13 | [cos(beta), 0, sin(beta)], 14 | [0, 1, 0], 15 | [-sin(beta), 0, cos(beta)] 16 | ], dtype=beta.dtype) 17 | 18 | def rot(alpha, beta, gamma): 19 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests/*.py 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'egnn-pytorch', 5 | packages = find_packages(), 6 | version = '0.2.8', 7 | license='MIT', 8 | description = 'E(n)-Equivariant Graph Neural Network - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang, Eric Alcaide', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/egnn-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'equivariance', 17 | 'graph neural network' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.3', 21 | 'numba', 22 | 'numpy', 23 | 'torch>=1.6' 24 | ], 25 | setup_requires=[ 26 | 'pytest-runner', 27 | ], 28 | tests_require=[ 29 | 'pytest' 30 | ], 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 'Intended Audience :: Developers', 34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3.6', 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tests/test_equivariance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from egnn_pytorch import EGNN, EGNN_Sparse 4 | from egnn_pytorch.utils import rot 5 | 6 | torch.set_default_dtype(torch.float64) 7 | 8 | def test_egnn_equivariance(): 9 | layer = EGNN(dim=512, edge_dim=4) 10 | 11 | R = rot(*torch.rand(3)) 12 | T = torch.randn(1, 1, 3) 13 | 14 | feats = torch.randn(1, 16, 512) 15 | coors = torch.randn(1, 16, 3) 16 | edges = torch.randn(1, 16, 16, 4) 17 | mask = torch.ones(1, 16).bool() 18 | 19 | # Cache first two nodes' features 20 | node1 = feats[:, 0, :] 21 | node2 = feats[:, 1, :] 22 | 23 | # Switch first and second nodes' positions 24 | feats_permuted_row_wise = feats.clone().detach() 25 | feats_permuted_row_wise[:, 0, :] = node2 26 | feats_permuted_row_wise[:, 1, :] = node1 27 | 28 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) 29 | feats2, coors2 = layer(feats, coors, edges, mask=mask) 30 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) 31 | 32 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' 33 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' 34 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' 35 | 36 | def test_higher_dimension(): 37 | layer = EGNN(dim=512, edge_dim=4) 38 | 39 | feats = torch.randn(1, 16, 512) 40 | coors = torch.randn(1, 16, 5) 41 | edges = torch.randn(1, 16, 16, 4) 42 | mask = torch.ones(1, 16).bool() 43 | 44 | feats, coors = layer(feats, coors, edges, mask = mask) 45 | assert True 46 | 47 | def test_egnn_equivariance_with_nearest_neighbors(): 48 | layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8) 49 | 50 | R = rot(*torch.rand(3)) 51 | T = torch.randn(1, 1, 3) 52 | 53 | feats = torch.randn(1, 256, 512) 54 | coors = torch.randn(1, 256, 3) 55 | edges = torch.randn(1, 256, 256, 1) 56 | mask = torch.ones(1, 256).bool() 57 | 58 | # Cache first two nodes' features 59 | node1 = feats[:, 0, :] 60 | node2 = feats[:, 1, :] 61 | 62 | # Switch first and second nodes' positions 63 | feats_permuted_row_wise = feats.clone().detach() 64 | feats_permuted_row_wise[:, 0, :] = node2 65 | feats_permuted_row_wise[:, 1, :] = node1 66 | 67 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) 68 | feats2, coors2 = layer(feats, coors, edges, mask=mask) 69 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) 70 | 71 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' 72 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' 73 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' 74 | 75 | 76 | def test_egnn_equivariance_with_coord_norm(): 77 | layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8, norm_coors=True) 78 | 79 | R = rot(*torch.rand(3)) 80 | T = torch.randn(1, 1, 3) 81 | 82 | feats = torch.randn(1, 256, 512) 83 | coors = torch.randn(1, 256, 3) 84 | edges = torch.randn(1, 256, 256, 1) 85 | mask = torch.ones(1, 256).bool() 86 | 87 | # Cache first two nodes' features 88 | node1 = feats[:, 0, :] 89 | node2 = feats[:, 1, :] 90 | 91 | # Switch first and second nodes' positions 92 | feats_permuted_row_wise = feats.clone().detach() 93 | feats_permuted_row_wise[:, 0, :] = node2 94 | feats_permuted_row_wise[:, 1, :] = node1 95 | 96 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) 97 | feats2, coors2 = layer(feats, coors, edges, mask=mask) 98 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) 99 | 100 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' 101 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' 102 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' 103 | 104 | 105 | def test_egnn_sparse_equivariance(): 106 | layer = EGNN_Sparse(feats_dim=1, 107 | m_dim=16, 108 | fourier_features=4) 109 | 110 | R = rot(*torch.rand(3)) 111 | T = torch.randn(1, 1, 3) 112 | apply_action = lambda t: (t @ R + T).squeeze() 113 | 114 | feats = torch.randn(16, 1) 115 | coors = torch.randn(16, 3) 116 | edge_idxs = (torch.rand(2, 20) * 16).long() 117 | 118 | # Cache first two nodes' features 119 | node1 = feats[0, :] 120 | node2 = feats[1, :] 121 | 122 | # Switch first and second nodes' positions 123 | feats_permuted_row_wise = feats.clone().detach() 124 | feats_permuted_row_wise[0, :] = node2 125 | feats_permuted_row_wise[1, :] = node1 126 | 127 | x1 = torch.cat([coors, feats], dim=-1) 128 | x2 = torch.cat([apply_action(coors), feats], dim=-1) 129 | x3 = torch.cat([apply_action(coors), feats_permuted_row_wise], dim=-1) 130 | 131 | out1 = layer(x=x1, edge_index=edge_idxs) 132 | out2 = layer(x=x2, edge_index=edge_idxs) 133 | out3 = layer(x=x3, edge_index=edge_idxs) 134 | 135 | feats1, coors1 = out1[:, 3:], out1[:, :3] 136 | feats2, coors2 = out2[:, 3:], out2[:, :3] 137 | feats3, coors3 = out3[:, 3:], out3[:, :3] 138 | 139 | print(feats1 - feats2) 140 | print(apply_action(coors1) - coors2) 141 | assert torch.allclose(feats1, feats2), 'features must be invariant' 142 | assert torch.allclose(apply_action(coors1), coors2), 'coordinates must be equivariant' 143 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' 144 | 145 | 146 | def test_geom_equivalence(): 147 | layer = EGNN_Sparse(feats_dim=128, 148 | edge_attr_dim=4, 149 | m_dim=16, 150 | fourier_features=4) 151 | 152 | feats = torch.randn(16, 128) 153 | coors = torch.randn(16, 3) 154 | x = torch.cat([coors, feats], dim=-1) 155 | edge_idxs = (torch.rand(2, 20) * 16).long() 156 | edges_attrs = torch.randn(16, 16, 4) 157 | edges_attrs = edges_attrs[edge_idxs[0], edge_idxs[1]] 158 | 159 | assert layer.forward(x, edge_idxs, edge_attr=edges_attrs).shape == x.shape 160 | --------------------------------------------------------------------------------