├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── gotennet.png
├── gotennet_pytorch
├── __init__.py
├── gotennet.py
└── tensor_typing.py
├── pyproject.toml
└── tests
└── test_gotennet.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests the examples in README
2 | on: push
3 |
4 | jobs:
5 | test:
6 | runs-on: ubuntu-latest
7 | steps:
8 | - uses: actions/checkout@v4
9 | - name: Install Python
10 | uses: actions/setup-python@v4
11 | - name: Install the latest version of rye
12 | uses: eifinger/setup-rye@v2
13 | - name: Use UV instead of pip
14 | run: rye config --set-bool behavior.use-uv=true
15 | - name: Install dependencies
16 | run: |
17 | rye sync
18 | - name: Run pytest
19 | run: rye run pytest tests/
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## GotenNet - Pytorch
4 |
5 | Implementation of GotenNet, new SOTA 3d equivariant transformer, in Pytorch
6 |
7 | I know a lot of researchers have moved on from geometric learning after Alphafold3. However, I just cannot help but wonder. Hedging my bets
8 |
9 | The official repository has been released [here](https://github.com/sarpaykent/GotenNet/)
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install gotennet-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | torch.set_default_dtype(torch.float64) # recommended for equivariant network training
22 |
23 | from gotennet_pytorch import GotenNet
24 |
25 | model = GotenNet(
26 | dim = 256,
27 | max_degree = 2,
28 | depth = 1,
29 | heads = 2,
30 | dim_head = 32,
31 | dim_edge_refinement = 256,
32 | return_coors = False
33 | )
34 |
35 | atom_ids = torch.randint(0, 14, (1, 12)) # negative atom indices will be assumed to be padding - length of molecule is thus `(atom_ids >= 0).sum(dim = -1)`
36 | coors = torch.randn(1, 12, 3)
37 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool()
38 |
39 | invariant, coors_out = model(atom_ids, adj_mat = adj_mat, coors = coors)
40 | ```
41 |
42 | ## Citations
43 |
44 | ```bibtex
45 | @inproceedings{anonymous2024rethinking,
46 | title = {Rethinking Efficient 3D Equivariant Graph Neural Networks},
47 | author = {Anonymous},
48 | booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
49 | year = {2024},
50 | url = {https://openreview.net/forum?id=5wxCQDtbMo},
51 | note = {under review}
52 | }
53 | ```
54 |
55 | ```bibtex
56 | @inproceedings{Zhou2024ValueRL,
57 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
58 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
59 | year = {2024},
60 | url = {https://api.semanticscholar.org/CorpusID:273532030}
61 | }
62 | ```
63 |
64 | ```bibtex
65 | @article{Zhu2024HyperConnections,
66 | title = {Hyper-Connections},
67 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
68 | journal = {ArXiv},
69 | year = {2024},
70 | volume = {abs/2409.19606},
71 | url = {https://api.semanticscholar.org/CorpusID:272987528}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/gotennet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/gotennet-pytorch/20a63dfabd5c9c71effd9e643941dfca99868a85/gotennet.png
--------------------------------------------------------------------------------
/gotennet_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from gotennet_pytorch.gotennet import GotenNet, torch_default_dtype
2 |
--------------------------------------------------------------------------------
/gotennet_pytorch/gotennet.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from contextlib import contextmanager
5 | from collections.abc import Sequence
6 |
7 | import torch
8 | from torch import nn, cat, Tensor, einsum
9 | from torch.nn import Linear, Sequential, Module, ModuleList, ParameterList
10 |
11 | import einx
12 | from einx import get_at
13 |
14 | from einops import rearrange, repeat, reduce
15 | from einops.layers.torch import Rearrange
16 |
17 | from e3nn.o3 import spherical_harmonics
18 |
19 | from gotennet_pytorch.tensor_typing import Float, Int, Bool
20 |
21 | from hyper_connections import get_init_and_expand_reduce_stream_functions
22 |
23 | from x_transformers import Attention
24 |
25 | # ein notation
26 |
27 | # b - batch
28 | # h - heads
29 | # n - sequence
30 | # m - sequence (neighbors)
31 | # i, j - source and target sequence
32 | # d - feature
33 | # m - order of each degree
34 | # l - degree
35 |
36 | # helper functions
37 |
38 | def exists(v):
39 | return v is not None
40 |
41 | def default(v, d):
42 | return v if exists(v) else d
43 |
44 | def max_neg_value(t):
45 | return -torch.finfo(t.dtype).max
46 |
47 | def mask_from_lens(lens, total_len):
48 | seq = torch.arange(total_len, device = lens.device)
49 | return einx.less('n, b -> b n', seq, lens)
50 |
51 | def softclamp(t, value = 50.):
52 | return (t / value).tanh() * value
53 |
54 | @contextmanager
55 | def torch_default_dtype(dtype):
56 | prev_dtype = torch.get_default_dtype()
57 | torch.set_default_dtype(dtype)
58 | yield
59 | torch.set_default_dtype(prev_dtype)
60 |
61 | # normalization
62 |
63 | LayerNorm = partial(nn.LayerNorm, bias = False)
64 |
65 | class HighDegreeNorm(Module):
66 | def __init__(self, dim, eps = 1e-6):
67 | super().__init__()
68 | self.eps = eps
69 | self.scale = dim ** 0.5
70 | self.gamma = nn.Parameter(torch.ones(dim, 1))
71 |
72 | def forward(self, x):
73 | norms = x.norm(dim = -1, keepdim = True)
74 | den = norms.norm(dim = -2, keepdim = True) * self.scale
75 | return x / den.clamp(min = self.eps) * self.gamma
76 |
77 | # radial basis function
78 |
79 | class Radial(Module):
80 | def __init__(
81 | self,
82 | dim,
83 | radial_hidden_dim = 64
84 | ):
85 | super().__init__()
86 |
87 | hidden = radial_hidden_dim
88 |
89 | self.rp = Sequential(
90 | Rearrange('... -> ... 1'),
91 | Linear(1, hidden),
92 | nn.SiLU(),
93 | LayerNorm(hidden),
94 | Linear(hidden, hidden),
95 | nn.SiLU(),
96 | LayerNorm(hidden),
97 | Linear(hidden, dim)
98 | )
99 |
100 | def forward(self, x):
101 | return self.rp(x)
102 |
103 | # node scalar feat init
104 | # eq (1) and (2)
105 |
106 | class NodeScalarFeatInit(Module):
107 | def __init__(
108 | self,
109 | num_atoms,
110 | dim,
111 | accept_embed = False,
112 | radial_hidden_dim = 64
113 | ):
114 | super().__init__()
115 | self.atom_embed = nn.Embedding(num_atoms, dim) if not accept_embed else nn.Linear(dim, dim)
116 | self.neighbor_atom_embed = nn.Embedding(num_atoms, dim) if not accept_embed else nn.Linear(dim, dim)
117 |
118 | self.rel_dist_mlp = Radial(
119 | dim = dim,
120 | radial_hidden_dim = radial_hidden_dim
121 | )
122 |
123 | self.to_node_feats = Sequential(
124 | Linear(dim * 2, dim),
125 | LayerNorm(dim),
126 | nn.SiLU(),
127 | Linear(dim, dim)
128 | )
129 |
130 | def forward(
131 | self,
132 | atoms: Int['b n'] | Float['b n d'],
133 | rel_dist: Float['b n m'],
134 | adj_mat: Bool['b n m'] | None = None,
135 | neighbor_indices: Int['b n m'] | None = None,
136 | neighbor_mask: Bool['b n m'] | None = None,
137 | mask: Bool['b n'] | None = None,
138 | ) -> Float['b n d']:
139 |
140 | dtype = rel_dist.dtype
141 | batch, seq, device = *atoms.shape[:2], atoms.device
142 |
143 | if not exists(adj_mat):
144 | adj_mat = torch.ones_like(rel_dist, device = device, dtype = dtype)
145 |
146 | if atoms.dtype in (torch.int, torch.long):
147 | atoms = atoms.masked_fill(atoms < 0, 0)
148 |
149 | embeds = self.atom_embed(atoms)
150 |
151 | rel_dist_feats = self.rel_dist_mlp(rel_dist)
152 |
153 | if exists(neighbor_indices):
154 | if exists(neighbor_mask):
155 | rel_dist_feats = einx.where('b i j, b i j d, -> b i j d', neighbor_mask, rel_dist_feats, 0.)
156 | else:
157 | if exists(mask):
158 | rel_dist_feats = einx.where('b j, b i j d, -> b i j d', mask, rel_dist_feats, 0.)
159 |
160 | neighbor_embeds = self.neighbor_atom_embed(atoms)
161 |
162 | if exists(neighbor_indices):
163 | neighbor_embeds = get_at('b [n] d, b i j -> b i j d', neighbor_embeds, neighbor_indices)
164 |
165 | neighbor_feats = einsum('b i j, b i j d, b i j d -> b i d', adj_mat.type(dtype), rel_dist_feats, neighbor_embeds)
166 | else:
167 | neighbor_feats = einsum('b i j, b i j d, b j d -> b i d', adj_mat.type(dtype), rel_dist_feats, neighbor_embeds)
168 |
169 | self_and_neighbor = torch.cat((embeds, neighbor_feats), dim = -1)
170 |
171 | return self.to_node_feats(self_and_neighbor)
172 |
173 | # edge scalar feat init
174 | # eq (3)
175 |
176 | class EdgeScalarFeatInit(Module):
177 | def __init__(
178 | self,
179 | dim,
180 | expansion_factor = 4.,
181 | ):
182 | super().__init__()
183 |
184 | dim_inner = int(dim * expansion_factor)
185 |
186 | self.rel_dist_mlp = Sequential(
187 | Rearrange('... -> ... 1'),
188 | nn.Linear(1, dim_inner, bias = False),
189 | LayerNorm(dim_inner),
190 | nn.SiLU(),
191 | nn.Linear(dim_inner, dim, bias = False)
192 | )
193 |
194 | def forward(
195 | self,
196 | h: Float['b n d'],
197 | rel_dist: Float['b n m'],
198 | neighbor_indices: Int['b n m'] | None = None
199 | ) -> Float['b n m d']:
200 |
201 | if exists(neighbor_indices):
202 | h_neighbors = get_at('b [n] d, b i j -> b i j d', h, neighbor_indices)
203 |
204 | outer_sum_feats = einx.add('b i d, b i j d -> b i j d', h, h_neighbors)
205 | else:
206 | outer_sum_feats = einx.add('b i d, b j d -> b i j d', h, h)
207 |
208 | rel_dist_feats = self.rel_dist_mlp(rel_dist)
209 |
210 | return outer_sum_feats + rel_dist_feats
211 |
212 | # equivariant feedforward
213 | # section 3.5
214 |
215 | class EquivariantFeedForward(Module):
216 | def __init__(
217 | self,
218 | dim,
219 | max_degree,
220 | mlp_expansion_factor = 2.,
221 | layernorm_input = False # no mention of this in the paper, but think there should be one based on my own intuition
222 | ):
223 | """
224 | following eq 13
225 | """
226 | super().__init__()
227 | assert max_degree > 1
228 | self.max_degree = max_degree
229 |
230 | mlp_dim = int(mlp_expansion_factor * dim * 2)
231 |
232 | self.projs = ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(max_degree)])
233 |
234 | self.mlps = ModuleList([
235 | Sequential(
236 | LayerNorm(dim * 2) if layernorm_input else nn.Identity(),
237 | Linear(dim * 2, mlp_dim),
238 | nn.SiLU(),
239 | Linear(mlp_dim, dim * 2)
240 | )
241 | for _ in range(max_degree)
242 | ])
243 |
244 | def forward(
245 | self,
246 | h: Float['b n d'],
247 | x: Sequence[Float['b n d _'], ...]
248 | ):
249 | assert len(x) == self.max_degree
250 |
251 | h_residual = 0.
252 | x_residuals = []
253 |
254 | for one_degree, proj, mlp in zip(x, self.projs, self.mlps):
255 |
256 | # make higher degree tensor invariant through norm on `m` axis and then concat -> mlp -> split
257 |
258 | proj_one_degree = einsum('... d m, ... d e -> ... e m', one_degree, proj)
259 |
260 | normed_invariant = proj_one_degree.norm(dim = -1)
261 |
262 | mlp_inp = torch.cat((h, normed_invariant), dim = - 1)
263 | mlp_out = mlp(mlp_inp)
264 |
265 | m1, m2 = mlp_out.chunk(2, dim = -1) # named m1, m2 in equation 13, one is the residual for h, the other modulates the projected higher degree tensor for its residual
266 |
267 | modulated_one_degree = einx.multiply('... d m, ... d -> ... d m', proj_one_degree, m2)
268 |
269 | # aggregate residuals
270 |
271 | h_residual = h_residual + m1
272 |
273 | x_residuals.append(modulated_one_degree)
274 |
275 | # return residuals
276 |
277 | return h_residual, x_residuals
278 |
279 | # hierarchical tensor refinement
280 | # section 3.4
281 |
282 | class HierarchicalTensorRefinement(Module):
283 | def __init__(
284 | self,
285 | dim,
286 | dim_edge_refinement, # they made this value much higher for MD22 task. so it is an important hparam for certain more difficult tasks
287 | max_degree,
288 | norm_edge_proj_input = True # this was not in the paper, but added or else network explodes at around depth 4
289 | ):
290 | super().__init__()
291 | assert max_degree > 0
292 |
293 |
294 | # in paper, queries share the same projection, but each higher degree has its own key projection
295 |
296 | self.to_queries = nn.Parameter(torch.randn(dim, dim_edge_refinement))
297 |
298 | self.to_keys = ParameterList([nn.Parameter(torch.randn(dim, dim_edge_refinement)) for _ in range(max_degree)])
299 |
300 | # the two weight matrices
301 | # one for mixing the inner product between all queries and keys across degrees above
302 | # the other for refining the t_ij passed down from the layer before as a residual
303 |
304 | self.residue_update = nn.Linear(dim, dim, bias = False)
305 | self.edge_proj = nn.Linear(dim_edge_refinement * max_degree, dim, bias = False)
306 |
307 | # norm not in diagram or paper, added to prevent t_ij from exploding
308 |
309 | self.norm = LayerNorm(dim_edge_refinement * max_degree) if norm_edge_proj_input else nn.Identity()
310 |
311 | def forward(
312 | self,
313 | t_ij: Float['b n m d'],
314 | x: Sequence[Float['b n d _'], ...],
315 | neighbor_indices: Int['b n m'] | None = None,
316 | ) -> Float['b n m d']:
317 |
318 | # eq (10)
319 |
320 | queries = [einsum('... d m, ... d e -> ... e m', one_degree, self.to_queries) for one_degree in x]
321 |
322 | keys = [einsum('... d m, ... d e -> ... e m', one_degree, to_keys) for one_degree, to_keys in zip(x, self.to_keys)]
323 |
324 | # eq (11)
325 |
326 | if exists(neighbor_indices):
327 | keys = [get_at('b [n] d m, b i j -> b i j d m', one_degree_key, neighbor_indices) for one_degree_key in keys]
328 |
329 | inner_product = [einsum('... i d m, ... i j d m -> ... i j d', one_degree_query, one_degree_key) for one_degree_query, one_degree_key in zip(queries, keys)]
330 | else:
331 | inner_product = [einsum('... i d m, ... j d m -> ... i j d', one_degree_query, one_degree_key) for one_degree_query, one_degree_key in zip(queries, keys)]
332 |
333 | w_ij = cat(inner_product, dim = -1)
334 |
335 | # this was not in the paper, but added or else network explodes at around depth 4
336 |
337 | w_ij = self.norm(w_ij)
338 |
339 | # eq (12)
340 |
341 | edge_proj_out = self.edge_proj(w_ij)
342 | edge_proj_out = torch.sigmoid(edge_proj_out)
343 |
344 | residue_update_out = self.residue_update(t_ij)
345 |
346 | return edge_proj_out * residue_update_out
347 |
348 | # geometry-aware tensor attention
349 | # section 3.3
350 |
351 | class GeometryAwareTensorAttention(Module):
352 | def __init__(
353 | self,
354 | dim,
355 | max_degree,
356 | dim_head = None,
357 | heads = 8,
358 | softclamp_value = 50.,
359 | mlp_expansion_factor = 2.,
360 | only_init_high_degree_feats = False, # if set to True, only returns high degree steerable features eq (4) in section 3.2
361 | learned_value_residual_mix = False,
362 | ):
363 | super().__init__()
364 | self.only_init_high_degree_feats = only_init_high_degree_feats
365 |
366 | assert max_degree > 0
367 | self.max_degree = max_degree
368 |
369 | dim_head = default(dim_head, dim)
370 |
371 | # for some reason, there is no mention of attention heads, will just improvise
372 |
373 | dim_inner = heads * dim_head
374 |
375 | self.split_heads = Rearrange('b ... (h d) -> b h ... d', h = heads)
376 | self.merge_heads = Rearrange('b h ... d -> b ... (h d)')
377 |
378 | # eq (5) - layernorms are present in the diagram in figure 2. but not mentioned in the equations..
379 |
380 | self.to_hi = LayerNorm(dim)
381 | self.to_hj = LayerNorm(dim)
382 |
383 | self.to_queries = Linear(dim, dim_inner, bias = False)
384 | self.to_keys = Linear(dim, dim_inner, bias = False)
385 |
386 | dim_mlp_inner = int(mlp_expansion_factor * dim_inner)
387 |
388 | # attention softclamping, used in Gemma
389 |
390 | self.softclamp = partial(softclamp, value = softclamp_value)
391 |
392 | # S contains two parts of L_max (one to modulate each degree of r_ij, another to modulate each X_j, then one final to modulate h). incidentally, this overlaps with eq. (m = 2 * L + 1), causing much confusion, cleared up in openreview
393 |
394 | self.S = (1, max_degree, max_degree) if not only_init_high_degree_feats else (max_degree,)
395 | S = sum(self.S)
396 |
397 | self.to_values = Sequential(
398 | Linear(dim, dim_mlp_inner),
399 | nn.SiLU(),
400 | Linear(dim_mlp_inner, S * dim_inner),
401 | Rearrange('... (s d) -> ... s d', s = S)
402 | )
403 |
404 | # value residual, iclr 2024 paper that is certain to take off
405 |
406 | self.to_value_residual_mix = Sequential(
407 | Linear(dim, heads, bias = False),
408 | nn.Sigmoid(),
409 | Rearrange('b n h -> b h n 1 1')
410 | ) if learned_value_residual_mix else None
411 |
412 | # eq (6) second half: t_ij -> edge scalar features
413 |
414 | self.to_edge_keys = Sequential(
415 | Linear(dim, S * dim_inner, bias = False), # W_re
416 | nn.SiLU(), # σ_k - never indicated in paper. just use Silu
417 | Rearrange('... (s d) -> ... s d', s = S)
418 | )
419 |
420 | # eq (7) - todo, handle cutoff radius
421 |
422 | self.to_edge_values = nn.Sequential( # t_ij modulating weighted sum
423 | Linear(dim, S * dim_inner, bias = False),
424 | Rearrange('... (s d) -> ... s d', s = S)
425 | )
426 |
427 | self.post_attn_h_values = Sequential(
428 | Linear(dim, dim_mlp_inner),
429 | nn.SiLU(),
430 | Linear(dim_mlp_inner, S * dim_inner),
431 | Rearrange('... (s d) -> ... s d', s = S)
432 | )
433 |
434 | # alphafold styled gating
435 |
436 | self.to_gates = Sequential(
437 | Linear(dim, heads * S),
438 | nn.Sigmoid(),
439 | Rearrange('b i (h s) -> b h i 1 s 1', s = S),
440 | )
441 |
442 | # combine heads
443 |
444 | self.combine_heads = Sequential(
445 | Linear(dim_inner, dim, bias = False)
446 | )
447 |
448 | def forward(
449 | self,
450 | h: Float['b n d'],
451 | t_ij: Float['b n m d'],
452 | r_ij: Sequence[Float['b n m _'], ...],
453 | x: Sequence[Float['b n d _'], ...] | None = None,
454 | mask: Bool['b n'] | None = None,
455 | neighbor_indices: Int['b n m'] | None = None,
456 | neighbor_mask: Bool['b n m'] | None = None,
457 | return_value_residual = False,
458 | value_residuals: Tuple[Tensor, Tensor] | None = None
459 | ):
460 | # validation
461 |
462 | assert exists(x) ^ self.only_init_high_degree_feats
463 |
464 | if not self.only_init_high_degree_feats:
465 | assert len(x) == self.max_degree
466 |
467 | assert len(r_ij) == self.max_degree
468 |
469 | # eq (5)
470 |
471 | hi = self.to_hi(h)
472 | hj = self.to_hj(h)
473 |
474 | queries = self.to_queries(hi)
475 | keys = self.to_keys(hj)
476 |
477 | # unsure why values are split into two, with one elementwise-multiplied with the edge values coming from t_ij
478 | # need another pair of eyes to double check
479 |
480 | values = self.to_values(hj)
481 | post_attn_values = self.post_attn_h_values(hj)
482 |
483 | # edge keys and values
484 |
485 | edge_keys = self.to_edge_keys(t_ij)
486 | edge_values = self.to_edge_values(t_ij)
487 |
488 | # split out attention heads
489 |
490 | queries, keys, values, post_attn_values, edge_keys, edge_values = map(self.split_heads, (queries, keys, values, post_attn_values, edge_keys, edge_values))
491 |
492 | # value residual mixing
493 |
494 | next_value_residuals = (values, post_attn_values, edge_values)
495 |
496 | if exists(self.to_value_residual_mix):
497 | assert exists(value_residuals)
498 |
499 | value_residual, post_attn_values_residual, edge_values_residual = value_residuals
500 |
501 | mix = self.to_value_residual_mix(hi)
502 |
503 | values = values.lerp(value_residual, mix)
504 | post_attn_values = post_attn_values.lerp(post_attn_values_residual, mix)
505 |
506 | if exists(neighbor_indices):
507 | mix = get_at('b h [n] ..., b i j -> b h i j ...', mix, neighbor_indices)
508 | else:
509 | mix = rearrange(mix, 'b h j ... -> b h 1 j ...')
510 |
511 | edge_values = edge_values.lerp(edge_values_residual, mix)
512 |
513 | # account for neighbor logic
514 |
515 | if exists(neighbor_indices):
516 | keys = get_at('b h [n] ..., b i j -> b h i j ...', keys, neighbor_indices)
517 | values = get_at('b h [n] ..., b i j -> b h i j ...', values, neighbor_indices)
518 | post_attn_values = get_at('b h [n] ..., b i j -> b h i j ...', post_attn_values, neighbor_indices)
519 |
520 | # eq (6)
521 |
522 | # unsure why there is a k-dimension in the paper math notation, in addition to i, j
523 |
524 | if exists(neighbor_indices):
525 | keys = einx.multiply('... i j d, ... i j s d -> ... i j s d', keys, edge_keys)
526 | else:
527 | keys = einx.multiply('... j d, ... i j s d -> ... i j s d', keys, edge_keys)
528 |
529 | # similarities
530 |
531 | sim = einsum('... i d, ... i j s d -> ... i j s', queries, keys)
532 |
533 | # soft clamping - used successfully in gemma to prevent attention logit overflows
534 |
535 | sim = self.softclamp(sim)
536 |
537 | # masking
538 |
539 | if exists(neighbor_indices):
540 | if exists(neighbor_mask):
541 | sim = einx.where('b i j, b h i j s, -> b h i j s', neighbor_mask, sim, max_neg_value(sim))
542 | else:
543 | if exists(mask):
544 | sim = einx.where('b j, b h i j s, -> b h i j s', mask, sim, max_neg_value(sim))
545 |
546 | # attend
547 |
548 | attn = sim.softmax(dim = -2)
549 |
550 | # aggregate values
551 |
552 | if exists(neighbor_indices):
553 | sea_ij = einsum('... i j s, ... i j s d -> ... i j s d', attn, values)
554 | else:
555 | sea_ij = einsum('... i j s, ... j s d -> ... i j s d', attn, values)
556 |
557 | # eq (7)
558 |
559 | if exists(neighbor_indices):
560 | sea_ij = sea_ij + einx.multiply('... i j s d, ... i j s d -> ... i j s d', edge_values, post_attn_values)
561 | else:
562 | sea_ij = sea_ij + einx.multiply('... i j s d, ... j s d -> ... i j s d', edge_values, post_attn_values)
563 |
564 | # alphafold style gating
565 |
566 | out = sea_ij * self.to_gates(hi)
567 |
568 | # combine heads - not in paper for some reason, but attention heads mentioned, so must be necessary?
569 |
570 | out = self.merge_heads(out)
571 |
572 | out = self.combine_heads(out)
573 |
574 | # maybe eq (4) and early return
575 |
576 | if self.only_init_high_degree_feats:
577 | x_ij_init = [einsum('... i j m, ... i j d -> ... i d m', one_r_ij, one_r_ij_scale) for one_r_ij, one_r_ij_scale in zip(r_ij, out.unbind(dim = -2))]
578 | return x_ij_init
579 |
580 | # split out all the O's (eq 7 second half)
581 |
582 | h_scales, r_ij_scales, x_scales = out.split(self.S, dim = -2)
583 |
584 | # modulate with invariant scales and sum residuals
585 |
586 | h_residual = reduce(h_scales, 'b i j 1 d -> b i d', 'sum')
587 | x_residuals = []
588 |
589 | for one_degree, one_r_ij, one_degree_scale, one_r_ij_scale in zip(x, r_ij, x_scales.unbind(dim = -2), r_ij_scales.unbind(dim = -2)):
590 |
591 | r_ij_residual = einsum('b i j m, b i j d -> b i d m', one_r_ij, one_r_ij_scale)
592 |
593 | if exists(neighbor_indices):
594 | one_degree_neighbors = get_at('b [n] d m, b i j -> b i j d m', one_degree, neighbor_indices)
595 |
596 | x_ij_residual = einsum('b i j d m, b i j d -> b i d m', one_degree_neighbors, one_degree_scale)
597 |
598 | else:
599 | x_ij_residual = einsum('b j d m, b i j d -> b i d m', one_degree, one_degree_scale)
600 |
601 | x_residuals.append(r_ij_residual + x_ij_residual)
602 |
603 | out = (h_residual, x_residuals)
604 |
605 | if not return_value_residual:
606 | return out
607 |
608 | return out, next_value_residuals
609 |
610 | # full attention
611 |
612 | class InvariantAttention(Module):
613 | def __init__(
614 | self,
615 | dim,
616 | **attn_kwargs
617 | ):
618 | super().__init__()
619 | self.norm = nn.RMSNorm(dim)
620 | self.attn = Attention(dim, **attn_kwargs)
621 |
622 | def forward(self, h, mask = None):
623 | h = self.norm(h)
624 | return self.attn(h, mask = mask)
625 |
626 | # main class
627 |
628 | class GotenNet(Module):
629 | def __init__(
630 | self,
631 | dim,
632 | depth,
633 | max_degree,
634 | dim_edge_refinement = None,
635 | accept_embed = False,
636 | num_atoms = 14,
637 | heads = 8,
638 | dim_head = None,
639 | cutoff_radius = None,
640 | invariant_full_attn = False,
641 | invariant_attn_use_flash = False,
642 | full_attn_kwargs: dict = dict(),
643 | max_neighbors = float('inf'),
644 | mlp_expansion_factor = 2.,
645 | edge_init_mlp_expansion_factor = 4.,
646 | ff_kwargs: dict = dict(),
647 | return_coors = True,
648 | proj_invariant_dim = None,
649 | final_norm = True,
650 | add_value_residual = True,
651 | num_residual_streams = 4,
652 | htr_kwargs: dict = dict()
653 | ):
654 | super().__init__()
655 | self.accept_embed = accept_embed
656 |
657 | assert max_degree > 0
658 | self.max_degree = max_degree
659 |
660 | dim_edge_refinement = default(dim_edge_refinement, dim)
661 |
662 | # hyper connections, applied to invariant h for starters
663 |
664 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
665 |
666 | # only consider neighbors less than `cutoff_radius`, in paper, they used ~ 5 angstroms
667 | # can further randomly select from eligible neighbors with `max_neighbors`
668 |
669 | self.cutoff_radius = cutoff_radius
670 | self.max_neighbors = max_neighbors
671 |
672 | # node and edge feature init
673 |
674 | self.node_init = NodeScalarFeatInit(num_atoms, dim, accept_embed = accept_embed)
675 | self.edge_init = EdgeScalarFeatInit(dim, expansion_factor = edge_init_mlp_expansion_factor)
676 |
677 | self.high_degree_init = GeometryAwareTensorAttention(
678 | dim,
679 | max_degree = max_degree,
680 | dim_head = dim_head,
681 | heads = heads,
682 | mlp_expansion_factor = mlp_expansion_factor,
683 | only_init_high_degree_feats = True
684 | )
685 |
686 | # layers, thus deep learning
687 |
688 | self.layers = ModuleList([])
689 | self.residual_fns = ModuleList([])
690 |
691 | for layer_index in range(depth):
692 | is_first = layer_index == 0
693 |
694 | self.layers.append(ModuleList([
695 | HierarchicalTensorRefinement(dim, dim_edge_refinement, max_degree, **htr_kwargs),
696 | InvariantAttention(dim = dim, flash = invariant_attn_use_flash, **full_attn_kwargs) if invariant_full_attn else None,
697 | GeometryAwareTensorAttention(dim, max_degree, dim_head, heads, mlp_expansion_factor, learned_value_residual_mix = add_value_residual and not is_first),
698 | EquivariantFeedForward(dim, max_degree, mlp_expansion_factor),
699 | ]))
700 |
701 | self.residual_fns.append(ModuleList([
702 | init_hyper_conn(dim = dim) if invariant_full_attn else None,
703 | init_hyper_conn(dim = dim),
704 | init_hyper_conn(dim = dim),
705 | ]))
706 |
707 | # not mentioned in paper, but transformers need a final norm
708 |
709 | self.final_norm = final_norm
710 |
711 | if final_norm:
712 | self.h_final_norm = nn.LayerNorm(dim)
713 |
714 | self.x_final_norms = ModuleList([HighDegreeNorm(dim) for _ in range(max_degree)])
715 |
716 | # maybe project invariant
717 |
718 | self.proj_invariant = None
719 |
720 | if exists(proj_invariant_dim):
721 | self.proj_invariant = Linear(dim, proj_invariant_dim, bias = False)
722 |
723 | # maybe project to coordinates
724 |
725 | self.proj_to_coors = Sequential(
726 | Rearrange('... d m -> ... m d'),
727 | Linear(dim, 1, bias = False),
728 | Rearrange('... 1 -> ...')
729 | ) if return_coors else None
730 |
731 | def forward(
732 | self,
733 | atoms: Int['b n'] | Float['b n d'],
734 | coors: Float['b n 3'],
735 | adj_mat: Bool['b n n'] | None = None,
736 | lens: Int['b'] | None = None,
737 | mask: Bool['b n'] | None = None
738 | ):
739 | assert (atoms.dtype in (torch.int, torch.long)) ^ self.accept_embed
740 |
741 | batch, seq_len, device = *atoms.shape[:2], atoms.device
742 |
743 | assert not (exists(lens) and exists(mask)), '`lens` and `masks` cannot be both passed in'
744 |
745 | if exists(lens):
746 | mask = mask_from_lens(lens, seq_len)
747 |
748 | # also allow for negative atom indices in place of `lens` or `mask`
749 |
750 | if atoms.dtype in (torch.int, torch.long):
751 | atom_mask = atoms >= 0
752 | mask = default(mask, atom_mask)
753 |
754 | rel_pos = einx.subtract('b i c, b j c -> b i j c', coors, coors)
755 | rel_dist = rel_pos.norm(dim = -1)
756 |
757 | # process adjacency matrix
758 |
759 | if exists(adj_mat):
760 | eye = torch.eye(seq_len, device = device, dtype = torch.bool)
761 | adj_mat = adj_mat & ~eye # remove self from adjacency matrix
762 |
763 | # figure out neighbors, if needed
764 |
765 | neighbor_indices: Int['b n m'] | None = None
766 | neighbor_mask: Bool['b n m'] | None = None
767 |
768 | if exists(self.cutoff_radius):
769 |
770 | if exists(mask):
771 | rel_dist = einx.where('b j, b i j, -> b i j', mask, rel_dist, 1e6)
772 |
773 | is_neighbor = (rel_dist <= self.cutoff_radius).float()
774 |
775 | max_eligible_neighbors = is_neighbor.sum(dim = -1).long().amax().item()
776 | max_neighbors = min(max_eligible_neighbors, self.max_neighbors)
777 |
778 | noised_is_neighbor = is_neighbor + torch.rand_like(is_neighbor) * 1e-3
779 | neighbor_indices = noised_is_neighbor.topk(k = max_neighbors, dim = -1).indices
780 |
781 | if exists(adj_mat):
782 | adj_mat = adj_mat.gather(-1, neighbor_indices)
783 |
784 | neighbor_dist = rel_dist.gather(-1, neighbor_indices)
785 | neighbor_mask = neighbor_dist <= self.cutoff_radius
786 |
787 | rel_dist = neighbor_dist
788 | rel_pos = rel_pos.gather(-2, repeat(neighbor_indices, '... -> ... c', c = 3))
789 |
790 | # initialization
791 |
792 | h = self.node_init(atoms, rel_dist, adj_mat, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask)
793 |
794 | t_ij = self.edge_init(h, rel_dist, neighbor_indices = neighbor_indices)
795 |
796 | # constitute r_ij from section 3.1
797 |
798 | r_ij = []
799 |
800 | for degree in range(1, self.max_degree + 1):
801 | one_degree_r_ij = spherical_harmonics(degree, rel_pos, normalize = True, normalization = 'norm')
802 | r_ij.append(one_degree_r_ij)
803 |
804 | # init the high degrees
805 |
806 | x = self.high_degree_init(h, t_ij, r_ij, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask)
807 |
808 | # value residual
809 |
810 | value_residuals = None
811 |
812 | # maybe expand invariant h residual stream
813 |
814 | h = self.expand_streams(h)
815 |
816 | # go through the layers
817 |
818 | for (htr, maybe_h_attn, attn, ff), (maybe_h_attn_residual_fn, attn_residual_fn, ff_residual_fn) in zip(self.layers, self.residual_fns):
819 |
820 | # hierarchical tensor refinement
821 |
822 | t_ij = htr(t_ij, x, neighbor_indices = neighbor_indices) + t_ij
823 |
824 | # maybe full flash attention across invariants
825 |
826 | if exists(maybe_h_attn):
827 | h, add_h_attn_residual = maybe_h_attn_residual_fn(h)
828 |
829 | h = maybe_h_attn(h, mask = mask)
830 |
831 | h = add_h_attn_residual(h)
832 |
833 | # followed by attention, but of course
834 |
835 | h, add_attn_residual = attn_residual_fn(h)
836 |
837 | (h_residual, x_residuals), next_value_residuals = attn(h, t_ij, r_ij, x, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask, value_residuals = value_residuals, return_value_residual = True)
838 |
839 | # add attention residuals
840 |
841 | h = add_attn_residual(h_residual)
842 |
843 | x = [*map(sum, zip(x, x_residuals))]
844 |
845 | # handle value residual
846 |
847 | value_residuals = default(value_residuals, next_value_residuals)
848 |
849 | # feedforward
850 |
851 | h, add_ff_residual = ff_residual_fn(h)
852 |
853 | h_residual, x_residuals = ff(h, x)
854 |
855 | # add feedforward residuals
856 |
857 | h = add_ff_residual(h_residual)
858 |
859 | x = [*map(sum, zip(x, x_residuals))]
860 |
861 | h = self.reduce_streams(h)
862 |
863 | # maybe final norms
864 |
865 | if self.final_norm:
866 | h = self.h_final_norm(h)
867 | x = [norm(one_degree) for one_degree, norm in zip(x, self.x_final_norms)]
868 |
869 | # maybe transform invariant h
870 |
871 | if exists(self.proj_invariant):
872 | h = self.proj_invariant(h)
873 |
874 | # return h and x if `return_coors = False`
875 |
876 | if not exists(self.proj_to_coors):
877 | return h, x
878 |
879 | degree1, *_ = x
880 |
881 | coors_out = self.proj_to_coors(degree1)
882 |
883 | return h, coors_out
884 |
--------------------------------------------------------------------------------
/gotennet_pytorch/tensor_typing.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 | from jaxtyping import (
4 | Float,
5 | Int,
6 | Bool
7 | )
8 |
9 | # jaxtyping is a misnomer, works for pytorch
10 |
11 | class TorchTyping:
12 | def __init__(self, abstract_dtype):
13 | self.abstract_dtype = abstract_dtype
14 |
15 | def __getitem__(self, shapes: str):
16 | return self.abstract_dtype[Tensor, shapes]
17 |
18 | Float = TorchTyping(Float)
19 | Int = TorchTyping(Int)
20 | Bool = TorchTyping(Bool)
21 |
22 | __all__ = [
23 | Float,
24 | Int,
25 | Bool
26 | ]
27 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "gotennet-pytorch"
3 | version = "0.3.1"
4 | description = "GotenNet in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'attention mechanism',
16 | 'se3 equivariance',
17 | 'molecules'
18 | ]
19 | classifiers=[
20 | 'Development Status :: 4 - Beta',
21 | 'Intended Audience :: Developers',
22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Programming Language :: Python :: 3.8',
25 | ]
26 |
27 | dependencies = [
28 | 'e3nn',
29 | 'einx>=0.3.0',
30 | 'einops>=0.8.0',
31 | 'jaxtyping',
32 | 'hyper-connections>=0.1.0',
33 | 'x-transformers>=1.44.4',
34 | 'torch>=2.4',
35 | ]
36 |
37 | [project.urls]
38 | Homepage = "https://pypi.org/project/gotennet-pytorch/"
39 | Repository = "https://github.com/lucidrains/gotennet-pytorch"
40 |
41 | [project.optional-dependencies]
42 | examples = ["tqdm", "numpy"]
43 |
44 | [build-system]
45 | requires = ["hatchling"]
46 | build-backend = "hatchling.build"
47 |
48 | [tool.rye]
49 | managed = true
50 | dev-dependencies = [
51 | "ruff>=0.4.2",
52 | "pytest>=8.2.0",
53 | "pytest-examples>=0.0.10",
54 | "pytest-cov>=5.0.0",
55 | ]
56 |
57 | [tool.pytest.ini_options]
58 | pythonpath = ["."]
59 |
60 | [tool.hatch.metadata]
61 | allow-direct-references = true
62 |
63 | [tool.hatch.build.targets.wheel]
64 | packages = ["gotennet_pytorch"]
65 |
--------------------------------------------------------------------------------
/tests/test_gotennet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import sin, cos, stack
4 | from einops import rearrange
5 |
6 | # random rotations
7 |
8 | def rot_z(gamma):
9 | c = cos(gamma)
10 | s = sin(gamma)
11 | z = torch.zeros_like(gamma)
12 | o = torch.ones_like(gamma)
13 |
14 | out = stack((
15 | c, -s, z,
16 | s, c, z,
17 | z, z, o
18 | ), dim = -1)
19 |
20 | return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3)
21 |
22 | def rot_y(beta):
23 | c = cos(beta)
24 | s = sin(beta)
25 | z = torch.zeros_like(beta)
26 | o = torch.ones_like(beta)
27 |
28 | out = stack((
29 | c, z, s,
30 | z, o, z,
31 | -s, z, c
32 | ), dim = -1)
33 |
34 | return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3)
35 |
36 | def rot(alpha, beta, gamma):
37 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
38 |
39 | # testing
40 |
41 | from gotennet_pytorch.gotennet import GotenNet, torch_default_dtype
42 |
43 | @torch_default_dtype(torch.float64)
44 | @pytest.mark.parametrize('invariant_full_attn', (False, True))
45 | def test_invariant(invariant_full_attn):
46 | model = GotenNet(
47 | dim = 256,
48 | max_degree = 2,
49 | depth = 4,
50 | heads = 2,
51 | dim_head = 32,
52 | dim_edge_refinement = 256,
53 | invariant_full_attn = invariant_full_attn,
54 | return_coors = False,
55 | ff_kwargs = dict(
56 | layernorm_input = True
57 | )
58 | )
59 |
60 | random_rotation = rot(*torch.randn(3))
61 |
62 | atom_ids = torch.randint(0, 14, (1, 12))
63 | coors = torch.randn(1, 12, 3)
64 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool()
65 | mask = torch.randint(0, 2, (1, 12)).bool()
66 |
67 | inv1, _ = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask)
68 | inv2, _ = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask)
69 |
70 | assert torch.allclose(inv1, inv2, atol = 1e-5)
71 |
72 | @torch_default_dtype(torch.float64)
73 | @pytest.mark.parametrize('num_residual_streams', (1, 4))
74 | @pytest.mark.parametrize('invariant_full_attn', (False, True))
75 | def test_equivariant(
76 | num_residual_streams,
77 | invariant_full_attn
78 | ):
79 |
80 | model = GotenNet(
81 | dim = 256,
82 | max_degree = 2,
83 | depth = 4,
84 | heads = 2,
85 | dim_head = 32,
86 | dim_edge_refinement = 256,
87 | return_coors = True,
88 | invariant_full_attn = invariant_full_attn,
89 | ff_kwargs = dict(
90 | layernorm_input = True
91 | ),
92 | num_residual_streams = num_residual_streams
93 | )
94 |
95 | random_rotation = rot(*torch.randn(3))
96 |
97 | atom_ids = torch.randint(0, 14, (1, 12))
98 | coors = torch.randn(1, 12, 3)
99 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool()
100 | mask = torch.randint(0, 2, (1, 12)).bool()
101 |
102 | _, coors1 = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask)
103 | _, coors2 = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask)
104 |
105 | assert torch.allclose(coors1 @ random_rotation, coors2, atol = 1e-5)
106 |
107 | @torch_default_dtype(torch.float64)
108 | def test_equivariant_with_atom_feats():
109 |
110 | model = GotenNet(
111 | dim = 256,
112 | max_degree = 2,
113 | depth = 4,
114 | heads = 2,
115 | dim_head = 32,
116 | dim_edge_refinement = 256,
117 | accept_embed = True,
118 | return_coors = True
119 | )
120 |
121 | random_rotation = rot(*torch.randn(3))
122 |
123 | atom_feats = torch.randn((1, 12, 256))
124 | coors = torch.randn(1, 12, 3)
125 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool()
126 | mask = torch.randint(0, 2, (1, 12)).bool()
127 |
128 | _, coors1 = model(atom_feats, adj_mat = adj_mat, coors = coors, mask = mask)
129 | _, coors2 = model(atom_feats, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask)
130 |
131 | assert torch.allclose(coors1 @ random_rotation, coors2, atol = 1e-5)
132 |
133 | @torch_default_dtype(torch.float64)
134 | @pytest.mark.parametrize('num_residual_streams', (1, 4))
135 | def test_equivariant_neighbors(num_residual_streams):
136 |
137 | model = GotenNet(
138 | dim = 256,
139 | max_degree = 2,
140 | depth = 4,
141 | heads = 2,
142 | dim_head = 32,
143 | cutoff_radius = 5.,
144 | dim_edge_refinement = 256,
145 | return_coors = True,
146 | num_residual_streams = num_residual_streams,
147 | ff_kwargs = dict(
148 | layernorm_input = True
149 | )
150 | )
151 |
152 | random_rotation = rot(*torch.randn(3))
153 |
154 | atom_ids = torch.randint(0, 14, (1, 6))
155 | coors = torch.randn(1, 6, 3)
156 | adj_mat = torch.randint(0, 2, (1, 6, 6)).bool()
157 | mask = torch.randint(0, 2, (1, 6)).bool()
158 |
159 | _, coors1 = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask)
160 | _, coors2 = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask)
161 |
162 | out1 = coors1 @ random_rotation
163 | out2 = coors2
164 |
165 | assert torch.allclose(out1, out2, atol = 1e-5)
166 |
--------------------------------------------------------------------------------