├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── denoise_sparse.py
├── egnn.png
├── egnn_pytorch
├── __init__.py
├── egnn_pytorch.py
├── egnn_pytorch_geometric.py
└── utils.py
├── examples
└── egnn_test.ipynb
├── setup.cfg
├── setup.py
└── tests
└── test_equivariance.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # mac files and custom data for examples
10 | *.DS_Store
11 | .DS_Store
12 | *sidechainnet_data/*
13 | *.pkl
14 | *tests/custom_tests.py
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang, Eric Alcaide
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. 🙏 **
4 |
5 | ## EGNN - Pytorch
6 |
7 | Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install egnn-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from egnn_pytorch import EGNN
20 |
21 | layer1 = EGNN(dim = 512)
22 | layer2 = EGNN(dim = 512)
23 |
24 | feats = torch.randn(1, 16, 512)
25 | coors = torch.randn(1, 16, 3)
26 |
27 | feats, coors = layer1(feats, coors)
28 | feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)
29 | ```
30 |
31 | With edges
32 |
33 | ```python
34 | import torch
35 | from egnn_pytorch import EGNN
36 |
37 | layer1 = EGNN(dim = 512, edge_dim = 4)
38 | layer2 = EGNN(dim = 512, edge_dim = 4)
39 |
40 | feats = torch.randn(1, 16, 512)
41 | coors = torch.randn(1, 16, 3)
42 | edges = torch.randn(1, 16, 16, 4)
43 |
44 | feats, coors = layer1(feats, coors, edges)
45 | feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)
46 | ```
47 |
48 | A full EGNN network
49 |
50 | ```python
51 | import torch
52 | from egnn_pytorch import EGNN_Network
53 |
54 | net = EGNN_Network(
55 | num_tokens = 21,
56 | num_positions = 1024, # unless what you are passing in is an unordered set, set this to the maximum sequence length
57 | dim = 32,
58 | depth = 3,
59 | num_nearest_neighbors = 8,
60 | coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
61 | )
62 |
63 | feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
64 | coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
65 | mask = torch.ones_like(feats).bool() # (1, 1024)
66 |
67 | feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
68 | ```
69 |
70 | Only attend to sparse neighbors, given to the network as an adjacency matrix.
71 |
72 | ```python
73 | import torch
74 | from egnn_pytorch import EGNN_Network
75 |
76 | net = EGNN_Network(
77 | num_tokens = 21,
78 | dim = 32,
79 | depth = 3,
80 | only_sparse_neighbors = True
81 | )
82 |
83 | feats = torch.randint(0, 21, (1, 1024))
84 | coors = torch.randn(1, 1024, 3)
85 | mask = torch.ones_like(feats).bool()
86 |
87 | # naive adjacency matrix
88 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
89 | i = torch.arange(1024)
90 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
91 |
92 | feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
93 | ```
94 |
95 | You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments
96 |
97 | ```python
98 | import torch
99 | from egnn_pytorch import EGNN_Network
100 |
101 | net = EGNN_Network(
102 | num_tokens = 21,
103 | dim = 32,
104 | depth = 3,
105 | num_adj_degrees = 3, # fetch up to 3rd degree neighbors
106 | adj_dim = 8, # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
107 | only_sparse_neighbors = True
108 | )
109 |
110 | feats = torch.randint(0, 21, (1, 1024))
111 | coors = torch.randn(1, 1024, 3)
112 | mask = torch.ones_like(feats).bool()
113 |
114 | # naive adjacency matrix
115 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
116 | i = torch.arange(1024)
117 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
118 |
119 | feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
120 | ```
121 |
122 | ## Edges
123 |
124 | If you need to pass in continuous edges
125 |
126 | ```python
127 | import torch
128 | from egnn_pytorch import EGNN_Network
129 |
130 | net = EGNN_Network(
131 | num_tokens = 21,
132 | dim = 32,
133 | depth = 3,
134 | edge_dim = 4,
135 | num_nearest_neighbors = 3
136 | )
137 |
138 | feats = torch.randint(0, 21, (1, 1024))
139 | coors = torch.randn(1, 1024, 3)
140 | mask = torch.ones_like(feats).bool()
141 |
142 | continuous_edges = torch.randn(1, 1024, 1024, 4)
143 |
144 | # naive adjacency matrix
145 | # assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
146 | i = torch.arange(1024)
147 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
148 |
149 | feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
150 | ```
151 |
152 | ## Stability
153 |
154 | The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.
155 |
156 | ```python
157 | import torch
158 | from egnn_pytorch import EGNN_Network
159 |
160 | net = EGNN_Network(
161 | num_tokens = 21,
162 | dim = 32,
163 | depth = 3,
164 | num_nearest_neighbors = 32,
165 | norm_coors = True, # normalize the relative coordinates
166 | coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
167 | )
168 |
169 | feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
170 | coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
171 | mask = torch.ones_like(feats).bool() # (1, 1024)
172 |
173 | feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
174 | ```
175 |
176 | ## All parameters
177 |
178 | ```python
179 | import torch
180 | from egnn_pytorch import EGNN
181 |
182 | model = EGNN(
183 | dim = dim, # input dimension
184 | edge_dim = 0, # dimension of the edges, if exists, should be > 0
185 | m_dim = 16, # hidden model dimension
186 | fourier_features = 0, # number of fourier features for encoding of relative distance - defaults to none as in paper
187 | num_nearest_neighbors = 0, # cap the number of neighbors doing message passing by relative distance
188 | dropout = 0.0, # dropout
189 | norm_feats = False, # whether to layernorm the features
190 | norm_coors = False, # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper
191 | update_feats = True, # whether to update features - you can build a layer that only updates one or the other
192 | update_coors = True, # whether ot update coordinates
193 | only_sparse_neighbors = False, # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in
194 | valid_radius = float('inf'), # the valid radius each node considers for message passing
195 | m_pool_method = 'sum', # whether to mean or sum pool for output node representation
196 | soft_edges = False, # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
197 | coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes
198 | )
199 |
200 | ```
201 |
202 | ## Examples
203 |
204 | To run the protein backbone denoising example, first install `sidechainnet`
205 |
206 | ```bash
207 | $ pip install sidechainnet
208 | ```
209 |
210 | Then
211 |
212 | ```bash
213 | $ python denoise_sparse.py
214 | ```
215 |
216 | ## Tests
217 |
218 | Make sure you have pytorch geometric installed locally
219 |
220 | ```bash
221 | $ python setup.py test
222 | ```
223 |
224 | ## Citations
225 |
226 | ```bibtex
227 | @misc{satorras2021en,
228 | title = {E(n) Equivariant Graph Neural Networks},
229 | author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
230 | year = {2021},
231 | eprint = {2102.09844},
232 | archivePrefix = {arXiv},
233 | primaryClass = {cs.LG}
234 | }
235 | ```
236 |
--------------------------------------------------------------------------------
/denoise_sparse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.optim import Adam
5 |
6 | from einops import rearrange, repeat
7 |
8 | import sidechainnet as scn
9 | from egnn_pytorch.egnn_pytorch import EGNN_Network
10 |
11 | torch.set_default_dtype(torch.float64)
12 |
13 | BATCH_SIZE = 1
14 | GRADIENT_ACCUMULATE_EVERY = 16
15 |
16 | def cycle(loader, len_thres = 200):
17 | while True:
18 | for data in loader:
19 | if data.seqs.shape[1] > len_thres:
20 | continue
21 | yield data
22 |
23 | net = EGNN_Network(
24 | num_tokens = 21,
25 | num_positions = 200 * 3, # maximum number of positions - absolute positional embedding since there is inherent order in the sequence
26 | depth = 5,
27 | dim = 8,
28 | num_nearest_neighbors = 16,
29 | fourier_features = 2,
30 | norm_coors = True,
31 | coor_weights_clamp_value = 2.
32 | ).cuda()
33 |
34 | data = scn.load(
35 | casp_version = 12,
36 | thinning = 30,
37 | with_pytorch = 'dataloaders',
38 | batch_size = BATCH_SIZE,
39 | dynamic_batching = False
40 | )
41 |
42 | dl = cycle(data['train'])
43 | optim = Adam(net.parameters(), lr=1e-3)
44 |
45 | for _ in range(10000):
46 | for _ in range(GRADIENT_ACCUMULATE_EVERY):
47 | batch = next(dl)
48 | seqs, coords, masks = batch.seqs, batch.crds, batch.msks
49 |
50 | seqs = seqs.cuda().argmax(dim = -1)
51 | coords = coords.cuda().type(torch.float64)
52 | masks = masks.cuda().bool()
53 |
54 | l = seqs.shape[1]
55 | coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
56 |
57 | # Keeping only the backbone coordinates
58 |
59 | coords = coords[:, :, 0:3, :]
60 | coords = rearrange(coords, 'b l s c -> b (l s) c')
61 |
62 | seq = repeat(seqs, 'b n -> b (n c)', c = 3)
63 | masks = repeat(masks, 'b n -> b (n c)', c = 3)
64 |
65 | i = torch.arange(seq.shape[-1], device = seq.device)
66 | adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
67 |
68 | noised_coords = coords + torch.randn_like(coords)
69 |
70 | feats, denoised_coords = net(seq, noised_coords, adj_mat = adj_mat, mask = masks)
71 |
72 | loss = F.mse_loss(denoised_coords[masks], coords[masks])
73 |
74 | (loss / GRADIENT_ACCUMULATE_EVERY).backward()
75 |
76 | print('loss:', loss.item())
77 | optim.step()
78 | optim.zero_grad()
79 |
--------------------------------------------------------------------------------
/egnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/egnn-pytorch/d3ba97d5b2540085cf8e02c838ebf5e42071697f/egnn.png
--------------------------------------------------------------------------------
/egnn_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from egnn_pytorch.egnn_pytorch import EGNN, EGNN_Network
2 | from egnn_pytorch.egnn_pytorch_geometric import EGNN_Sparse, EGNN_Sparse_Network
3 |
--------------------------------------------------------------------------------
/egnn_pytorch/egnn_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum, broadcast_tensors
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 | # helper functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def safe_div(num, den, eps = 1e-8):
14 | res = num.div(den.clamp(min = eps))
15 | res.masked_fill_(den == 0, 0.)
16 | return res
17 |
18 | def batched_index_select(values, indices, dim = 1):
19 | value_dims = values.shape[(dim + 1):]
20 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
21 | indices = indices[(..., *((None,) * len(value_dims)))]
22 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
23 | value_expand_len = len(indices_shape) - (dim + 1)
24 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
25 |
26 | value_expand_shape = [-1] * len(values.shape)
27 | expand_slice = slice(dim, (dim + value_expand_len))
28 | value_expand_shape[expand_slice] = indices.shape[expand_slice]
29 | values = values.expand(*value_expand_shape)
30 |
31 | dim += value_expand_len
32 | return values.gather(dim, indices)
33 |
34 | def fourier_encode_dist(x, num_encodings = 4, include_self = True):
35 | x = x.unsqueeze(-1)
36 | device, dtype, orig_x = x.device, x.dtype, x
37 | scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype)
38 | x = x / scales
39 | x = torch.cat([x.sin(), x.cos()], dim=-1)
40 | x = torch.cat((x, orig_x), dim = -1) if include_self else x
41 | return x
42 |
43 | def embedd_token(x, dims, layers):
44 | stop_concat = -len(dims)
45 | to_embedd = x[:, stop_concat:].long()
46 | for i,emb_layer in enumerate(layers):
47 | # the portion corresponding to `to_embedd` part gets dropped
48 | x = torch.cat([ x[:, :stop_concat],
49 | emb_layer( to_embedd[:, i] )
50 | ], dim=-1)
51 | stop_concat = x.shape[-1]
52 | return x
53 |
54 | # swish activation fallback
55 |
56 | class Swish_(nn.Module):
57 | def forward(self, x):
58 | return x * x.sigmoid()
59 |
60 | SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_
61 |
62 | # helper classes
63 |
64 | # this follows the same strategy for normalization as done in SE3 Transformers
65 | # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95
66 |
67 | class CoorsNorm(nn.Module):
68 | def __init__(self, eps = 1e-8, scale_init = 1.):
69 | super().__init__()
70 | self.eps = eps
71 | scale = torch.zeros(1).fill_(scale_init)
72 | self.scale = nn.Parameter(scale)
73 |
74 | def forward(self, coors):
75 | norm = coors.norm(dim = -1, keepdim = True)
76 | normed_coors = coors / norm.clamp(min = self.eps)
77 | return normed_coors * self.scale
78 |
79 | # global linear attention
80 |
81 | class Attention(nn.Module):
82 | def __init__(self, dim, heads = 8, dim_head = 64):
83 | super().__init__()
84 | inner_dim = heads * dim_head
85 | self.heads = heads
86 | self.scale = dim_head ** -0.5
87 |
88 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
89 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
90 | self.to_out = nn.Linear(inner_dim, dim)
91 |
92 | def forward(self, x, context, mask = None):
93 | h = self.heads
94 |
95 | q = self.to_q(x)
96 | kv = self.to_kv(context).chunk(2, dim = -1)
97 |
98 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
99 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
100 |
101 | if exists(mask):
102 | mask_value = -torch.finfo(dots.dtype).max
103 | mask = rearrange(mask, 'b n -> b () () n')
104 | dots.masked_fill_(~mask, mask_value)
105 |
106 | attn = dots.softmax(dim = -1)
107 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
108 |
109 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
110 | return self.to_out(out)
111 |
112 | class GlobalLinearAttention(nn.Module):
113 | def __init__(
114 | self,
115 | *,
116 | dim,
117 | heads = 8,
118 | dim_head = 64
119 | ):
120 | super().__init__()
121 | self.norm_seq = nn.LayerNorm(dim)
122 | self.norm_queries = nn.LayerNorm(dim)
123 | self.attn1 = Attention(dim, heads, dim_head)
124 | self.attn2 = Attention(dim, heads, dim_head)
125 |
126 | self.ff = nn.Sequential(
127 | nn.LayerNorm(dim),
128 | nn.Linear(dim, dim * 4),
129 | nn.GELU(),
130 | nn.Linear(dim * 4, dim)
131 | )
132 |
133 | def forward(self, x, queries, mask = None):
134 | res_x, res_queries = x, queries
135 | x, queries = self.norm_seq(x), self.norm_queries(queries)
136 |
137 | induced = self.attn1(queries, x, mask = mask)
138 | out = self.attn2(x, induced)
139 |
140 | x = out + res_x
141 | queries = induced + res_queries
142 |
143 | x = self.ff(x) + x
144 | return x, queries
145 |
146 | # classes
147 |
148 | class EGNN(nn.Module):
149 | def __init__(
150 | self,
151 | dim,
152 | edge_dim = 0,
153 | m_dim = 16,
154 | fourier_features = 0,
155 | num_nearest_neighbors = 0,
156 | dropout = 0.0,
157 | init_eps = 1e-3,
158 | norm_feats = False,
159 | norm_coors = False,
160 | norm_coors_scale_init = 1e-2,
161 | update_feats = True,
162 | update_coors = True,
163 | only_sparse_neighbors = False,
164 | valid_radius = float('inf'),
165 | m_pool_method = 'sum',
166 | soft_edges = False,
167 | coor_weights_clamp_value = None
168 | ):
169 | super().__init__()
170 | assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean'
171 | assert update_feats or update_coors, 'you must update either features, coordinates, or both'
172 |
173 | self.fourier_features = fourier_features
174 |
175 | edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1
176 | dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
177 |
178 | self.edge_mlp = nn.Sequential(
179 | nn.Linear(edge_input_dim, edge_input_dim * 2),
180 | dropout,
181 | SiLU(),
182 | nn.Linear(edge_input_dim * 2, m_dim),
183 | SiLU()
184 | )
185 |
186 | self.edge_gate = nn.Sequential(
187 | nn.Linear(m_dim, 1),
188 | nn.Sigmoid()
189 | ) if soft_edges else None
190 |
191 | self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity()
192 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()
193 |
194 | self.m_pool_method = m_pool_method
195 |
196 | self.node_mlp = nn.Sequential(
197 | nn.Linear(dim + m_dim, dim * 2),
198 | dropout,
199 | SiLU(),
200 | nn.Linear(dim * 2, dim),
201 | ) if update_feats else None
202 |
203 | self.coors_mlp = nn.Sequential(
204 | nn.Linear(m_dim, m_dim * 4),
205 | dropout,
206 | SiLU(),
207 | nn.Linear(m_dim * 4, 1)
208 | ) if update_coors else None
209 |
210 | self.num_nearest_neighbors = num_nearest_neighbors
211 | self.only_sparse_neighbors = only_sparse_neighbors
212 | self.valid_radius = valid_radius
213 |
214 | self.coor_weights_clamp_value = coor_weights_clamp_value
215 |
216 | self.init_eps = init_eps
217 | self.apply(self.init_)
218 |
219 | def init_(self, module):
220 | if type(module) in {nn.Linear}:
221 | # seems to be needed to keep the network from exploding to NaN with greater depths
222 | nn.init.normal_(module.weight, std = self.init_eps)
223 |
224 | def forward(self, feats, coors, edges = None, mask = None, adj_mat = None):
225 | b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, self.num_nearest_neighbors, self.valid_radius, self.only_sparse_neighbors
226 |
227 | if exists(mask):
228 | num_nodes = mask.sum(dim = -1)
229 |
230 | use_nearest = num_nearest > 0 or only_sparse_neighbors
231 |
232 | rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
233 | rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True)
234 |
235 | i = j = n
236 |
237 | if use_nearest:
238 | ranking = rel_dist[..., 0].clone()
239 |
240 | if exists(mask):
241 | rank_mask = mask[:, :, None] * mask[:, None, :]
242 | ranking.masked_fill_(~rank_mask, 1e5)
243 |
244 | if exists(adj_mat):
245 | if len(adj_mat.shape) == 2:
246 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)
247 |
248 | if only_sparse_neighbors:
249 | num_nearest = int(adj_mat.float().sum(dim = -1).max().item())
250 | valid_radius = 0
251 |
252 | self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j')
253 |
254 | adj_mat = adj_mat.masked_fill(self_mask, False)
255 | ranking.masked_fill_(self_mask, -1.)
256 | ranking.masked_fill_(adj_mat, 0.)
257 |
258 | nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False)
259 |
260 | nbhd_mask = nbhd_ranking <= valid_radius
261 |
262 | rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2)
263 | rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2)
264 |
265 | if exists(edges):
266 | edges = batched_index_select(edges, nbhd_indices, dim = 2)
267 |
268 | j = num_nearest
269 |
270 | if fourier_features > 0:
271 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features)
272 | rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d')
273 |
274 | if use_nearest:
275 | feats_j = batched_index_select(feats, nbhd_indices, dim = 1)
276 | else:
277 | feats_j = rearrange(feats, 'b j d -> b () j d')
278 |
279 | feats_i = rearrange(feats, 'b i d -> b i () d')
280 | feats_i, feats_j = broadcast_tensors(feats_i, feats_j)
281 |
282 | edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1)
283 |
284 | if exists(edges):
285 | edge_input = torch.cat((edge_input, edges), dim = -1)
286 |
287 | m_ij = self.edge_mlp(edge_input)
288 |
289 | if exists(self.edge_gate):
290 | m_ij = m_ij * self.edge_gate(m_ij)
291 |
292 | if exists(mask):
293 | mask_i = rearrange(mask, 'b i -> b i ()')
294 |
295 | if use_nearest:
296 | mask_j = batched_index_select(mask, nbhd_indices, dim = 1)
297 | mask = (mask_i * mask_j) & nbhd_mask
298 | else:
299 | mask_j = rearrange(mask, 'b j -> b () j')
300 | mask = mask_i * mask_j
301 |
302 | if exists(self.coors_mlp):
303 | coor_weights = self.coors_mlp(m_ij)
304 | coor_weights = rearrange(coor_weights, 'b i j () -> b i j')
305 |
306 | rel_coors = self.coors_norm(rel_coors)
307 |
308 | if exists(mask):
309 | coor_weights.masked_fill_(~mask, 0.)
310 |
311 | if exists(self.coor_weights_clamp_value):
312 | clamp_value = self.coor_weights_clamp_value
313 | coor_weights.clamp_(min = -clamp_value, max = clamp_value)
314 |
315 | coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors
316 | else:
317 | coors_out = coors
318 |
319 | if exists(self.node_mlp):
320 | if exists(mask):
321 | m_ij_mask = rearrange(mask, '... -> ... ()')
322 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.)
323 |
324 | if self.m_pool_method == 'mean':
325 | if exists(mask):
326 | # masked mean
327 | mask_sum = m_ij_mask.sum(dim = -2)
328 | m_i = safe_div(m_ij.sum(dim = -2), mask_sum)
329 | else:
330 | m_i = m_ij.mean(dim = -2)
331 |
332 | elif self.m_pool_method == 'sum':
333 | m_i = m_ij.sum(dim = -2)
334 |
335 | normed_feats = self.node_norm(feats)
336 | node_mlp_input = torch.cat((normed_feats, m_i), dim = -1)
337 | node_out = self.node_mlp(node_mlp_input) + feats
338 | else:
339 | node_out = feats
340 |
341 | return node_out, coors_out
342 |
343 | class EGNN_Network(nn.Module):
344 | def __init__(
345 | self,
346 | *,
347 | depth,
348 | dim,
349 | num_tokens = None,
350 | num_edge_tokens = None,
351 | num_positions = None,
352 | edge_dim = 0,
353 | num_adj_degrees = None,
354 | adj_dim = 0,
355 | global_linear_attn_every = 0,
356 | global_linear_attn_heads = 8,
357 | global_linear_attn_dim_head = 64,
358 | num_global_tokens = 4,
359 | **kwargs
360 | ):
361 | super().__init__()
362 | assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'
363 | self.num_positions = num_positions
364 |
365 | self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
366 | self.pos_emb = nn.Embedding(num_positions, dim) if exists(num_positions) else None
367 | self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
368 | self.has_edges = edge_dim > 0
369 |
370 | self.num_adj_degrees = num_adj_degrees
371 | self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
372 |
373 | edge_dim = edge_dim if self.has_edges else 0
374 | adj_dim = adj_dim if exists(num_adj_degrees) else 0
375 |
376 | has_global_attn = global_linear_attn_every > 0
377 | self.global_tokens = None
378 | if has_global_attn:
379 | self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim))
380 |
381 | self.layers = nn.ModuleList([])
382 | for ind in range(depth):
383 | is_global_layer = has_global_attn and (ind % global_linear_attn_every) == 0
384 |
385 | self.layers.append(nn.ModuleList([
386 | GlobalLinearAttention(dim = dim, heads = global_linear_attn_heads, dim_head = global_linear_attn_dim_head) if is_global_layer else None,
387 | EGNN(dim = dim, edge_dim = (edge_dim + adj_dim), norm_feats = True, **kwargs),
388 | ]))
389 |
390 | def forward(
391 | self,
392 | feats,
393 | coors,
394 | adj_mat = None,
395 | edges = None,
396 | mask = None,
397 | return_coor_changes = False
398 | ):
399 | b, device = feats.shape[0], feats.device
400 |
401 | if exists(self.token_emb):
402 | feats = self.token_emb(feats)
403 |
404 | if exists(self.pos_emb):
405 | n = feats.shape[1]
406 | assert n <= self.num_positions, f'given sequence length {n} must be less than the number of positions {self.num_positions} set at init'
407 | pos_emb = self.pos_emb(torch.arange(n, device = device))
408 | feats += rearrange(pos_emb, 'n d -> () n d')
409 |
410 | if exists(edges) and exists(self.edge_emb):
411 | edges = self.edge_emb(edges)
412 |
413 | # create N-degrees adjacent matrix from 1st degree connections
414 | if exists(self.num_adj_degrees):
415 | assert exists(adj_mat), 'adjacency matrix must be passed in (keyword argument adj_mat)'
416 |
417 | if len(adj_mat.shape) == 2:
418 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)
419 |
420 | adj_indices = adj_mat.clone().long()
421 |
422 | for ind in range(self.num_adj_degrees - 1):
423 | degree = ind + 2
424 |
425 | next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
426 | next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
427 | adj_indices.masked_fill_(next_degree_mask, degree)
428 | adj_mat = next_degree_adj_mat.clone()
429 |
430 | if exists(self.adj_emb):
431 | adj_emb = self.adj_emb(adj_indices)
432 | edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb
433 |
434 | # setup global attention
435 |
436 | global_tokens = None
437 | if exists(self.global_tokens):
438 | global_tokens = repeat(self.global_tokens, 'n d -> b n d', b = b)
439 |
440 | # go through layers
441 |
442 | coor_changes = [coors]
443 |
444 | for global_attn, egnn in self.layers:
445 | if exists(global_attn):
446 | feats, global_tokens = global_attn(feats, global_tokens, mask = mask)
447 |
448 | feats, coors = egnn(feats, coors, adj_mat = adj_mat, edges = edges, mask = mask)
449 | coor_changes.append(coors)
450 |
451 | if return_coor_changes:
452 | return feats, coors, coor_changes
453 |
454 | return feats, coors
455 |
--------------------------------------------------------------------------------
/egnn_pytorch/egnn_pytorch_geometric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum, broadcast_tensors
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 | # types
9 |
10 | from typing import Optional, List, Union
11 |
12 | # pytorch geometric
13 |
14 | try:
15 | import torch_geometric
16 | from torch_geometric.nn import MessagePassing
17 | from torch_geometric.typing import Adj, Size, OptTensor, Tensor
18 | except:
19 | Tensor = OptTensor = Adj = MessagePassing = Size = object
20 | PYG_AVAILABLE = False
21 |
22 | # to stop throwing errors from type suggestions
23 | Adj = object
24 | Size = object
25 | OptTensor = object
26 | Tensor = object
27 |
28 | from .egnn_pytorch import *
29 |
30 | # global linear attention
31 |
32 | class Attention_Sparse(Attention):
33 | def __init__(self, **kwargs):
34 | """ Wraps the attention class to operate with pytorch-geometric inputs. """
35 | super(Attention_Sparse, self).__init__(**kwargs)
36 |
37 | def sparse_forward(self, x, context, batch=None, batch_uniques=None, mask=None):
38 | assert batch is not None or batch_uniques is not None, "Batch/(uniques) must be passed for block_sparse_attn"
39 | if batch_uniques is None:
40 | batch_uniques = torch.unique(batch, return_counts=True)
41 | # only one example in batch - do dense - faster
42 | if batch_uniques[0].shape[0] == 1:
43 | x, context = map(lambda t: rearrange(t, 'h d -> () h d'), (x, context))
44 | return self.forward(x, context, mask=None).squeeze() # get rid of batch dim
45 | # multiple examples in batch - do block-sparse by dense loop
46 | else:
47 | x_list = []
48 | aux_count = 0
49 | for bi,n_idxs in zip(*batch_uniques):
50 | x_list.append(
51 | self.sparse_forward(
52 | x[aux_count:aux_count+n_i],
53 | context[aux_count:aux_count+n_idxs],
54 | batch_uniques = (bi.unsqueeze(-1), n_idxs.unsqueeze(-1))
55 | )
56 | )
57 | return torch.cat(x_list, dim=0)
58 |
59 |
60 | class GlobalLinearAttention_Sparse(nn.Module):
61 | def __init__(
62 | self,
63 | *,
64 | dim,
65 | heads = 8,
66 | dim_head = 64
67 | ):
68 | super().__init__()
69 | self.norm_seq = torch_geomtric.nn.norm.LayerNorm(dim)
70 | self.norm_queries = torch_geomtric.nn.norm.LayerNorm(dim)
71 | self.attn1 = Attention_Sparse(dim, heads, dim_head)
72 | self.attn2 = Attention_Sparse(dim, heads, dim_head)
73 |
74 | # can't concat pyg norms with torch sequentials
75 | self.ff_norm = torch_geomtric.nn.norm.LayerNorm(dim)
76 | self.ff = nn.Sequential(
77 | nn.Linear(dim, dim * 4),
78 | nn.GELU(),
79 | nn.Linear(dim * 4, dim)
80 | )
81 |
82 | def forward(self, x, queries, batch=None, batch_uniques=None, mask = None):
83 | res_x, res_queries = x, queries
84 | x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch)
85 |
86 | induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask = mask)
87 | out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques)
88 |
89 | x = out + res_x
90 | queries = induced + res_queries
91 |
92 | x_norm = self.ff_norm(x, batch=batch)
93 | x = self.ff(x_norm) + x_norm
94 | return x, queries
95 |
96 |
97 | # define pytorch-geometric equivalents
98 |
99 | class EGNN_Sparse(MessagePassing):
100 | """ Different from the above since it separates the edge assignment
101 | from the computation (this allows for great reduction in time and
102 | computations when the graph is locally or sparse connected).
103 | * aggr: one of ["add", "mean", "max"]
104 | """
105 | def __init__(
106 | self,
107 | feats_dim,
108 | pos_dim=3,
109 | edge_attr_dim = 0,
110 | m_dim = 16,
111 | fourier_features = 0,
112 | soft_edge = 0,
113 | norm_feats = False,
114 | norm_coors = False,
115 | norm_coors_scale_init = 1e-2,
116 | update_feats = True,
117 | update_coors = True,
118 | dropout = 0.,
119 | coor_weights_clamp_value = None,
120 | aggr = "add",
121 | **kwargs
122 | ):
123 | assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option'
124 | assert update_feats or update_coors, 'you must update either features, coordinates, or both'
125 | kwargs.setdefault('aggr', aggr)
126 | super(EGNN_Sparse, self).__init__(**kwargs)
127 | # model params
128 | self.fourier_features = fourier_features
129 | self.feats_dim = feats_dim
130 | self.pos_dim = pos_dim
131 | self.m_dim = m_dim
132 | self.soft_edge = soft_edge
133 | self.norm_feats = norm_feats
134 | self.norm_coors = norm_coors
135 | self.update_coors = update_coors
136 | self.update_feats = update_feats
137 | self.coor_weights_clamp_value = None
138 |
139 | self.edge_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2)
140 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
141 |
142 | # EDGES
143 | self.edge_mlp = nn.Sequential(
144 | nn.Linear(self.edge_input_dim, self.edge_input_dim * 2),
145 | self.dropout,
146 | SiLU(),
147 | nn.Linear(self.edge_input_dim * 2, m_dim),
148 | SiLU()
149 | )
150 |
151 | self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1),
152 | nn.Sigmoid()
153 | ) if soft_edge else None
154 |
155 | # NODES - can't do identity in node_norm bc pyg expects 2 inputs, but identity expects 1.
156 | self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None
157 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()
158 |
159 | self.node_mlp = nn.Sequential(
160 | nn.Linear(feats_dim + m_dim, feats_dim * 2),
161 | self.dropout,
162 | SiLU(),
163 | nn.Linear(feats_dim * 2, feats_dim),
164 | ) if update_feats else None
165 |
166 | # COORS
167 | self.coors_mlp = nn.Sequential(
168 | nn.Linear(m_dim, m_dim * 4),
169 | self.dropout,
170 | SiLU(),
171 | nn.Linear(self.m_dim * 4, 1)
172 | ) if update_coors else None
173 |
174 | self.apply(self.init_)
175 |
176 | def init_(self, module):
177 | if type(module) in {nn.Linear}:
178 | # seems to be needed to keep the network from exploding to NaN with greater depths
179 | nn.init.xavier_normal_(module.weight)
180 | nn.init.zeros_(module.bias)
181 |
182 | def forward(self, x: Tensor, edge_index: Adj,
183 | edge_attr: OptTensor = None, batch: Adj = None,
184 | angle_data: List = None, size: Size = None) -> Tensor:
185 | """ Inputs:
186 | * x: (n_points, d) where d is pos_dims + feat_dims
187 | * edge_index: (2, n_edges)
188 | * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats.
189 | * batch: (n_points,) long tensor. specifies xloud belonging for each point
190 | * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor.
191 | * size: None
192 | """
193 | coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:]
194 |
195 | rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
196 | rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True)
197 |
198 | if self.fourier_features > 0:
199 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = self.fourier_features)
200 | rel_dist = rearrange(rel_dist, 'n () d -> n d')
201 |
202 | if exists(edge_attr):
203 | edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1)
204 | else:
205 | edge_attr_feats = rel_dist
206 |
207 | hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats,
208 | coors=coors, rel_coors=rel_coors,
209 | batch=batch)
210 | return torch.cat([coors_out, hidden_out], dim=-1)
211 |
212 |
213 | def message(self, x_i, x_j, edge_attr) -> Tensor:
214 | m_ij = self.edge_mlp( torch.cat([x_i, x_j, edge_attr], dim=-1) )
215 | return m_ij
216 |
217 | def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
218 | """The initial call to start propagating messages.
219 | Args:
220 | `edge_index` holds the indices of a general (sparse)
221 | assignment matrix of shape :obj:`[N, M]`.
222 | size (tuple, optional) if none, the size will be inferred
223 | and assumed to be quadratic.
224 | **kwargs: Any additional data which is needed to construct and
225 | aggregate messages, and to update node embeddings.
226 | """
227 | size = self._check_input(edge_index, size)
228 | coll_dict = self._collect(self._user_args,
229 | edge_index, size, kwargs)
230 | msg_kwargs = self.inspector.collect_param_data('message', coll_dict)
231 | aggr_kwargs = self.inspector.collect_param_data('aggregate', coll_dict)
232 | update_kwargs = self.inspector.collect_param_data('update', coll_dict)
233 |
234 | # get messages
235 | m_ij = self.message(**msg_kwargs)
236 |
237 | # update coors if specified
238 | if self.update_coors:
239 | coor_wij = self.coors_mlp(m_ij)
240 | # clamp if arg is set
241 | if self.coor_weights_clamp_value:
242 | coor_weights_clamp_value = self.coor_weights_clamp_value
243 | coor_weights.clamp_(min = -clamp_value, max = clamp_value)
244 |
245 | # normalize if needed
246 | kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"])
247 |
248 | mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs)
249 | coors_out = kwargs["coors"] + mhat_i
250 | else:
251 | coors_out = kwargs["coors"]
252 |
253 | # update feats if specified
254 | if self.update_feats:
255 | # weight the edges if arg is passed
256 | if self.soft_edge:
257 | m_ij = m_ij * self.edge_weight(m_ij)
258 | m_i = self.aggregate(m_ij, **aggr_kwargs)
259 |
260 | hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"]
261 | hidden_out = self.node_mlp( torch.cat([hidden_feats, m_i], dim = -1) )
262 | hidden_out = kwargs["x"] + hidden_out
263 | else:
264 | hidden_out = kwargs["x"]
265 |
266 | # return tuple
267 | return self.update((hidden_out, coors_out), **update_kwargs)
268 |
269 | def __repr__(self):
270 | dict_print = {}
271 | return "E(n)-GNN Layer for Graphs " + str(self.__dict__)
272 |
273 |
274 | class EGNN_Sparse_Network(nn.Module):
275 | r"""Sample GNN model architecture that uses the EGNN-Sparse
276 | message passing layer to learn over point clouds.
277 | Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1
278 |
279 | Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...
280 |
281 | Args:
282 | * n_layers: int. number of MPNN layers
283 | * ... : same interpretation as the base layer.
284 | * embedding_nums: list. number of unique keys to embedd. for points
285 | 1 entry per embedding needed.
286 | * embedding_dims: list. point - number of dimensions of
287 | the resulting embedding. 1 entry per embedding needed.
288 | * edge_embedding_nums: list. number of unique keys to embedd. for edges.
289 | 1 entry per embedding needed.
290 | * edge_embedding_dims: list. point - number of dimensions of
291 | the resulting embedding. 1 entry per embedding needed.
292 | * recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc
293 | * verbose: bool. verbosity level.
294 | -----
295 | Diff with normal layer: one has to do preprocessing before (radius, global token, ...)
296 | """
297 | def __init__(self, n_layers, feats_dim,
298 | pos_dim = 3,
299 | edge_attr_dim = 0,
300 | m_dim = 16,
301 | fourier_features = 0,
302 | soft_edge = 0,
303 | embedding_nums=[],
304 | embedding_dims=[],
305 | edge_embedding_nums=[],
306 | edge_embedding_dims=[],
307 | update_coors=True,
308 | update_feats=True,
309 | norm_feats=True,
310 | norm_coors=False,
311 | norm_coors_scale_init = 1e-2,
312 | dropout=0.,
313 | coor_weights_clamp_value=None,
314 | aggr="add",
315 | global_linear_attn_every = 0,
316 | global_linear_attn_heads = 8,
317 | global_linear_attn_dim_head = 64,
318 | num_global_tokens = 4,
319 | recalc=0 ,):
320 | super().__init__()
321 |
322 | self.n_layers = n_layers
323 |
324 | # Embeddings? solve here
325 | self.embedding_nums = embedding_nums
326 | self.embedding_dims = embedding_dims
327 | self.emb_layers = nn.ModuleList()
328 | self.edge_embedding_nums = edge_embedding_nums
329 | self.edge_embedding_dims = edge_embedding_dims
330 | self.edge_emb_layers = nn.ModuleList()
331 |
332 | # instantiate point and edge embedding layers
333 |
334 | for i in range( len(self.embedding_dims) ):
335 | self.emb_layers.append(nn.Embedding(num_embeddings = embedding_nums[i],
336 | embedding_dim = embedding_dims[i]))
337 | feats_dim += embedding_dims[i] - 1
338 |
339 | for i in range( len(self.edge_embedding_dims) ):
340 | self.edge_emb_layers.append(nn.Embedding(num_embeddings = edge_embedding_nums[i],
341 | embedding_dim = edge_embedding_dims[i]))
342 | edge_attr_dim += edge_embedding_dims[i] - 1
343 | # rest
344 | self.mpnn_layers = nn.ModuleList()
345 | self.feats_dim = feats_dim
346 | self.pos_dim = pos_dim
347 | self.edge_attr_dim = edge_attr_dim
348 | self.m_dim = m_dim
349 | self.fourier_features = fourier_features
350 | self.soft_edge = soft_edge
351 | self.norm_feats = norm_feats
352 | self.norm_coors = norm_coors
353 | self.norm_coors_scale_init = norm_coors_scale_init
354 | self.update_feats = update_feats
355 | self.update_coors = update_coors
356 | self.dropout = dropout
357 | self.coor_weights_clamp_value = coor_weights_clamp_value
358 | self.recalc = recalc
359 |
360 | self.has_global_attn = global_linear_attn_every > 0
361 | self.global_tokens = None
362 | self.global_linear_attn_every = global_linear_attn_every
363 | if self.has_global_attn:
364 | self.global_tokens = nn.Parameter(torch.randn(num_global_tokens, dim))
365 |
366 | # instantiate layers
367 | for i in range(n_layers):
368 | layer = EGNN_Sparse(feats_dim = feats_dim,
369 | pos_dim = pos_dim,
370 | edge_attr_dim = edge_attr_dim,
371 | m_dim = m_dim,
372 | fourier_features = fourier_features,
373 | soft_edge = soft_edge,
374 | norm_feats = norm_feats,
375 | norm_coors = norm_coors,
376 | norm_coors_scale_init = norm_coors_scale_init,
377 | update_feats = update_feats,
378 | update_coors = update_coors,
379 | dropout = dropout,
380 | coor_weights_clamp_value = coor_weights_clamp_value)
381 |
382 | # global attention case
383 | is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
384 | if is_global_layer:
385 | attn_layer = GlobalLinearAttention(dim = self.feats_dim,
386 | heads = global_linear_attn_heads,
387 | dim_head = global_linear_attn_dim_head)
388 | self.mpnn_layers.append(nn.ModuleList([layer, attn_layer]))
389 | # normal case
390 | else:
391 | self.mpnn_layers.append(layer)
392 |
393 |
394 | def forward(self, x, edge_index, batch, edge_attr,
395 | bsize=None, recalc_edge=None, verbose=0):
396 | """ Recalculate edge features every `self.recalc_edge` with the
397 | `recalc_edge` function if self.recalc_edge is set.
398 |
399 | * x: (N, pos_dim+feats_dim) will be unpacked into coors, feats.
400 | """
401 | # NODES - Embedd each dim to its target dimensions:
402 | x = embedd_token(x, self.embedding_dims, self.emb_layers)
403 |
404 | # regulates wether to embedd edges each layer
405 | edges_need_embedding = True
406 | for i,layer in enumerate(self.mpnn_layers):
407 |
408 | # EDGES - Embedd each dim to its target dimensions:
409 | if edges_need_embedding:
410 | edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers)
411 | edges_need_embedding = False
412 |
413 | # attn tokens
414 | global_tokens = None
415 | if exists(self.global_tokens):
416 | unique, amounts = torch.unique(batch, return_counts)
417 | num_idxs = torch.cat([torch.arange(num_idxs_i) for num_idxs_i in amounts], dim=-1)
418 | global_tokens = self.global_tokens[num_idxs]
419 |
420 | # pass layers
421 | is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
422 | if not is_global_layer:
423 | x = layer(x, edge_index, edge_attr, batch=batch, size=bsize)
424 | else:
425 | # only pass feats to the attn layer
426 | x_attn = layer[0](x[:, self.pos_dim:], global_tokens)
427 | # merge attn-ed feats and coords
428 | x = torch.cat( (x[:, :self.pos_dim], x_attn), dim=-1)
429 | x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize)
430 |
431 | # recalculate edge info - not needed if last layer
432 | if self.recalc and ((i%self.recalc == 0) and not (i == len(self.mpnn_layers)-1)) :
433 | edge_index, edge_attr, _ = recalc_edge(x) # returns attr, idx, any_other_info
434 | edges_need_embedding = True
435 |
436 | return x
437 |
438 | def __repr__(self):
439 | return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))
440 |
--------------------------------------------------------------------------------
/egnn_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import sin, cos, atan2, acos
3 |
4 | def rot_z(gamma):
5 | return torch.tensor([
6 | [cos(gamma), -sin(gamma), 0],
7 | [sin(gamma), cos(gamma), 0],
8 | [0, 0, 1]
9 | ], dtype=gamma.dtype)
10 |
11 | def rot_y(beta):
12 | return torch.tensor([
13 | [cos(beta), 0, sin(beta)],
14 | [0, 1, 0],
15 | [-sin(beta), 0, cos(beta)]
16 | ], dtype=beta.dtype)
17 |
18 | def rot(alpha, beta, gamma):
19 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
20 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose
6 | python_files = tests/*.py
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'egnn-pytorch',
5 | packages = find_packages(),
6 | version = '0.2.8',
7 | license='MIT',
8 | description = 'E(n)-Equivariant Graph Neural Network - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang, Eric Alcaide',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/egnn-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'equivariance',
17 | 'graph neural network'
18 | ],
19 | install_requires=[
20 | 'einops>=0.3',
21 | 'numba',
22 | 'numpy',
23 | 'torch>=1.6'
24 | ],
25 | setup_requires=[
26 | 'pytest-runner',
27 | ],
28 | tests_require=[
29 | 'pytest'
30 | ],
31 | classifiers=[
32 | 'Development Status :: 4 - Beta',
33 | 'Intended Audience :: Developers',
34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
35 | 'License :: OSI Approved :: MIT License',
36 | 'Programming Language :: Python :: 3.6',
37 | ],
38 | )
39 |
--------------------------------------------------------------------------------
/tests/test_equivariance.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from egnn_pytorch import EGNN, EGNN_Sparse
4 | from egnn_pytorch.utils import rot
5 |
6 | torch.set_default_dtype(torch.float64)
7 |
8 | def test_egnn_equivariance():
9 | layer = EGNN(dim=512, edge_dim=4)
10 |
11 | R = rot(*torch.rand(3))
12 | T = torch.randn(1, 1, 3)
13 |
14 | feats = torch.randn(1, 16, 512)
15 | coors = torch.randn(1, 16, 3)
16 | edges = torch.randn(1, 16, 16, 4)
17 | mask = torch.ones(1, 16).bool()
18 |
19 | # Cache first two nodes' features
20 | node1 = feats[:, 0, :]
21 | node2 = feats[:, 1, :]
22 |
23 | # Switch first and second nodes' positions
24 | feats_permuted_row_wise = feats.clone().detach()
25 | feats_permuted_row_wise[:, 0, :] = node2
26 | feats_permuted_row_wise[:, 1, :] = node1
27 |
28 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)
29 | feats2, coors2 = layer(feats, coors, edges, mask=mask)
30 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)
31 |
32 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'
33 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'
34 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
35 |
36 | def test_higher_dimension():
37 | layer = EGNN(dim=512, edge_dim=4)
38 |
39 | feats = torch.randn(1, 16, 512)
40 | coors = torch.randn(1, 16, 5)
41 | edges = torch.randn(1, 16, 16, 4)
42 | mask = torch.ones(1, 16).bool()
43 |
44 | feats, coors = layer(feats, coors, edges, mask = mask)
45 | assert True
46 |
47 | def test_egnn_equivariance_with_nearest_neighbors():
48 | layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8)
49 |
50 | R = rot(*torch.rand(3))
51 | T = torch.randn(1, 1, 3)
52 |
53 | feats = torch.randn(1, 256, 512)
54 | coors = torch.randn(1, 256, 3)
55 | edges = torch.randn(1, 256, 256, 1)
56 | mask = torch.ones(1, 256).bool()
57 |
58 | # Cache first two nodes' features
59 | node1 = feats[:, 0, :]
60 | node2 = feats[:, 1, :]
61 |
62 | # Switch first and second nodes' positions
63 | feats_permuted_row_wise = feats.clone().detach()
64 | feats_permuted_row_wise[:, 0, :] = node2
65 | feats_permuted_row_wise[:, 1, :] = node1
66 |
67 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)
68 | feats2, coors2 = layer(feats, coors, edges, mask=mask)
69 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)
70 |
71 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'
72 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'
73 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
74 |
75 |
76 | def test_egnn_equivariance_with_coord_norm():
77 | layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8, norm_coors=True)
78 |
79 | R = rot(*torch.rand(3))
80 | T = torch.randn(1, 1, 3)
81 |
82 | feats = torch.randn(1, 256, 512)
83 | coors = torch.randn(1, 256, 3)
84 | edges = torch.randn(1, 256, 256, 1)
85 | mask = torch.ones(1, 256).bool()
86 |
87 | # Cache first two nodes' features
88 | node1 = feats[:, 0, :]
89 | node2 = feats[:, 1, :]
90 |
91 | # Switch first and second nodes' positions
92 | feats_permuted_row_wise = feats.clone().detach()
93 | feats_permuted_row_wise[:, 0, :] = node2
94 | feats_permuted_row_wise[:, 1, :] = node1
95 |
96 | feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask)
97 | feats2, coors2 = layer(feats, coors, edges, mask=mask)
98 | feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask)
99 |
100 | assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant'
101 | assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant'
102 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
103 |
104 |
105 | def test_egnn_sparse_equivariance():
106 | layer = EGNN_Sparse(feats_dim=1,
107 | m_dim=16,
108 | fourier_features=4)
109 |
110 | R = rot(*torch.rand(3))
111 | T = torch.randn(1, 1, 3)
112 | apply_action = lambda t: (t @ R + T).squeeze()
113 |
114 | feats = torch.randn(16, 1)
115 | coors = torch.randn(16, 3)
116 | edge_idxs = (torch.rand(2, 20) * 16).long()
117 |
118 | # Cache first two nodes' features
119 | node1 = feats[0, :]
120 | node2 = feats[1, :]
121 |
122 | # Switch first and second nodes' positions
123 | feats_permuted_row_wise = feats.clone().detach()
124 | feats_permuted_row_wise[0, :] = node2
125 | feats_permuted_row_wise[1, :] = node1
126 |
127 | x1 = torch.cat([coors, feats], dim=-1)
128 | x2 = torch.cat([apply_action(coors), feats], dim=-1)
129 | x3 = torch.cat([apply_action(coors), feats_permuted_row_wise], dim=-1)
130 |
131 | out1 = layer(x=x1, edge_index=edge_idxs)
132 | out2 = layer(x=x2, edge_index=edge_idxs)
133 | out3 = layer(x=x3, edge_index=edge_idxs)
134 |
135 | feats1, coors1 = out1[:, 3:], out1[:, :3]
136 | feats2, coors2 = out2[:, 3:], out2[:, :3]
137 | feats3, coors3 = out3[:, 3:], out3[:, :3]
138 |
139 | print(feats1 - feats2)
140 | print(apply_action(coors1) - coors2)
141 | assert torch.allclose(feats1, feats2), 'features must be invariant'
142 | assert torch.allclose(apply_action(coors1), coors2), 'coordinates must be equivariant'
143 | assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
144 |
145 |
146 | def test_geom_equivalence():
147 | layer = EGNN_Sparse(feats_dim=128,
148 | edge_attr_dim=4,
149 | m_dim=16,
150 | fourier_features=4)
151 |
152 | feats = torch.randn(16, 128)
153 | coors = torch.randn(16, 3)
154 | x = torch.cat([coors, feats], dim=-1)
155 | edge_idxs = (torch.rand(2, 20) * 16).long()
156 | edges_attrs = torch.randn(16, 16, 4)
157 | edges_attrs = edges_attrs[edge_idxs[0], edge_idxs[1]]
158 |
159 | assert layer.forward(x, edge_idxs, edge_attr=edges_attrs).shape == x.shape
160 |
--------------------------------------------------------------------------------